diff options
Diffstat (limited to 'rest_framework')
44 files changed, 1351 insertions, 301 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 151ba832..f9882c57 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,3 +1,3 @@ -__version__ = '2.1.14' +__version__ = '2.1.17' VERSION = __version__ # synonym diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index c50bf944..76ee4bd6 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -23,34 +23,47 @@ class BaseAuthentication(object): """ raise NotImplementedError(".authenticate() must be overridden.") + def authenticate_header(self, request): + """ + Return a string to be used as the value of the `WWW-Authenticate` + header in a `401 Unauthenticated` response, or `None` if the + authentication scheme should return `403 Permission Denied` responses. + """ + pass + class BasicAuthentication(BaseAuthentication): """ HTTP Basic authentication against username/password. """ + www_authenticate_realm = 'api' def authenticate(self, request): """ Returns a `User` if a correct username and password have been supplied using HTTP Basic authentication. Otherwise returns `None`. """ - if 'HTTP_AUTHORIZATION' in request.META: - auth = request.META['HTTP_AUTHORIZATION'].split() - if len(auth) == 2 and auth[0].lower() == "basic": - try: - encoding = api_settings.HTTP_HEADER_ENCODING - b = base64.b64decode(auth[1].encode(encoding)) - auth_parts = b.decode(encoding).partition(':') - except TypeError: - return None - - try: - userid = smart_text(auth_parts[0]) - password = smart_text(auth_parts[2]) - except DjangoUnicodeDecodeError: - return None - - return self.authenticate_credentials(userid, password) + auth = request.META.get('HTTP_AUTHORIZATION', '').split() + + if not auth or auth[0].lower() != "basic": + return None + + if len(auth) != 2: + raise exceptions.AuthenticationFailed('Invalid basic header') + + encoding = api_settings.HTTP_HEADER_ENCODING + try: + auth_parts = base64.b64decode(auth[1].encode(encoding)).partition(':') + except TypeError: + raise exceptions.AuthenticationFailed('Invalid basic header') + + try: + userid = smart_text(auth_parts[0]) + password = smart_text(auth_parts[2]) + except DjangoUnicodeDecodeError: + raise exceptions.AuthenticationFailed('Invalid basic header') + + return self.authenticate_credentials(userid, password) def authenticate_credentials(self, userid, password): """ @@ -59,6 +72,10 @@ class BasicAuthentication(BaseAuthentication): user = authenticate(username=userid, password=password) if user is not None and user.is_active: return (user, None) + raise exceptions.AuthenticationFailed('Invalid username/password') + + def authenticate_header(self, request): + return 'Basic realm="%s"' % self.www_authenticate_realm class SessionAuthentication(BaseAuthentication): @@ -78,7 +95,7 @@ class SessionAuthentication(BaseAuthentication): # Unauthenticated, CSRF validation not required if not user or not user.is_active: - return + return None # Enforce CSRF validation for session based authentication. class CSRFCheck(CsrfViewMiddleware): @@ -89,7 +106,7 @@ class SessionAuthentication(BaseAuthentication): reason = CSRFCheck().process_view(http_request, None, (), {}) if reason: # CSRF failed, bail with explicit error message - raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) + raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason) # CSRF passed with authenticated user return (user, None) @@ -116,14 +133,26 @@ class TokenAuthentication(BaseAuthentication): def authenticate(self, request): auth = request.META.get('HTTP_AUTHORIZATION', '').split() - if len(auth) == 2 and auth[0].lower() == "token": - key = auth[1] - try: - token = self.model.objects.get(key=key) - except self.model.DoesNotExist: - return None + if not auth or auth[0].lower() != "token": + return None + + if len(auth) != 2: + raise exceptions.AuthenticationFailed('Invalid token header') + + return self.authenticate_credentials(auth[1]) + + def authenticate_credentials(self, key): + try: + token = self.model.objects.get(key=key) + except self.model.DoesNotExist: + raise exceptions.AuthenticationFailed('Invalid token') + + if token.user.is_active: + return (token.user, token) + raise exceptions.AuthenticationFailed('User inactive or deleted') + + def authenticate_header(self, request): + return 'Token' - if token.user.is_active: - return (token.user, token) # TODO: OAuthAuthentication diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index d318c723..7c03cb76 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -12,10 +12,11 @@ class ObtainAuthToken(APIView): permission_classes = () parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) renderer_classes = (renderers.JSONRenderer,) + serializer_class = AuthTokenSerializer model = Token def post(self, request): - serializer = AuthTokenSerializer(data=request.DATA) + serializer = self.serializer_class(data=request.DATA) if serializer.is_valid(): token, created = Token.objects.get_or_create(user=serializer.object['user']) return Response({'token': token.key}) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 5924cd6d..ef11b85b 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -126,6 +126,12 @@ else: update_wrapper(view, cls.dispatch, assigned=()) return view +# Taken from @markotibold's attempt at supporting PATCH. +# https://github.com/markotibold/django-rest-framework/tree/patch +http_method_names = set(View.http_method_names) +http_method_names.add('patch') +View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django + # PUT, DELETE do not require CSRF until 1.4. They should. Make it better. if django.VERSION >= (1, 4): from django.middleware.csrf import CsrfViewMiddleware diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 1b710a03..7a4103e1 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,4 +1,5 @@ from rest_framework.views import APIView +import types def api_view(http_method_names): @@ -23,6 +24,14 @@ def api_view(http_method_names): # pass # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this + # api_view applied without (method_names) + assert not(isinstance(http_method_names, types.FunctionType)), \ + '@api_view missing list of allowed HTTP methods' + + # api_view applied with eg. string instead of list of strings + assert isinstance(http_method_names, (list, tuple)), \ + '@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__ + allowed_methods = set(http_method_names) | set(('options',)) WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 89479deb..d635351c 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -23,6 +23,22 @@ class ParseError(APIException): self.detail = detail or self.default_detail +class AuthenticationFailed(APIException): + status_code = status.HTTP_401_UNAUTHORIZED + default_detail = 'Incorrect authentication credentials.' + + def __init__(self, detail=None): + self.detail = detail or self.default_detail + + +class NotAuthenticated(APIException): + status_code = status.HTTP_401_UNAUTHORIZED + default_detail = 'Authentication credentials were not provided.' + + def __init__(self, detail=None): + self.detail = detail or self.default_detail + + class PermissionDenied(APIException): status_code = status.HTTP_403_FORBIDDEN default_detail = 'You do not have permission to perform this action.' diff --git a/rest_framework/fields.py b/rest_framework/fields.py index adea5bf5..a66e1d7c 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,4 +1,3 @@ - from __future__ import unicode_literals import copy @@ -185,11 +184,13 @@ class WritableField(Field): try: if self._use_files: + files = files or {} native = files[field_name] else: native = data[field_name] except KeyError: - if self.default is not None: + if self.default is not None and not self.root.partial: + # Note: partial updates shouldn't set defaults native = self.default else: if self.required: diff --git a/rest_framework/generics.py b/rest_framework/generics.py index dd8dfcf8..19f2b704 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -47,14 +47,16 @@ class GenericAPIView(views.APIView): return serializer_class - def get_serializer(self, instance=None, data=None, files=None): + def get_serializer(self, instance=None, data=None, + files=None, partial=False): """ Return the serializer instance that should be used for validating and deserializing input, and for serializing output. """ serializer_class = self.get_serializer_class() context = self.get_serializer_context() - return serializer_class(instance, data=data, files=files, context=context) + return serializer_class(instance, data=data, files=files, + partial=partial, context=context) class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): @@ -171,6 +173,10 @@ class UpdateAPIView(mixins.UpdateModelMixin, def put(self, request, *args, **kwargs): return self.update(request, *args, **kwargs) + def patch(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, @@ -185,6 +191,23 @@ class ListCreateAPIView(mixins.ListModelMixin, return self.create(request, *args, **kwargs) +class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + SingleObjectAPIView): + """ + Concrete view for retrieving, updating a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, mixins.DestroyModelMixin, SingleObjectAPIView): @@ -211,5 +234,9 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, def put(self, request, *args, **kwargs): return self.update(request, *args, **kwargs) + def patch(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 503376ce..acaf8a71 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -18,11 +18,14 @@ class CreateModelMixin(object): """ def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.DATA, files=request.FILES) + if serializer.is_valid(): self.pre_save(serializer.object) self.object = serializer.save() headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + return Response(serializer.data, status=status.HTTP_201_CREATED, + headers=headers) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) def get_success_headers(self, data): @@ -84,20 +87,21 @@ class UpdateModelMixin(object): Should be mixed in with `SingleObjectBaseView`. """ def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) try: self.object = self.get_object() - created = False + success_status_code = status.HTTP_200_OK except Http404: self.object = None - created = True + success_status_code = status.HTTP_201_CREATED - serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES) + serializer = self.get_serializer(self.object, data=request.DATA, + files=request.FILES, partial=partial) if serializer.is_valid(): self.pre_save(serializer.object) self.object = serializer.save() - status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK - return Response(serializer.data, status=status_code) + return Response(serializer.data, status=success_status_code) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -117,7 +121,8 @@ class UpdateModelMixin(object): # Ensure we clean the attributes so that we don't eg return integer # pk using a string representation, as provided by the url conf kwarg. - obj.full_clean() + if hasattr(obj, 'full_clean'): + obj.full_clean() class DestroyModelMixin(object): diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index d241ade7..92d41e0e 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -34,6 +34,17 @@ class PreviousPageField(serializers.Field): return replace_query_param(url, self.page_field, page) +class DefaultObjectSerializer(serializers.Field): + """ + If no object serializer is specified, then this serializer will be applied + as the default. + """ + + def __init__(self, source=None, context=None): + # Note: Swallow context kwarg - only required for eg. ModelSerializer. + super(DefaultObjectSerializer, self).__init__(source=source) + + class PaginationSerializerOptions(serializers.SerializerOptions): """ An object that stores the options that may be provided to a @@ -44,7 +55,7 @@ class PaginationSerializerOptions(serializers.SerializerOptions): def __init__(self, meta): super(PaginationSerializerOptions, self).__init__(meta) self.object_serializer_class = getattr(meta, 'object_serializer_class', - serializers.Field) + DefaultObjectSerializer) class BasePaginationSerializer(serializers.Serializer): @@ -62,14 +73,13 @@ class BasePaginationSerializer(serializers.Serializer): super(BasePaginationSerializer, self).__init__(*args, **kwargs) results_field = self.results_field object_serializer = self.opts.object_serializer_class - self.fields[results_field] = object_serializer(source='object_list') - def to_native(self, obj): - """ - Prevent default behaviour of iterating over elements, and serializing - each in turn. - """ - return self.convert_object(obj) + if 'context' in kwargs: + context_kwarg = {'context': kwargs['context']} + else: + context_kwarg = {} + + self.fields[results_field] = object_serializer(source='object_list', **context_kwarg) class PaginationSerializer(BasePaginationSerializer): diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 7c01006a..4a2b34a5 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -8,12 +8,12 @@ on the request, such as form content or json encoded data. from django.http import QueryDict from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser from django.http.multipartparser import MultiPartParserError -from django.utils import simplejson as json from rest_framework.compat import yaml, ETParseError from rest_framework.exceptions import ParseError from rest_framework.compat import six from xml.etree import ElementTree as ET from xml.parsers.expat import ExpatError +import json import datetime import decimal diff --git a/rest_framework/relations.py b/rest_framework/relations.py index b7a6e0c1..c4f854ef 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -6,6 +6,7 @@ from django.core.urlresolvers import resolve, get_script_prefix from django import forms from django.forms import widgets from django.forms.models import ModelChoiceIterator +from django.utils.translation import ugettext_lazy as _ from rest_framework.fields import Field, WritableField from rest_framework.reverse import reverse from rest_framework.compat import urlparse @@ -103,7 +104,13 @@ class RelatedField(WritableField): ### Regular serializer stuff... def field_to_native(self, obj, field_name): - value = getattr(obj, self.source or field_name) + try: + value = getattr(obj, self.source or field_name) + except ObjectDoesNotExist: + return None + + if value is None: + return None return self.to_native(value) def field_from_native(self, data, files, field_name, into): @@ -144,7 +151,7 @@ class ManyRelatedMixin(object): value = data.getlist(self.source or field_name) except: # Non-form data - value = data.get(self.source or field_name) + value = data.get(self.source or field_name, []) else: if value == ['']: value = [] @@ -171,6 +178,11 @@ class PrimaryKeyRelatedField(RelatedField): default_read_only = False form_field_class = forms.ChoiceField + default_error_messages = { + 'does_not_exist': _("Invalid pk '%s' - object does not exist."), + 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'), + } + # TODO: Remove these field hacks... def prepare_value(self, obj): return self.to_native(obj.pk) @@ -196,7 +208,11 @@ class PrimaryKeyRelatedField(RelatedField): try: return self.queryset.get(pk=data) except ObjectDoesNotExist: - msg = "Invalid pk '%s' - object does not exist." % smart_text(data) + msg = self.error_messages['does_not_exist'] % smart_text(data) + raise ValidationError(msg) + except (TypeError, ValueError): + received = type(data).__name__ + msg = self.error_messages['incorrect_type'] % received raise ValidationError(msg) def field_to_native(self, obj, field_name): @@ -205,7 +221,10 @@ class PrimaryKeyRelatedField(RelatedField): pk = obj.serializable_value(self.source or field_name) except AttributeError: # RelatedObject (reverse relationship) - obj = getattr(obj, self.source or field_name) + try: + obj = getattr(obj, self.source or field_name) + except ObjectDoesNotExist: + return None return self.to_native(obj.pk) # Forward relationship return self.to_native(pk) @@ -218,6 +237,11 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField): default_read_only = False form_field_class = forms.MultipleChoiceField + default_error_messages = { + 'does_not_exist': _("Invalid pk '%s' - object does not exist."), + 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'), + } + def prepare_value(self, obj): return self.to_native(obj.pk) @@ -252,7 +276,11 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField): try: return self.queryset.get(pk=data) except ObjectDoesNotExist: - msg = "Invalid pk '%s' - object does not exist." % smart_text(data) + msg = self.error_messages['does_not_exist'] % smart_text(data) + raise ValidationError(msg) + except (TypeError, ValueError): + received = type(data).__name__ + msg = self.error_messages['incorrect_type'] % received raise ValidationError(msg) ### Slug relationships @@ -262,6 +290,11 @@ class SlugRelatedField(RelatedField): default_read_only = False form_field_class = forms.ChoiceField + default_error_messages = { + 'does_not_exist': _("Object with %s=%s does not exist."), + 'invalid': _('Invalid value.'), + } + def __init__(self, *args, **kwargs): self.slug_field = kwargs.pop('slug_field', None) assert self.slug_field, 'slug_field is required' @@ -277,8 +310,11 @@ class SlugRelatedField(RelatedField): try: return self.queryset.get(**{self.slug_field: data}) except ObjectDoesNotExist: - raise ValidationError('Object with %s=%s does not exist.' % + raise ValidationError(self.error_messages['does_not_exist'] % (self.slug_field, unicode(data))) + except (TypeError, ValueError): + msg = self.error_messages['invalid'] + raise ValidationError(msg) class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField): @@ -297,6 +333,14 @@ class HyperlinkedRelatedField(RelatedField): default_read_only = False form_field_class = forms.ChoiceField + default_error_messages = { + 'no_match': _('Invalid hyperlink - No URL match'), + 'incorrect_match': _('Invalid hyperlink - Incorrect URL match'), + 'configuration_error': _('Invalid hyperlink due to configuration error'), + 'does_not_exist': _("Invalid hyperlink - object does not exist."), + 'incorrect_type': _('Incorrect type. Expected url string, received %s.'), + } + def __init__(self, *args, **kwargs): try: self.view_name = kwargs.pop('view_name') @@ -333,21 +377,21 @@ class HyperlinkedRelatedField(RelatedField): slug = getattr(obj, self.slug_field, None) if not slug: - raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) + raise Exception('Could not resolve URL for field using view name "%s"' % view_name) kwargs = {self.slug_url_kwarg: slug} try: - return reverse(self.view_name, kwargs=kwargs, request=request, format=format) + return reverse(view_name, kwargs=kwargs, request=request, format=format) except: pass kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} try: - return reverse(self.view_name, kwargs=kwargs, request=request, format=format) + return reverse(view_name, kwargs=kwargs, request=request, format=format) except: pass - raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) + raise Exception('Could not resolve URL for field using view name "%s"' % view_name) def from_native(self, value): # Convert URL -> model instance pk @@ -355,7 +399,13 @@ class HyperlinkedRelatedField(RelatedField): if self.queryset is None: raise Exception('Writable related fields must include a `queryset` argument') - if value.startswith('http:') or value.startswith('https:'): + try: + http_prefix = value.startswith('http:') or value.startswith('https:') + except AttributeError: + msg = self.error_messages['incorrect_type'] + raise ValidationError(msg % type(value).__name__) + + if http_prefix: # If needed convert absolute URLs to relative path value = urlparse.urlparse(value).path prefix = get_script_prefix() @@ -365,10 +415,10 @@ class HyperlinkedRelatedField(RelatedField): try: match = resolve(value) except: - raise ValidationError('Invalid hyperlink - No URL match') + raise ValidationError(self.error_messages['no_match']) - if match.url_name != self.view_name: - raise ValidationError('Invalid hyperlink - Incorrect URL match') + if match.view_name != self.view_name: + raise ValidationError(self.error_messages['incorrect_match']) pk = match.kwargs.get(self.pk_url_kwarg, None) slug = match.kwargs.get(self.slug_url_kwarg, None) @@ -380,14 +430,18 @@ class HyperlinkedRelatedField(RelatedField): elif slug is not None: slug_field = self.get_slug_field() queryset = self.queryset.filter(**{slug_field: slug}) - # If none of those are defined, it's an error. + # If none of those are defined, it's probably a configuation error. else: - raise ValidationError('Invalid hyperlink') + raise ValidationError(self.error_messages['configuration_error']) try: obj = queryset.get() except ObjectDoesNotExist: - raise ValidationError('Invalid hyperlink - object does not exist.') + raise ValidationError(self.error_messages['does_not_exist']) + except (TypeError, ValueError): + msg = self.error_messages['incorrect_type'] + raise ValidationError(msg % type(value).__name__) + return obj @@ -410,6 +464,7 @@ class HyperlinkedIdentityField(Field): # TODO: Make view_name mandatory, and have the # HyperlinkedModelSerializer set it on-the-fly self.view_name = kwargs.pop('view_name', None) + # Optionally the format of the target hyperlink may be specified self.format = kwargs.pop('format', None) self.slug_field = kwargs.pop('slug_field', self.slug_field) @@ -421,9 +476,22 @@ class HyperlinkedIdentityField(Field): def field_to_native(self, obj, field_name): request = self.context.get('request', None) - format = self.format or self.context.get('format', None) + format = self.context.get('format', None) view_name = self.view_name or self.parent.opts.view_name kwargs = {self.pk_url_kwarg: obj.pk} + + # By default use whatever format is given for the current context + # unless the target is a different type to the source. + # + # Eg. Consider a HyperlinkedIdentityField pointing from a json + # representation to an html property of that representation... + # + # '/snippets/1/' should link to '/snippets/1/highlight/' + # ...but... + # '/snippets/1/.json' should link to '/snippets/1/highlight/.html' + if format and self.format and self.format != format: + format = self.format + try: return reverse(view_name, kwargs=kwargs, request=request, format=format) except: @@ -432,18 +500,18 @@ class HyperlinkedIdentityField(Field): slug = getattr(obj, self.slug_field, None) if not slug: - raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) + raise Exception('Could not resolve URL for field using view name "%s"' % view_name) kwargs = {self.slug_url_kwarg: slug} try: - return reverse(self.view_name, kwargs=kwargs, request=request, format=format) + return reverse(view_name, kwargs=kwargs, request=request, format=format) except: pass kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} try: - return reverse(self.view_name, kwargs=kwargs, request=request, format=format) + return reverse(view_name, kwargs=kwargs, request=request, format=format) except: pass - raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) + raise Exception('Could not resolve URL for field using view name "%s"' % view_name) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 54930167..b3ee0690 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -10,10 +10,10 @@ from __future__ import unicode_literals import copy import string +import json from django import forms from django.http.multipartparser import parse_header from django.template import RequestContext, loader, Template -from django.utils import simplejson as json from rest_framework.compat import yaml from rest_framework.exceptions import ConfigurationError from rest_framework.settings import api_settings diff --git a/rest_framework/request.py b/rest_framework/request.py index 048a1c41..23e1da87 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -86,6 +86,7 @@ class Request(object): self._method = Empty self._content_type = Empty self._stream = Empty + self._authenticator = None if self.parser_context is None: self.parser_context = {} @@ -166,7 +167,7 @@ class Request(object): by the authentication classes provided to the request. """ if not hasattr(self, '_user'): - self._user, self._auth = self._authenticate() + self._authenticator, self._user, self._auth = self._authenticate() return self._user @user.setter @@ -185,7 +186,7 @@ class Request(object): request, such as an authentication token. """ if not hasattr(self, '_auth'): - self._user, self._auth = self._authenticate() + self._authenticator, self._user, self._auth = self._authenticate() return self._auth @auth.setter @@ -196,6 +197,14 @@ class Request(object): """ self._auth = value + @property + def successful_authenticator(self): + """ + Return the instance of the authentication instance class that was used + to authenticate the request, or `None`. + """ + return self._authenticator + def _load_data_and_files(self): """ Parses the request content into self.DATA and self.FILES. @@ -299,21 +308,23 @@ class Request(object): def _authenticate(self): """ - Attempt to authenticate the request using each authentication instance in turn. - Returns a two-tuple of (user, authtoken). + Attempt to authenticate the request using each authentication instance + in turn. + Returns a three-tuple of (authenticator, user, authtoken). """ for authenticator in self.authenticators: user_auth_tuple = authenticator.authenticate(self) if not user_auth_tuple is None: - return user_auth_tuple + user, auth = user_auth_tuple + return (authenticator, user, auth) return self._not_authenticated() def _not_authenticated(self): """ - Return a two-tuple of (user, authtoken), representing an - unauthenticated request. + Return a three-tuple of (authenticator, user, authtoken), representing + an unauthenticated request. - By default this will be (AnonymousUser, None). + By default this will be (None, AnonymousUser, None). """ if api_settings.UNAUTHENTICATED_USER: user = api_settings.UNAUTHENTICATED_USER() @@ -325,7 +336,7 @@ class Request(object): else: auth = None - return (user, auth) + return (None, user, auth) def __getattr__(self, attr): """ diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 663f166b..3d3bcb3c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -2,6 +2,7 @@ import copy import datetime import types from decimal import Decimal +from django.core.paginator import Page from django.db import models from django.forms import widgets from django.utils.datastructures import SortedDict @@ -209,6 +210,11 @@ class BaseSerializer(Field): Converts a dictionary of data into a dictionary of deserialized fields. """ reverted_data = {} + + if data is not None and not isinstance(data, dict): + self._errors['non_field_errors'] = [u'Invalid data'] + return None + for field_name, field in self.fields.items(): field.initialize(parent=self, field_name=field_name) try: @@ -223,6 +229,8 @@ class BaseSerializer(Field): Run `validate_<fieldname>()` and `validate()` methods on the serializer """ for field_name, field in self.fields.items(): + if field_name in self._errors: + continue try: validate_method = getattr(self, 'validate_%s' % field_name, None) if validate_method: @@ -267,7 +275,11 @@ class BaseSerializer(Field): """ Serialize objects -> primitives. """ - if hasattr(obj, '__iter__'): + # Note: At the moment we have an ugly hack to determine if we should + # walk over iterables. At some point, serializers will require an + # explicit `many=True` in order to iterate over a set, and this hack + # will disappear. + if hasattr(obj, '__iter__') and not isinstance(obj, Page): return [self.convert_object(item) for item in obj] return self.convert_object(obj) @@ -277,7 +289,7 @@ class BaseSerializer(Field): """ if hasattr(data, '__iter__') and not isinstance(data, dict): # TODO: error data when deserializing lists - return (self.from_native(item) for item in data) + return [self.from_native(item, None) for item in data] self._errors = {} if data is not None or files is not None: @@ -294,15 +306,21 @@ class BaseSerializer(Field): Override default so that we can apply ModelSerializer as a nested field to relationships. """ - if self.source: - for component in self.source.split('.'): - obj = getattr(obj, component) + if self.source == '*': + return self.to_native(obj) + + try: + if self.source: + for component in self.source.split('.'): + obj = getattr(obj, component) + if is_simple_callable(obj): + obj = obj() + else: + obj = getattr(obj, field_name) if is_simple_callable(obj): obj = obj() - else: - obj = getattr(obj, field_name) - if is_simple_callable(obj): - obj = value() + except ObjectDoesNotExist: + return None # If the object has an "all" method, assume it's a relationship if is_simple_callable(getattr(obj, 'all', None)): @@ -408,7 +426,7 @@ class ModelSerializer(Serializer): """ Returns a default instance of the pk field. """ - return Field() + return self.get_field(model_field) def get_nested_field(self, model_field): """ @@ -426,7 +444,7 @@ class ModelSerializer(Serializer): # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) kwargs = { - 'null': model_field.null, + 'null': model_field.null or model_field.blank, 'queryset': model_field.rel.to._default_manager } @@ -445,11 +463,14 @@ class ModelSerializer(Serializer): if model_field.null or model_field.blank: kwargs['required'] = False + if isinstance(model_field, models.AutoField) or not model_field.editable: + kwargs['read_only'] = True + if model_field.has_default(): kwargs['required'] = False kwargs['default'] = model_field.get_default() - if model_field.__class__ == models.TextField: + if issubclass(model_field.__class__, models.TextField): kwargs['widget'] = widgets.Textarea # TODO: TypedChoiceField? @@ -458,6 +479,7 @@ class ModelSerializer(Serializer): return ChoiceField(**kwargs) field_mapping = { + models.AutoField: IntegerField, models.FloatField: FloatField, models.IntegerField: IntegerField, models.PositiveIntegerField: IntegerField, @@ -492,6 +514,22 @@ class ModelSerializer(Serializer): exclusions.remove(field_name) return exclusions + def full_clean(self, instance): + """ + Perform Django's full_clean, and populate the `errors` dictionary + if any validation errors occur. + + Note that we don't perform this inside the `.restore_object()` method, + so that subclasses can override `.restore_object()`, and still get + the full_clean validation checking. + """ + try: + instance.full_clean(exclude=self.get_validation_exclusions()) + except ValidationError, err: + self._errors = err.message_dict + return None + return instance + def restore_object(self, attrs, instance=None): """ Restore the model instance. @@ -531,13 +569,21 @@ class ModelSerializer(Serializer): return instance - def save(self, save_m2m=True): + def from_native(self, data, files): + """ + Override the default method to also include model field validation. + """ + instance = super(ModelSerializer, self).from_native(data, files) + if instance: + return self.full_clean(instance) + + def save(self): """ Save the deserialized object and return it. """ self.object.save() - if getattr(self, 'm2m_data', None) and save_m2m: + if getattr(self, 'm2m_data', None): for accessor_name, object_list in self.m2m_data.items(): setattr(self.object, accessor_name, object_list) self.m2m_data = {} diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 2358d188..13d03e62 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -119,9 +119,8 @@ def import_from_string(val, setting_name): module_path, class_name = '.'.join(parts[:-1]), parts[-1] module = importlib.import_module(module_path) return getattr(module, class_name) - except: - raise - msg = "Could not import '%s' for API setting '%s'" % (val, setting_name) + except ImportError as e: + msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) raise ImportError(msg) diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 42e49cb9..092bf2e4 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -112,7 +112,7 @@ <div class="request-info"> <pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre> - <div> + </div> <div class="response-info"> <pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %} {% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|urlize_quoted_links }}</span> diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html index 6e2bd8d4..e10ce20f 100644 --- a/rest_framework/templates/rest_framework/login.html +++ b/rest_framework/templates/rest_framework/login.html @@ -25,14 +25,14 @@ <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> {% csrf_token %} <div id="div_id_username" class="clearfix control-group"> - <div class="controls" style="height: 30px"> - <Label class="span4" style="margin-top: 3px">Username:</label> + <div class="controls"> + <Label class="span4">Username:</label> <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> </div> </div> <div id="div_id_password" class="clearfix control-group"> - <div class="controls" style="height: 30px"> - <Label class="span4" style="margin-top: 3px">Password:</label> + <div class="controls"> + <Label class="span4">Password:</label> <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> </div> </div> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 4205e57c..cbafbe0e 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -27,7 +27,7 @@ register = template.Library() # conflicts with this rest_framework template tag module. try: # Django 1.5+ - from django.contrib.staticfiles.templatetags import StaticFilesNode + from django.contrib.staticfiles.templatetags.staticfiles import StaticFilesNode @register.tag('static') def do_static(parser, token): diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 8c0bfc47..ba2042cb 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -1,14 +1,13 @@ from django.contrib.auth.models import User from django.http import HttpResponse from django.test import Client, TestCase -from django.utils import simplejson as json - from rest_framework import permissions from rest_framework.authtoken.models import Token -from rest_framework.authentication import TokenAuthentication +from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication from rest_framework.compat import patterns from rest_framework.views import APIView +import json import base64 @@ -21,10 +20,10 @@ class MockView(APIView): def put(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) -MockView.authentication_classes += (TokenAuthentication,) - urlpatterns = patterns('', - (r'^$', MockView.as_view()), + (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), + (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), + (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), ) @@ -42,25 +41,26 @@ class BasicAuthTests(TestCase): def test_post_form_passing_basic_auth(self): """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" - auth = 'Basic ' + base64.encodestring(('%s:%s' % (self.username, self.password)).encode('iso-8859-1')).strip().decode('iso-8859-1') - response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).encode('iso-8859-1').strip().decode('iso-8859-1') + response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) def test_post_json_passing_basic_auth(self): """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" - auth = 'Basic ' + base64.encodestring(('%s:%s' % (self.username, self.password)).encode('iso-8859-1')).strip().decode('iso-8859-1') - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).encode('iso-8859-1').strip().decode('iso-8859-1') + response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) def test_post_form_failing_basic_auth(self): """Ensure POSTing form over basic auth without correct credentials fails""" - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/basic/', {'example': 'example'}) + self.assertEqual(response.status_code, 401) def test_post_json_failing_basic_auth(self): """Ensure POSTing json over basic auth without correct credentials fails""" - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, 401) + self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') class SessionAuthTests(TestCase): @@ -83,7 +83,7 @@ class SessionAuthTests(TestCase): Ensure POSTing form over session authentication without CSRF token fails. """ self.csrf_client.login(username=self.username, password=self.password) - response = self.csrf_client.post('/', {'example': 'example'}) + response = self.csrf_client.post('/session/', {'example': 'example'}) self.assertEqual(response.status_code, 403) def test_post_form_session_auth_passing(self): @@ -91,7 +91,7 @@ class SessionAuthTests(TestCase): Ensure POSTing form over session authentication with logged in user and CSRF token passes. """ self.non_csrf_client.login(username=self.username, password=self.password) - response = self.non_csrf_client.post('/', {'example': 'example'}) + response = self.non_csrf_client.post('/session/', {'example': 'example'}) self.assertEqual(response.status_code, 200) def test_put_form_session_auth_passing(self): @@ -99,14 +99,14 @@ class SessionAuthTests(TestCase): Ensure PUTting form over session authentication with logged in user and CSRF token passes. """ self.non_csrf_client.login(username=self.username, password=self.password) - response = self.non_csrf_client.put('/', {'example': 'example'}) + response = self.non_csrf_client.put('/session/', {'example': 'example'}) self.assertEqual(response.status_code, 200) def test_post_form_session_auth_failing(self): """ Ensure POSTing form over session authentication without logged in user fails. """ - response = self.csrf_client.post('/', {'example': 'example'}) + response = self.csrf_client.post('/session/', {'example': 'example'}) self.assertEqual(response.status_code, 403) @@ -127,24 +127,24 @@ class TokenAuthTests(TestCase): def test_post_form_passing_token_auth(self): """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) def test_post_json_passing_token_auth(self): """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) def test_post_form_failing_token_auth(self): """Ensure POSTing form over token auth without correct credentials fails""" - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/token/', {'example': 'example'}) + self.assertEqual(response.status_code, 401) def test_post_json_failing_token_auth(self): """Ensure POSTing json over token auth without correct credentials fails""" - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, 401) def test_token_has_auto_assigned_key_if_none_provided(self): """Ensure creating a token with no key will auto-assign a key""" diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index 8079c8cb..82f912e9 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -1,5 +1,4 @@ from django.test import TestCase -from django.test.client import RequestFactory from rest_framework import status from rest_framework.response import Response from rest_framework.renderers import JSONRenderer @@ -17,6 +16,8 @@ from rest_framework.decorators import ( permission_classes, ) +from rest_framework.tests.utils import RequestFactory + class DecoratorTestCase(TestCase): @@ -27,13 +28,27 @@ class DecoratorTestCase(TestCase): response.request = request return APIView.finalize_response(self, request, response, *args, **kwargs) - def test_wrap_view(self): + def test_api_view_incorrect(self): + """ + If @api_view is not applied correct, we should raise an assertion. + """ - @api_view(['GET']) + @api_view def view(request): - return Response({}) + return Response() + + request = self.factory.get('/') + self.assertRaises(AssertionError, view, request) - self.assertTrue(isinstance(view.cls_instance, APIView)) + def test_api_view_incorrect_arguments(self): + """ + If @api_view is missing arguments, we should raise an assertion. + """ + + with self.assertRaises(AssertionError): + @api_view('GET') + def view(request): + return Response() def test_calling_method(self): @@ -63,6 +78,20 @@ class DecoratorTestCase(TestCase): response = view(request) self.assertEqual(response.status_code, 405) + def test_calling_patch_method(self): + + @api_view(['GET', 'PATCH']) + def view(request): + return Response({}) + + request = self.factory.patch('/') + response = view(request) + self.assertEqual(response.status_code, 200) + + request = self.factory.post('/') + response = view(request) + self.assertEqual(response.status_code, 405) + def test_renderer_classes(self): @api_view(['GET']) diff --git a/rest_framework/tests/extras/__init__.py b/rest_framework/tests/extras/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/extras/__init__.py diff --git a/rest_framework/tests/extras/bad_import.py b/rest_framework/tests/extras/bad_import.py new file mode 100644 index 00000000..68263d94 --- /dev/null +++ b/rest_framework/tests/extras/bad_import.py @@ -0,0 +1 @@ +raise ValueError diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py new file mode 100644 index 00000000..8068272d --- /dev/null +++ b/rest_framework/tests/fields.py @@ -0,0 +1,49 @@ +""" +General serializer field tests. +""" + +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class TimestampedModel(models.Model): + added = models.DateTimeField(auto_now_add=True) + updated = models.DateTimeField(auto_now=True) + + +class CharPrimaryKeyModel(models.Model): + id = models.CharField(max_length=20, primary_key=True) + + +class TimestampedModelSerializer(serializers.ModelSerializer): + class Meta: + model = TimestampedModel + + +class CharPrimaryKeyModelSerializer(serializers.ModelSerializer): + class Meta: + model = CharPrimaryKeyModel + + +class ReadOnlyFieldTests(TestCase): + def test_auto_now_fields_read_only(self): + """ + auto_now and auto_now_add fields should be read_only by default. + """ + serializer = TimestampedModelSerializer() + self.assertEquals(serializer.fields['added'].read_only, True) + + def test_auto_pk_fields_read_only(self): + """ + AutoField fields should be read_only by default. + """ + serializer = TimestampedModelSerializer() + self.assertEquals(serializer.fields['id'].read_only, True) + + def test_non_auto_pk_fields_not_read_only(self): + """ + PK fields other than AutoField fields should not be read_only by default. + """ + serializer = CharPrimaryKeyModelSerializer() + self.assertEquals(serializer.fields['id'].read_only, False) diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py index ca6bc905..0434f900 100644 --- a/rest_framework/tests/files.py +++ b/rest_framework/tests/files.py @@ -26,7 +26,6 @@ class UploadedFileSerializer(serializers.Serializer): class FileSerializerTests(TestCase): - def test_create(self): now = datetime.datetime.now() file = BytesIO(six.b('stuff')) @@ -38,3 +37,16 @@ class FileSerializerTests(TestCase): self.assertEquals(serializer.object.created, uploaded_file.created) self.assertEquals(serializer.object.file, uploaded_file.file) self.assertFalse(serializer.object is uploaded_file) + + def test_creation_failure(self): + """ + Passing files=None should result in an ValidationError + + Regression test for: + https://github.com/tomchristie/django-rest-framework/issues/542 + """ + now = datetime.datetime.now() + + serializer = UploadedFileSerializer(data={'created': now}) + self.assertFalse(serializer.is_valid()) + self.assertIn('file', serializer.errors) diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py index ba29dbed..72070a1a 100644 --- a/rest_framework/tests/genericrelations.py +++ b/rest_framework/tests/genericrelations.py @@ -1,27 +1,63 @@ from __future__ import unicode_literals +from django.contrib.contenttypes.models import ContentType +from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey +from django.db import models from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import * + + +class Tag(models.Model): + """ + Tags have a descriptive slug, and are attached to an arbitrary object. + """ + tag = models.SlugField() + content_type = models.ForeignKey(ContentType) + object_id = models.PositiveIntegerField() + tagged_item = GenericForeignKey('content_type', 'object_id') + + def __unicode__(self): + return self.tag + + +class Bookmark(models.Model): + """ + A URL bookmark that may have multiple tags attached. + """ + url = models.URLField() + tags = GenericRelation(Tag) + + def __unicode__(self): + return 'Bookmark: %s' % self.url + + +class Note(models.Model): + """ + A textual note that may have multiple tags attached. + """ + text = models.TextField() + tags = GenericRelation(Tag) + + def __unicode__(self): + return 'Note: %s' % self.text class TestGenericRelations(TestCase): def setUp(self): - bookmark = Bookmark(url='https://www.djangoproject.com/') - bookmark.save() - django = Tag(tag_name='django') - django.save() - python = Tag(tag_name='python') - python.save() - t1 = TaggedItem(content_object=bookmark, tag=django) - t1.save() - t2 = TaggedItem(content_object=bookmark, tag=python) - t2.save() - self.bookmark = bookmark - - def test_reverse_generic_relation(self): + self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') + Tag.objects.create(tagged_item=self.bookmark, tag='django') + Tag.objects.create(tagged_item=self.bookmark, tag='python') + self.note = Note.objects.create(text='Remember the milk') + Tag.objects.create(tagged_item=self.note, tag='reminder') + + def test_generic_relation(self): + """ + Test a relationship that spans a GenericRelation field. + IE. A reverse generic relationship. + """ + class BookmarkSerializer(serializers.ModelSerializer): - tags = serializers.ManyRelatedField(source='tags') + tags = serializers.ManyRelatedField() class Meta: model = Bookmark @@ -33,3 +69,33 @@ class TestGenericRelations(TestCase): 'url': 'https://www.djangoproject.com/' } self.assertEquals(serializer.data, expected) + + def test_generic_fk(self): + """ + Test a relationship that spans a GenericForeignKey field. + IE. A forward generic relationship. + """ + + class TagSerializer(serializers.ModelSerializer): + tagged_item = serializers.RelatedField() + + class Meta: + model = Tag + exclude = ('id', 'content_type', 'object_id') + + serializer = TagSerializer(Tag.objects.all()) + expected = [ + { + 'tag': u'django', + 'tagged_item': u'Bookmark: https://www.djangoproject.com/' + }, + { + 'tag': u'python', + 'tagged_item': u'Bookmark: https://www.djangoproject.com/' + }, + { + 'tag': u'reminder', + 'tagged_item': u'Note: Remember the milk' + } + ] + self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 215de0c4..fd01312a 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,12 +1,11 @@ from __future__ import unicode_literals - from django.db import models from django.test import TestCase -from django.test.client import RequestFactory -from django.utils import simplejson as json from rest_framework import generics, serializers, status +from rest_framework.tests.utils import RequestFactory from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel from rest_framework.compat import six +import json factory = RequestFactory() @@ -183,6 +182,20 @@ class TestInstanceView(TestCase): updated = self.objects.get(id=1) self.assertEquals(updated.text, 'foobar') + def test_patch_instance_view(self): + """ + PATCH requests to RetrieveUpdateDestroyAPIView should update an object. + """ + content = {'text': 'foobar'} + request = factory.patch('/1', json.dumps(content), + content_type='application/json') + + response = self.view(request, pk=1).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) + updated = self.objects.get(id=1) + self.assertEquals(updated.text, 'foobar') + def test_delete_instance_view(self): """ DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index ee4d8e57..c6a8224b 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -1,6 +1,6 @@ +import json from django.test import TestCase from django.test.client import RequestFactory -from django.utils import simplejson as json from rest_framework import generics, status, serializers from rest_framework.compat import patterns, url from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 0759650a..9ab15328 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -71,6 +71,7 @@ class SlugBasedModel(RESTFrameworkModel): class DefaultValueModel(RESTFrameworkModel): text = models.CharField(default='foobar', max_length=100) + extra = models.CharField(blank=True, null=True, max_length=100) class CallableDefaultValueModel(RESTFrameworkModel): @@ -85,27 +86,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel): text = models.CharField(max_length=100, default='anchor') rel = models.ManyToManyField(Anchor) -# Models to test generic relations - - -class Tag(RESTFrameworkModel): - tag_name = models.SlugField() - - -class TaggedItem(RESTFrameworkModel): - tag = models.ForeignKey(Tag, related_name='items') - content_type = models.ForeignKey(ContentType) - object_id = models.PositiveIntegerField() - content_object = GenericForeignKey('content_type', 'object_id') - - def __unicode__(self): - return self.tag.tag_name - - -class Bookmark(RESTFrameworkModel): - url = models.URLField() - tags = GenericRelation(TaggedItem) - # Model to test filtering. class FilterableItem(RESTFrameworkModel): @@ -176,3 +156,42 @@ class OptionalRelationModel(RESTFrameworkModel): # Model for RegexField class Book(RESTFrameworkModel): isbn = models.CharField(max_length=13) + + +# Models for relations tests +# ManyToMany +class ManyToManyTarget(RESTFrameworkModel): + name = models.CharField(max_length=100) + + +class ManyToManySource(RESTFrameworkModel): + name = models.CharField(max_length=100) + targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') + + +# ForeignKey +class ForeignKeyTarget(RESTFrameworkModel): + name = models.CharField(max_length=100) + + +class ForeignKeySource(RESTFrameworkModel): + name = models.CharField(max_length=100) + target = models.ForeignKey(ForeignKeyTarget, related_name='sources') + + +# Nullable ForeignKey +class NullableForeignKeySource(RESTFrameworkModel): + name = models.CharField(max_length=100) + target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, + related_name='nullable_sources') + + +# OneToOne +class OneToOneTarget(RESTFrameworkModel): + name = models.CharField(max_length=100) + + +class NullableOneToOneSource(RESTFrameworkModel): + name = models.CharField(max_length=100) + target = models.OneToOneField(OneToOneTarget, null=True, blank=True, + related_name='nullable_source') diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 81d297a1..697dfb5b 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -181,10 +181,10 @@ class UnitTestPagination(TestCase): """ Ensure context gets passed through to the object serializer. """ - serializer = PassOnContextPaginationSerializer(self.first_page) + serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) serializer.data results = serializer.fields[serializer.results_field] - self.assertTrue(serializer.context is results.context) + self.assertEquals(serializer.context, results.context) class TestUnpaginated(TestCase): @@ -252,6 +252,8 @@ class TestCustomPaginateByParam(TestCase): self.assertEquals(response.data['results'], self.data[:5]) +### Tests for context in pagination serializers + class CustomField(serializers.Field): def to_native(self, value): if not 'view' in self.context: @@ -262,6 +264,11 @@ class CustomField(serializers.Field): class BasicModelSerializer(serializers.Serializer): text = CustomField() + def __init__(self, *args, **kwargs): + super(BasicModelSerializer, self).__init__(*args, **kwargs) + if not 'view' in self.context: + raise RuntimeError("context isn't getting passed into serializer init") + class TestContextPassedToCustomField(TestCase): def setUp(self): @@ -279,3 +286,39 @@ class TestContextPassedToCustomField(TestCase): self.assertEquals(response.status_code, status.HTTP_200_OK) + +### Tests for custom pagination serializers + +class LinksSerializer(serializers.Serializer): + next = pagination.NextPageField(source='*') + prev = pagination.PreviousPageField(source='*') + + +class CustomPaginationSerializer(pagination.BasePaginationSerializer): + links = LinksSerializer(source='*') # Takes the page object as the source + total_results = serializers.Field(source='paginator.count') + + results_field = 'objects' + + +class TestCustomPaginationSerializer(TestCase): + def setUp(self): + objects = ['john', 'paul', 'george', 'ringo'] + paginator = Paginator(objects, 2) + self.page = paginator.page(1) + + def test_custom_pagination_serializer(self): + request = RequestFactory().get('/foobar') + serializer = CustomPaginationSerializer( + instance=self.page, + context={'request': request} + ) + expected = { + 'links': { + 'next': 'http://testserver/foobar?page=2', + 'prev': None + }, + 'total_results': 4, + 'objects': ['john', 'paul'] + } + self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py new file mode 100644 index 00000000..edc85f9e --- /dev/null +++ b/rest_framework/tests/relations.py @@ -0,0 +1,47 @@ +""" +General tests for relational fields. +""" + +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class NullModel(models.Model): + pass + + +class FieldTests(TestCase): + def test_pk_related_field_with_empty_string(self): + """ + Regression test for #446 + + https://github.com/tomchristie/django-rest-framework/issues/446 + """ + field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + def test_hyperlinked_related_field_with_empty_string(self): + field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + def test_slug_related_field_with_empty_string(self): + field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + +class TestManyRelateMixin(TestCase): + def test_missing_many_to_many_related_field(self): + ''' + Regression test for #632 + + https://github.com/tomchristie/django-rest-framework/pull/632 + ''' + field = serializers.ManyRelatedField(read_only=False) + + into = {} + field.field_from_native({}, None, 'field_name', into) + self.assertEqual(into['field_name'], []) diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index 407c04e0..b4ad3166 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -1,9 +1,9 @@ from __future__ import unicode_literals -from django.db import models from django.test import TestCase from rest_framework import serializers from rest_framework.compat import patterns, url +from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource def dummy_view(request, pk): @@ -15,20 +15,11 @@ urlpatterns = patterns('', url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'), url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), + url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), + url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), ) -# ManyToMany - -class ManyToManyTarget(models.Model): - name = models.CharField(max_length=100) - - -class ManyToManySource(models.Model): - name = models.CharField(max_length=100) - targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') - - class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail') @@ -41,17 +32,6 @@ class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): model = ManyToManySource -# ForeignKey - -class ForeignKeyTarget(models.Model): - name = models.CharField(max_length=100) - - -class ForeignKeySource(models.Model): - name = models.CharField(max_length=100) - target = models.ForeignKey(ForeignKeyTarget, related_name='sources') - - class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail') @@ -65,16 +45,17 @@ class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): # Nullable ForeignKey +class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = NullableForeignKeySource -class NullableForeignKeySource(models.Model): - name = models.CharField(max_length=100) - target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, - related_name='nullable_sources') +# OneToOne +class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): + nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail') -class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: - model = NullableForeignKeySource + model = OneToOneTarget # TODO: Add test that .data cannot be accessed prior to .is_valid @@ -236,6 +217,13 @@ class HyperlinkedForeignKeyTests(TestCase): ] self.assertEquals(serializer.data, expected) + def test_foreign_key_update_incorrect_type(self): + data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': 2} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'target': [u'Incorrect type. Expected url string, received int.']}) + def test_reverse_foreign_key_update(self): data = {'url': '/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']} instance = ForeignKeyTarget.objects.get(pk=2) @@ -248,7 +236,7 @@ class HyperlinkedForeignKeyTests(TestCase): expected = [ {'url': '/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, {'url': '/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, - ] + ] self.assertEquals(new_serializer.data, expected) serializer.save() @@ -434,3 +422,24 @@ class HyperlinkedNullableForeignKeyTests(TestCase): # {'id': 2, 'name': 'target-2', 'sources': []}, # ] # self.assertEquals(serializer.data, expected) + + +class HyperlinkedNullableOneToOneTests(TestCase): + urls = 'rest_framework.tests.relations_hyperlink' + + def setUp(self): + target = OneToOneTarget(name='target-1') + target.save() + new_target = OneToOneTarget(name='target-2') + new_target.save() + source = NullableOneToOneSource(name='source-1', target=target) + source.save() + + def test_reverse_foreign_key_retrieve_with_null(self): + queryset = OneToOneTarget.objects.all() + serializer = NullableOneToOneTargetSerializer(queryset) + expected = [ + {'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'}, + {'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None}, + ] + self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index 442cbebe..e81f0e42 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -1,19 +1,7 @@ from __future__ import unicode_literals - -from django.db import models from django.test import TestCase from rest_framework import serializers - - -# ForeignKey - -class ForeignKeyTarget(models.Model): - name = models.CharField(max_length=100) - - -class ForeignKeySource(models.Model): - name = models.CharField(max_length=100) - target = models.ForeignKey(ForeignKeyTarget, related_name='sources') +from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource class ForeignKeySourceSerializer(serializers.ModelSerializer): @@ -34,20 +22,24 @@ class ForeignKeyTargetSerializer(serializers.ModelSerializer): model = ForeignKeyTarget -# Nullable ForeignKey - -class NullableForeignKeySource(models.Model): - name = models.CharField(max_length=100) - target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, - related_name='nullable_sources') - - class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: depth = 1 model = NullableForeignKeySource +class NullableOneToOneSourceSerializer(serializers.ModelSerializer): + class Meta: + model = NullableOneToOneSource + + +class NullableOneToOneTargetSerializer(serializers.ModelSerializer): + nullable_source = NullableOneToOneSourceSerializer() + + class Meta: + model = OneToOneTarget + + class ReverseForeignKeyTests(TestCase): def setUp(self): target = ForeignKeyTarget(name='target-1') @@ -102,3 +94,22 @@ class NestedNullableForeignKeyTests(TestCase): {'id': 3, 'name': 'source-3', 'target': None}, ] self.assertEquals(serializer.data, expected) + + +class NestedNullableOneToOneTests(TestCase): + def setUp(self): + target = OneToOneTarget(name='target-1') + target.save() + new_target = OneToOneTarget(name='target-2') + new_target.save() + source = NullableOneToOneSource(name='source-1', target=target) + source.save() + + def test_reverse_foreign_key_retrieve_with_null(self): + queryset = OneToOneTarget.objects.all() + serializer = NullableOneToOneTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}}, + {'id': 2, 'name': u'target-2', 'nullable_source': None}, + ] + self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index a04c5c80..4d00795a 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -3,17 +3,7 @@ from __future__ import unicode_literals from django.db import models from django.test import TestCase from rest_framework import serializers - - -# ManyToMany - -class ManyToManyTarget(models.Model): - name = models.CharField(max_length=100) - - -class ManyToManySource(models.Model): - name = models.CharField(max_length=100) - targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') +from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource class ManyToManyTargetSerializer(serializers.ModelSerializer): @@ -28,17 +18,6 @@ class ManyToManySourceSerializer(serializers.ModelSerializer): model = ManyToManySource -# ForeignKey - -class ForeignKeyTarget(models.Model): - name = models.CharField(max_length=100) - - -class ForeignKeySource(models.Model): - name = models.CharField(max_length=100) - target = models.ForeignKey(ForeignKeyTarget, related_name='sources') - - class ForeignKeyTargetSerializer(serializers.ModelSerializer): sources = serializers.ManyPrimaryKeyRelatedField() @@ -51,17 +30,17 @@ class ForeignKeySourceSerializer(serializers.ModelSerializer): model = ForeignKeySource -# Nullable ForeignKey +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): + class Meta: + model = NullableForeignKeySource -class NullableForeignKeySource(models.Model): - name = models.CharField(max_length=100) - target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, - related_name='nullable_sources') +# OneToOne +class NullableOneToOneTargetSerializer(serializers.ModelSerializer): + nullable_source = serializers.PrimaryKeyRelatedField() -class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: - model = NullableForeignKeySource + model = OneToOneTarget # TODO: Add test that .data cannot be accessed prior to .is_valid @@ -218,6 +197,13 @@ class PKForeignKeyTests(TestCase): ] self.assertEquals(serializer.data, expected) + def test_foreign_key_update_incorrect_type(self): + data = {'id': 1, 'name': u'source-1', 'target': 'foo'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'target': [u'Incorrect type. Expected pk value, received str.']}) + def test_reverse_foreign_key_update(self): data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} instance = ForeignKeyTarget.objects.get(pk=2) @@ -230,7 +216,7 @@ class PKForeignKeyTests(TestCase): expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, - ] + ] self.assertEquals(new_serializer.data, expected) serializer.save() @@ -414,3 +400,22 @@ class PKNullableForeignKeyTests(TestCase): # {'id': 2, 'name': 'target-2', 'sources': []}, # ] # self.assertEquals(serializer.data, expected) + + +class PKNullableOneToOneTests(TestCase): + def setUp(self): + target = OneToOneTarget(name='target-1') + target.save() + new_target = OneToOneTarget(name='target-2') + new_target.save() + source = NullableOneToOneSource(name='source-1', target=target) + source.save() + + def test_reverse_foreign_key_retrieve_with_null(self): + queryset = OneToOneTarget.objects.all() + serializer = NullableOneToOneTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'nullable_source': 1}, + {'id': 2, 'name': u'target-2', 'nullable_source': None}, + ] + self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/relations_slug.py new file mode 100644 index 00000000..37ccc75e --- /dev/null +++ b/rest_framework/tests/relations_slug.py @@ -0,0 +1,257 @@ +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget + + +class ForeignKeyTargetSerializer(serializers.ModelSerializer): + sources = serializers.ManySlugRelatedField(slug_field='name') + + class Meta: + model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): + target = serializers.SlugRelatedField(slug_field='name') + + class Meta: + model = ForeignKeySource + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): + target = serializers.SlugRelatedField(slug_field='name', null=True) + + class Meta: + model = NullableForeignKeySource + + +# TODO: M2M Tests, FKTests (Non-nulable), One2One +class PKForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + new_target = ForeignKeyTarget(name='target-2') + new_target.save() + for idx in range(1, 4): + source = ForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve(self): + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 'target-1'}, + {'id': 2, 'name': u'source-2', 'target': 'target-1'}, + {'id': 3, 'name': u'source-3', 'target': 'target-1'} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_foreign_key_retrieve(self): + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, + {'id': 2, 'name': u'target-2', 'sources': []}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update(self): + data = {'id': 1, 'name': u'source-1', 'target': 'target-2'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 'target-2'}, + {'id': 2, 'name': u'source-2', 'target': 'target-1'}, + {'id': 3, 'name': u'source-3', 'target': 'target-1'} + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update_incorrect_type(self): + data = {'id': 1, 'name': u'source-1', 'target': 123} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'target': [u'Object with name=123 does not exist.']}) + + def test_reverse_foreign_key_update(self): + data = {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']} + instance = ForeignKeyTarget.objects.get(pk=2) + serializer = ForeignKeyTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + # We shouldn't have saved anything to the db yet since save + # hasn't been called. + queryset = ForeignKeyTarget.objects.all() + new_serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, + {'id': 2, 'name': u'target-2', 'sources': []}, + ] + self.assertEquals(new_serializer.data, expected) + + serializer.save() + self.assertEquals(serializer.data, data) + + # Ensure target 2 is update, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': ['source-2']}, + {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_create(self): + data = {'id': 4, 'name': u'source-4', 'target': 'target-2'} + serializer = ForeignKeySourceSerializer(data=data) + serializer.is_valid() + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'source-4') + + # Ensure source 4 is added, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 'target-1'}, + {'id': 2, 'name': u'source-2', 'target': 'target-1'}, + {'id': 3, 'name': u'source-3', 'target': 'target-1'}, + {'id': 4, 'name': u'source-4', 'target': 'target-2'}, + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_foreign_key_create(self): + data = {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']} + serializer = ForeignKeyTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'target-3') + + # Ensure target 3 is added, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': ['source-2']}, + {'id': 2, 'name': u'target-2', 'sources': []}, + {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update_with_invalid_null(self): + data = {'id': 1, 'name': u'source-1', 'target': None} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) + + +class SlugNullableForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + for idx in range(1, 4): + if idx == 3: + target = None + source = NullableForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve_with_null(self): + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 'target-1'}, + {'id': 2, 'name': u'source-2', 'target': 'target-1'}, + {'id': 3, 'name': u'source-3', 'target': None}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_create_with_valid_null(self): + data = {'id': 4, 'name': u'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'source-4') + + # Ensure source 4 is created, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 'target-1'}, + {'id': 2, 'name': u'source-2', 'target': 'target-1'}, + {'id': 3, 'name': u'source-3', 'target': None}, + {'id': 4, 'name': u'source-4', 'target': None} + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_create_with_valid_emptystring(self): + """ + The emptystring should be interpreted as null in the context + of relationships. + """ + data = {'id': 4, 'name': u'source-4', 'target': ''} + expected_data = {'id': 4, 'name': u'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, expected_data) + self.assertEqual(obj.name, u'source-4') + + # Ensure source 4 is created, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 'target-1'}, + {'id': 2, 'name': u'source-2', 'target': 'target-1'}, + {'id': 3, 'name': u'source-3', 'target': None}, + {'id': 4, 'name': u'source-4', 'target': None} + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update_with_valid_null(self): + data = {'id': 1, 'name': u'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': None}, + {'id': 2, 'name': u'source-2', 'target': 'target-1'}, + {'id': 3, 'name': u'source-3', 'target': None} + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update_with_valid_emptystring(self): + """ + The emptystring should be interpreted as null in the context + of relationships. + """ + data = {'id': 1, 'name': u'source-1', 'target': ''} + expected_data = {'id': 1, 'name': u'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, expected_data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': None}, + {'id': 2, 'name': u'source-2', 'target': 'target-1'}, + {'id': 3, 'name': u'source-3', 'target': None} + ] + self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index 7d4575bb..92b1bfd8 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -1,12 +1,12 @@ """ Tests for content parsing, and form-overloaded content parsing. """ +import json from django.contrib.auth.models import User from django.contrib.auth import authenticate, login, logout from django.contrib.sessions.middleware import SessionMiddleware from django.test import TestCase, Client from django.test.client import RequestFactory -from django.utils import simplejson as json from rest_framework import status from rest_framework.authentication import SessionAuthentication from rest_framework.compat import patterns diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 6ce7de31..a00626b5 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -56,6 +56,19 @@ class ActionItemSerializer(serializers.ModelSerializer): model = ActionItem +class ActionItemSerializerCustomRestore(serializers.ModelSerializer): + + class Meta: + model = ActionItem + + def restore_object(self, data, instance=None): + if instance is None: + return ActionItem(**data) + for key, val in data.items(): + setattr(instance, key, val) + return instance + + class PersonSerializer(serializers.ModelSerializer): info = serializers.Field(source='info') @@ -71,6 +84,7 @@ class AlbumsSerializer(serializers.ModelSerializer): model = Album fields = ['title'] # lists are also valid options + class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): class Meta: model = HasPositiveIntegerAsChoice @@ -163,7 +177,6 @@ class BasicTests(TestCase): """ Attempting to update fields set as read_only should have no effect. """ - serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99}) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() @@ -184,8 +197,7 @@ class ValidationTests(TestCase): 'content': 'x' * 1001, 'created': datetime.datetime(2012, 1, 1) } - self.actionitem = ActionItem(title='Some to do item', - ) + self.actionitem = ActionItem(title='Some to do item',) def test_create(self): serializer = CommentSerializer(data=self.data) @@ -217,30 +229,24 @@ class ValidationTests(TestCase): self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.errors, {}) - def test_field_validation(self): - - class CommentSerializerWithFieldValidator(CommentSerializer): - - def validate_content(self, attrs, source): - value = attrs[source] - if "test" not in value: - raise serializers.ValidationError("Test not in value") - return attrs - - data = { - 'email': 'tom@example.com', - 'content': 'A test comment', - 'created': datetime.datetime(2012, 1, 1) - } - - serializer = CommentSerializerWithFieldValidator(data=data) - self.assertTrue(serializer.is_valid()) + def test_bad_type_data_is_false(self): + """ + Data of the wrong type is not valid. + """ + data = ['i am', 'a', 'list'] + serializer = CommentSerializer(self.comment, data=data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) - data['content'] = 'This should not validate' + data = 'and i am a string' + serializer = CommentSerializer(self.comment, data=data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) - serializer = CommentSerializerWithFieldValidator(data=data) - self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'content': ['Test not in value']}) + data = 42 + serializer = CommentSerializer(self.comment, data=data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) def test_cross_field_validation(self): @@ -282,6 +288,20 @@ class ValidationTests(TestCase): self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']}) + def test_modelserializer_max_length_exceeded_with_custom_restore(self): + """ + When overriding ModelSerializer.restore_object, validation tests should still apply. + Regression test for #623. + + https://github.com/tomchristie/django-rest-framework/pull/623 + """ + data = { + 'title': 'x' * 201, + } + serializer = ActionItemSerializerCustomRestore(data=data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']}) + def test_default_modelfield_max_length_exceeded(self): data = { 'title': 'Testing "info" field...', @@ -292,12 +312,69 @@ class ValidationTests(TestCase): self.assertEquals(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']}) +class CustomValidationTests(TestCase): + class CommentSerializerWithFieldValidator(CommentSerializer): + + def validate_email(self, attrs, source): + value = attrs[source] + + return attrs + + def validate_content(self, attrs, source): + value = attrs[source] + if "test" not in value: + raise serializers.ValidationError("Test not in value") + return attrs + + def test_field_validation(self): + data = { + 'email': 'tom@example.com', + 'content': 'A test comment', + 'created': datetime.datetime(2012, 1, 1) + } + + serializer = self.CommentSerializerWithFieldValidator(data=data) + self.assertTrue(serializer.is_valid()) + + data['content'] = 'This should not validate' + + serializer = self.CommentSerializerWithFieldValidator(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) + + def test_missing_data(self): + """ + Make sure that validate_content isn't called if the field is missing + """ + incomplete_data = { + 'email': 'tom@example.com', + 'created': datetime.datetime(2012, 1, 1) + } + serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'content': [u'This field is required.']}) + + def test_wrong_data(self): + """ + Make sure that validate_content isn't called if the field input is wrong + """ + wrong_data = { + 'email': 'not an email', + 'content': 'A test comment', + 'created': datetime.datetime(2012, 1, 1) + } + serializer = self.CommentSerializerWithFieldValidator(data=wrong_data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'email': [u'Enter a valid e-mail address.']}) + + class PositiveIntegerAsChoiceTests(TestCase): def test_positive_integer_in_json_is_correctly_parsed(self): - data = {'some_integer':1} + data = {'some_integer': 1} serializer = PositiveIntegerAsChoiceSerializer(data=data) self.assertEquals(serializer.is_valid(), True) + class ModelValidationTests(TestCase): def test_validate_unique(self): """ @@ -342,7 +419,6 @@ class ModelValidationTests(TestCase): self.assertTrue(photo_serializer.save()) - class RegexValidationTest(TestCase): def test_create_failed(self): serializer = BookSerializer(data={'isbn': '1234567890'}) @@ -553,6 +629,21 @@ class DefaultValueTests(TestCase): self.assertEquals(instance.pk, 1) self.assertEquals(instance.text, 'overridden') + def test_partial_update_default(self): + """ Regression test for issue #532 """ + data = {'text': 'overridden'} + serializer = self.serializer_class(data=data, partial=True) + self.assertEquals(serializer.is_valid(), True) + instance = serializer.save() + + data = {'extra': 'extra_value'} + serializer = self.serializer_class(instance=instance, data=data, partial=True) + self.assertEquals(serializer.is_valid(), True) + instance = serializer.save() + + self.assertEquals(instance.extra, 'extra_value') + self.assertEquals(instance.text, 'overridden') + class CallableDefaultValueTests(TestCase): def setUp(self): diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/settings.py new file mode 100644 index 00000000..0293fdc3 --- /dev/null +++ b/rest_framework/tests/settings.py @@ -0,0 +1,21 @@ +"""Tests for the settings module""" +from django.test import TestCase + +from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS + + +class TestSettings(TestCase): + """Tests relating to the api settings""" + + def test_non_import_errors(self): + """Make sure other errors aren't suppressed.""" + settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) + with self.assertRaises(ValueError): + settings.DEFAULT_MODEL_SERIALIZER_CLASS + + def test_import_error_message_maintained(self): + """Make sure real import errors are captured and raised sensibly.""" + settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) + with self.assertRaises(ImportError) as cm: + settings.DEFAULT_MODEL_SERIALIZER_CLASS + self.assertTrue('ImportError' in str(cm.exception)) diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/urlpatterns.py new file mode 100644 index 00000000..43e8ef69 --- /dev/null +++ b/rest_framework/tests/urlpatterns.py @@ -0,0 +1,78 @@ +from collections import namedtuple + +from django.core import urlresolvers + +from django.test import TestCase +from django.test.client import RequestFactory + +from rest_framework.compat import patterns, url, include +from rest_framework.urlpatterns import format_suffix_patterns + + +# A container class for test paths for the test case +URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs']) + + +def dummy_view(request, *args, **kwargs): + pass + + +class FormatSuffixTests(TestCase): + """ + Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. + """ + def _resolve_urlpatterns(self, urlpatterns, test_paths): + factory = RequestFactory() + try: + urlpatterns = format_suffix_patterns(urlpatterns) + except: + self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns") + resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) + for test_path in test_paths: + request = factory.get(test_path.path) + try: + callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) + except: + self.fail("Failed to resolve URL: %s" % request.path_info) + self.assertEquals(callback_args, test_path.args) + self.assertEquals(callback_kwargs, test_path.kwargs) + + def test_format_suffix(self): + urlpatterns = patterns( + '', + url(r'^test$', dummy_view), + ) + test_paths = [ + URLTestPath('/test', (), {}), + URLTestPath('/test.api', (), {'format': 'api'}), + URLTestPath('/test.asdf', (), {'format': 'asdf'}), + ] + self._resolve_urlpatterns(urlpatterns, test_paths) + + def test_default_args(self): + urlpatterns = patterns( + '', + url(r'^test$', dummy_view, {'foo': 'bar'}), + ) + test_paths = [ + URLTestPath('/test', (), {'foo': 'bar', }), + URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}), + URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}), + ] + self._resolve_urlpatterns(urlpatterns, test_paths) + + def test_included_urls(self): + nested_patterns = patterns( + '', + url(r'^path$', dummy_view) + ) + urlpatterns = patterns( + '', + url(r'^test/', include(nested_patterns), {'foo': 'bar'}), + ) + test_paths = [ + URLTestPath('/test/path', (), {'foo': 'bar', }), + URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}), + URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}), + ] + self._resolve_urlpatterns(urlpatterns, test_paths) diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py new file mode 100644 index 00000000..3906adb9 --- /dev/null +++ b/rest_framework/tests/utils.py @@ -0,0 +1,27 @@ +from django.test.client import RequestFactory, FakePayload +from django.test.client import MULTIPART_CONTENT +from urlparse import urlparse + + +class RequestFactory(RequestFactory): + + def __init__(self, **defaults): + super(RequestFactory, self).__init__(**defaults) + + def patch(self, path, data={}, content_type=MULTIPART_CONTENT, + **extra): + "Construct a PATCH request." + + patch_data = self._encode_data(data, content_type) + + parsed = urlparse(path) + r = { + 'CONTENT_LENGTH': len(patch_data), + 'CONTENT_TYPE': content_type, + 'PATH_INFO': self._get_path(parsed), + 'QUERY_STRING': parsed[4], + 'REQUEST_METHOD': 'PATCH', + 'wsgi.input': FakePayload(patch_data), + } + r.update(extra) + return self.request(**r) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py index e51ca9f3..f2432516 100644 --- a/rest_framework/tests/views.py +++ b/rest_framework/tests/views.py @@ -20,7 +20,7 @@ class BasicView(APIView): return Response({'method': 'POST', 'data': request.DATA}) -@api_view(['GET', 'POST', 'PUT']) +@api_view(['GET', 'POST', 'PUT', 'PATCH']) def basic_view(request): if request.method == 'GET': return {'method': 'GET'} @@ -28,6 +28,8 @@ def basic_view(request): return {'method': 'POST', 'data': request.DATA} elif request.method == 'PUT': return {'method': 'PUT', 'data': request.DATA} + elif request.method == 'PATCH': + return {'method': 'PATCH', 'data': request.DATA} def sanitise_json_error(error_dict): diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 143928c9..47789026 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,5 +1,35 @@ -from rest_framework.compat import url +from rest_framework.compat import url, include from rest_framework.settings import api_settings +from django.core.urlresolvers import RegexURLResolver + + +def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required): + ret = [] + for urlpattern in urlpatterns: + if isinstance(urlpattern, RegexURLResolver): + # Set of included URL patterns + regex = urlpattern.regex.pattern + namespace = urlpattern.namespace + app_name = urlpattern.app_name + kwargs = urlpattern.default_kwargs + # Add in the included patterns, after applying the suffixes + patterns = apply_suffix_patterns(urlpattern.url_patterns, + suffix_pattern, + suffix_required) + ret.append(url(regex, include(patterns, namespace, app_name), kwargs)) + + else: + # Regular URL pattern + regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern + view = urlpattern._callback or urlpattern._callback_str + kwargs = urlpattern.default_args + name = urlpattern.name + # Add in both the existing and the new urlpattern + if not suffix_required: + ret.append(urlpattern) + ret.append(url(regex, view, kwargs, name)) + + return ret def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): @@ -28,15 +58,4 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): else: suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg - ret = [] - for urlpattern in urlpatterns: - # Form our complementing '.format' urlpattern - regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern - view = urlpattern._callback or urlpattern._callback_str - kwargs = urlpattern.default_args - name = urlpattern.name - # Add in both the existing and the new urlpattern - if not suffix_required: - ret.append(urlpattern) - ret.append(url(regex, view, kwargs, name)) - return ret + return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 2d1fb353..7afe100a 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -4,7 +4,7 @@ Helper classes for parsers. import datetime import decimal import types -from django.utils import simplejson as json +import json from django.utils.datastructures import SortedDict from rest_framework.compat import timezone from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata @@ -12,7 +12,7 @@ from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata class JSONEncoder(json.JSONEncoder): """ - JSONEncoder subclass that knows how to encode date/time, + JSONEncoder subclass that knows how to encode date/time/timedelta, decimal types, and generators. """ def default(self, o): @@ -34,6 +34,8 @@ class JSONEncoder(json.JSONEncoder): if o.microsecond: r = r[:12] return r + elif isinstance(o, datetime.timedelta): + return str(o.total_seconds()) elif isinstance(o, decimal.Decimal): return str(o) elif hasattr(o, '__iter__'): diff --git a/rest_framework/views.py b/rest_framework/views.py index 10bdd5a5..ac9b3385 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -148,6 +148,8 @@ class APIView(View): """ If request is not permitted, determine what kind of exception to raise. """ + if not self.request.successful_authenticator: + raise exceptions.NotAuthenticated() raise exceptions.PermissionDenied() def throttled(self, request, wait): @@ -156,6 +158,15 @@ class APIView(View): """ raise exceptions.Throttled(wait) + def get_authenticate_header(self, request): + """ + If a request is unauthenticated, determine the WWW-Authenticate + header to use for 401 responses, if any. + """ + authenticators = self.get_authenticators() + if authenticators: + return authenticators[0].authenticate_header(request) + def get_parser_context(self, http_request): """ Returns a dict that is passed through to Parser.parse(), @@ -319,6 +330,16 @@ class APIView(View): # Throttle wait header self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait + if isinstance(exc, (exceptions.NotAuthenticated, + exceptions.AuthenticationFailed)): + # WWW-Authenticate header for 401 responses, else coerce to 403 + auth_header = self.get_authenticate_header(self.request) + + if auth_header: + self.headers['WWW-Authenticate'] = auth_header + else: + exc.status_code = status.HTTP_403_FORBIDDEN + if isinstance(exc, exceptions.APIException): return Response({'detail': exc.detail}, status=exc.status_code, |
