aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/authentication.py80
-rw-r--r--rest_framework/decorators.py9
-rw-r--r--rest_framework/exceptions.py16
-rw-r--r--rest_framework/pagination.py26
-rw-r--r--rest_framework/relations.py20
-rw-r--r--rest_framework/request.py29
-rw-r--r--rest_framework/serializers.py12
-rw-r--r--rest_framework/tests/authentication.py41
-rw-r--r--rest_framework/tests/decorators.py22
-rw-r--r--rest_framework/tests/genericrelations.py96
-rw-r--r--rest_framework/tests/models.py21
-rw-r--r--rest_framework/tests/pagination.py43
-rw-r--r--rest_framework/tests/relations_hyperlink.py9
-rw-r--r--rest_framework/tests/relations_pk.py7
-rw-r--r--rest_framework/tests/relations_slug.py162
-rw-r--r--rest_framework/tests/serializer.py88
-rw-r--r--rest_framework/tests/urlpatterns.py78
-rw-r--r--rest_framework/urlpatterns.py45
-rw-r--r--rest_framework/views.py21
20 files changed, 660 insertions, 167 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index bc267fad..f9882c57 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,3 +1,3 @@
-__version__ = '2.1.16'
+__version__ = '2.1.17'
VERSION = __version__ # synonym
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 30c78ebc..fc169189 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -21,32 +21,46 @@ class BaseAuthentication(object):
"""
raise NotImplementedError(".authenticate() must be overridden.")
+ def authenticate_header(self, request):
+ """
+ Return a string to be used as the value of the `WWW-Authenticate`
+ header in a `401 Unauthenticated` response, or `None` if the
+ authentication scheme should return `403 Permission Denied` responses.
+ """
+ pass
+
class BasicAuthentication(BaseAuthentication):
"""
HTTP Basic authentication against username/password.
"""
+ www_authenticate_realm = 'api'
def authenticate(self, request):
"""
Returns a `User` if a correct username and password have been supplied
using HTTP Basic authentication. Otherwise returns `None`.
"""
- if 'HTTP_AUTHORIZATION' in request.META:
- auth = request.META['HTTP_AUTHORIZATION'].split()
- if len(auth) == 2 and auth[0].lower() == "basic":
- try:
- auth_parts = base64.b64decode(auth[1]).partition(':')
- except TypeError:
- return None
-
- try:
- userid = smart_unicode(auth_parts[0])
- password = smart_unicode(auth_parts[2])
- except DjangoUnicodeDecodeError:
- return None
-
- return self.authenticate_credentials(userid, password)
+ auth = request.META.get('HTTP_AUTHORIZATION', '').split()
+
+ if not auth or auth[0].lower() != "basic":
+ return None
+
+ if len(auth) != 2:
+ raise exceptions.AuthenticationFailed('Invalid basic header')
+
+ try:
+ auth_parts = base64.b64decode(auth[1]).partition(':')
+ except TypeError:
+ raise exceptions.AuthenticationFailed('Invalid basic header')
+
+ try:
+ userid = smart_unicode(auth_parts[0])
+ password = smart_unicode(auth_parts[2])
+ except DjangoUnicodeDecodeError:
+ raise exceptions.AuthenticationFailed('Invalid basic header')
+
+ return self.authenticate_credentials(userid, password)
def authenticate_credentials(self, userid, password):
"""
@@ -55,6 +69,10 @@ class BasicAuthentication(BaseAuthentication):
user = authenticate(username=userid, password=password)
if user is not None and user.is_active:
return (user, None)
+ raise exceptions.AuthenticationFailed('Invalid username/password')
+
+ def authenticate_header(self, request):
+ return 'Basic realm="%s"' % self.www_authenticate_realm
class SessionAuthentication(BaseAuthentication):
@@ -74,7 +92,7 @@ class SessionAuthentication(BaseAuthentication):
# Unauthenticated, CSRF validation not required
if not user or not user.is_active:
- return
+ return None
# Enforce CSRF validation for session based authentication.
class CSRFCheck(CsrfViewMiddleware):
@@ -85,7 +103,7 @@ class SessionAuthentication(BaseAuthentication):
reason = CSRFCheck().process_view(http_request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
- raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
+ raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason)
# CSRF passed with authenticated user
return (user, None)
@@ -112,14 +130,26 @@ class TokenAuthentication(BaseAuthentication):
def authenticate(self, request):
auth = request.META.get('HTTP_AUTHORIZATION', '').split()
- if len(auth) == 2 and auth[0].lower() == "token":
- key = auth[1]
- try:
- token = self.model.objects.get(key=key)
- except self.model.DoesNotExist:
- return None
+ if not auth or auth[0].lower() != "token":
+ return None
+
+ if len(auth) != 2:
+ raise exceptions.AuthenticationFailed('Invalid token header')
+
+ return self.authenticate_credentials(auth[1])
+
+ def authenticate_credentials(self, key):
+ try:
+ token = self.model.objects.get(key=key)
+ except self.model.DoesNotExist:
+ raise exceptions.AuthenticationFailed('Invalid token')
+
+ if token.user.is_active:
+ return (token.user, token)
+ raise exceptions.AuthenticationFailed('User inactive or deleted')
+
+ def authenticate_header(self, request):
+ return 'Token'
- if token.user.is_active:
- return (token.user, token)
# TODO: OAuthAuthentication
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 1b710a03..7a4103e1 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -1,4 +1,5 @@
from rest_framework.views import APIView
+import types
def api_view(http_method_names):
@@ -23,6 +24,14 @@ def api_view(http_method_names):
# pass
# WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
+ # api_view applied without (method_names)
+ assert not(isinstance(http_method_names, types.FunctionType)), \
+ '@api_view missing list of allowed HTTP methods'
+
+ # api_view applied with eg. string instead of list of strings
+ assert isinstance(http_method_names, (list, tuple)), \
+ '@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__
+
allowed_methods = set(http_method_names) | set(('options',))
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 89479deb..d635351c 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -23,6 +23,22 @@ class ParseError(APIException):
self.detail = detail or self.default_detail
+class AuthenticationFailed(APIException):
+ status_code = status.HTTP_401_UNAUTHORIZED
+ default_detail = 'Incorrect authentication credentials.'
+
+ def __init__(self, detail=None):
+ self.detail = detail or self.default_detail
+
+
+class NotAuthenticated(APIException):
+ status_code = status.HTTP_401_UNAUTHORIZED
+ default_detail = 'Authentication credentials were not provided.'
+
+ def __init__(self, detail=None):
+ self.detail = detail or self.default_detail
+
+
class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN
default_detail = 'You do not have permission to perform this action.'
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index d241ade7..92d41e0e 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -34,6 +34,17 @@ class PreviousPageField(serializers.Field):
return replace_query_param(url, self.page_field, page)
+class DefaultObjectSerializer(serializers.Field):
+ """
+ If no object serializer is specified, then this serializer will be applied
+ as the default.
+ """
+
+ def __init__(self, source=None, context=None):
+ # Note: Swallow context kwarg - only required for eg. ModelSerializer.
+ super(DefaultObjectSerializer, self).__init__(source=source)
+
+
class PaginationSerializerOptions(serializers.SerializerOptions):
"""
An object that stores the options that may be provided to a
@@ -44,7 +55,7 @@ class PaginationSerializerOptions(serializers.SerializerOptions):
def __init__(self, meta):
super(PaginationSerializerOptions, self).__init__(meta)
self.object_serializer_class = getattr(meta, 'object_serializer_class',
- serializers.Field)
+ DefaultObjectSerializer)
class BasePaginationSerializer(serializers.Serializer):
@@ -62,14 +73,13 @@ class BasePaginationSerializer(serializers.Serializer):
super(BasePaginationSerializer, self).__init__(*args, **kwargs)
results_field = self.results_field
object_serializer = self.opts.object_serializer_class
- self.fields[results_field] = object_serializer(source='object_list')
- def to_native(self, obj):
- """
- Prevent default behaviour of iterating over elements, and serializing
- each in turn.
- """
- return self.convert_object(obj)
+ if 'context' in kwargs:
+ context_kwarg = {'context': kwargs['context']}
+ else:
+ context_kwarg = {}
+
+ self.fields[results_field] = object_serializer(source='object_list', **context_kwarg)
class PaginationSerializer(BasePaginationSerializer):
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 7ded3891..af63ceaa 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -177,7 +177,7 @@ class PrimaryKeyRelatedField(RelatedField):
default_error_messages = {
'does_not_exist': _("Invalid pk '%s' - object does not exist."),
- 'invalid': _('Invalid value.'),
+ 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'),
}
# TODO: Remove these field hacks...
@@ -208,7 +208,8 @@ class PrimaryKeyRelatedField(RelatedField):
msg = self.error_messages['does_not_exist'] % smart_unicode(data)
raise ValidationError(msg)
except (TypeError, ValueError):
- msg = self.error_messages['invalid']
+ received = type(data).__name__
+ msg = self.error_messages['incorrect_type'] % received
raise ValidationError(msg)
def field_to_native(self, obj, field_name):
@@ -235,7 +236,7 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
default_error_messages = {
'does_not_exist': _("Invalid pk '%s' - object does not exist."),
- 'invalid': _('Invalid value.'),
+ 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'),
}
def prepare_value(self, obj):
@@ -275,7 +276,8 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
msg = self.error_messages['does_not_exist'] % smart_unicode(data)
raise ValidationError(msg)
except (TypeError, ValueError):
- msg = self.error_messages['invalid']
+ received = type(data).__name__
+ msg = self.error_messages['incorrect_type'] % received
raise ValidationError(msg)
### Slug relationships
@@ -333,7 +335,7 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_match': _('Invalid hyperlink - Incorrect URL match'),
'configuration_error': _('Invalid hyperlink due to configuration error'),
'does_not_exist': _("Invalid hyperlink - object does not exist."),
- 'invalid': _('Invalid value.'),
+ 'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
}
def __init__(self, *args, **kwargs):
@@ -397,8 +399,8 @@ class HyperlinkedRelatedField(RelatedField):
try:
http_prefix = value.startswith('http:') or value.startswith('https:')
except AttributeError:
- msg = self.error_messages['invalid']
- raise ValidationError(msg)
+ msg = self.error_messages['incorrect_type']
+ raise ValidationError(msg % type(value).__name__)
if http_prefix:
# If needed convert absolute URLs to relative path
@@ -434,8 +436,8 @@ class HyperlinkedRelatedField(RelatedField):
except ObjectDoesNotExist:
raise ValidationError(self.error_messages['does_not_exist'])
except (TypeError, ValueError):
- msg = self.error_messages['invalid']
- raise ValidationError(msg)
+ msg = self.error_messages['incorrect_type']
+ raise ValidationError(msg % type(value).__name__)
return obj
diff --git a/rest_framework/request.py b/rest_framework/request.py
index b7133608..1c28cd17 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -86,6 +86,7 @@ class Request(object):
self._method = Empty
self._content_type = Empty
self._stream = Empty
+ self._authenticator = None
if self.parser_context is None:
self.parser_context = {}
@@ -166,7 +167,7 @@ class Request(object):
by the authentication classes provided to the request.
"""
if not hasattr(self, '_user'):
- self._user, self._auth = self._authenticate()
+ self._authenticator, self._user, self._auth = self._authenticate()
return self._user
@user.setter
@@ -185,7 +186,7 @@ class Request(object):
request, such as an authentication token.
"""
if not hasattr(self, '_auth'):
- self._user, self._auth = self._authenticate()
+ self._authenticator, self._user, self._auth = self._authenticate()
return self._auth
@auth.setter
@@ -196,6 +197,14 @@ class Request(object):
"""
self._auth = value
+ @property
+ def successful_authenticator(self):
+ """
+ Return the instance of the authentication instance class that was used
+ to authenticate the request, or `None`.
+ """
+ return self._authenticator
+
def _load_data_and_files(self):
"""
Parses the request content into self.DATA and self.FILES.
@@ -299,21 +308,23 @@ class Request(object):
def _authenticate(self):
"""
- Attempt to authenticate the request using each authentication instance in turn.
- Returns a two-tuple of (user, authtoken).
+ Attempt to authenticate the request using each authentication instance
+ in turn.
+ Returns a three-tuple of (authenticator, user, authtoken).
"""
for authenticator in self.authenticators:
user_auth_tuple = authenticator.authenticate(self)
if not user_auth_tuple is None:
- return user_auth_tuple
+ user, auth = user_auth_tuple
+ return (authenticator, user, auth)
return self._not_authenticated()
def _not_authenticated(self):
"""
- Return a two-tuple of (user, authtoken), representing an
- unauthenticated request.
+ Return a three-tuple of (authenticator, user, authtoken), representing
+ an unauthenticated request.
- By default this will be (AnonymousUser, None).
+ By default this will be (None, AnonymousUser, None).
"""
if api_settings.UNAUTHENTICATED_USER:
user = api_settings.UNAUTHENTICATED_USER()
@@ -325,7 +336,7 @@ class Request(object):
else:
auth = None
- return (user, auth)
+ return (None, user, auth)
def __getattr__(self, attr):
"""
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 27458f96..6ecc7b45 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -2,6 +2,7 @@ import copy
import datetime
import types
from decimal import Decimal
+from django.core.paginator import Page
from django.db import models
from django.forms import widgets
from django.utils.datastructures import SortedDict
@@ -227,6 +228,8 @@ class BaseSerializer(Field):
Run `validate_<fieldname>()` and `validate()` methods on the serializer
"""
for field_name, field in self.fields.items():
+ if field_name in self._errors:
+ continue
try:
validate_method = getattr(self, 'validate_%s' % field_name, None)
if validate_method:
@@ -271,7 +274,11 @@ class BaseSerializer(Field):
"""
Serialize objects -> primitives.
"""
- if hasattr(obj, '__iter__'):
+ # Note: At the moment we have an ugly hack to determine if we should
+ # walk over iterables. At some point, serializers will require an
+ # explicit `many=True` in order to iterate over a set, and this hack
+ # will disappear.
+ if hasattr(obj, '__iter__') and not isinstance(obj, Page):
return [self.convert_object(item) for item in obj]
return self.convert_object(obj)
@@ -298,6 +305,9 @@ class BaseSerializer(Field):
Override default so that we can apply ModelSerializer as a nested
field to relationships.
"""
+ if self.source == '*':
+ return self.to_native(obj)
+
try:
if self.source:
for component in self.source.split('.'):
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index e86041bc..1f17e8d2 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -4,7 +4,7 @@ from django.test import Client, TestCase
from rest_framework import permissions
from rest_framework.authtoken.models import Token
-from rest_framework.authentication import TokenAuthentication
+from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication
from rest_framework.compat import patterns
from rest_framework.views import APIView
@@ -21,10 +21,10 @@ class MockView(APIView):
def put(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
-MockView.authentication_classes += (TokenAuthentication,)
-
urlpatterns = patterns('',
- (r'^$', MockView.as_view()),
+ (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
+ (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
+ (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
)
@@ -43,24 +43,25 @@ class BasicAuthTests(TestCase):
def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
- response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_form_failing_basic_auth(self):
"""Ensure POSTing form over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/basic/', {'example': 'example'})
+ self.assertEqual(response.status_code, 401)
def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, 401)
+ self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
class SessionAuthTests(TestCase):
@@ -83,7 +84,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication without CSRF token fails.
"""
self.csrf_client.login(username=self.username, password=self.password)
- response = self.csrf_client.post('/', {'example': 'example'})
+ response = self.csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
def test_post_form_session_auth_passing(self):
@@ -91,7 +92,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.post('/', {'example': 'example'})
+ response = self.non_csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 200)
def test_put_form_session_auth_passing(self):
@@ -99,14 +100,14 @@ class SessionAuthTests(TestCase):
Ensure PUTting form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.put('/', {'example': 'example'})
+ response = self.non_csrf_client.put('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 200)
def test_post_form_session_auth_failing(self):
"""
Ensure POSTing form over session authentication without logged in user fails.
"""
- response = self.csrf_client.post('/', {'example': 'example'})
+ response = self.csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
@@ -127,24 +128,24 @@ class TokenAuthTests(TestCase):
def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
- response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/token/', {'example': 'example'})
+ self.assertEqual(response.status_code, 401)
def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails"""
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, 401)
def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key"""
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py
index 5e6bce4e..82f912e9 100644
--- a/rest_framework/tests/decorators.py
+++ b/rest_framework/tests/decorators.py
@@ -28,13 +28,27 @@ class DecoratorTestCase(TestCase):
response.request = request
return APIView.finalize_response(self, request, response, *args, **kwargs)
- def test_wrap_view(self):
+ def test_api_view_incorrect(self):
+ """
+ If @api_view is not applied correct, we should raise an assertion.
+ """
- @api_view(['GET'])
+ @api_view
def view(request):
- return Response({})
+ return Response()
+
+ request = self.factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ def test_api_view_incorrect_arguments(self):
+ """
+ If @api_view is missing arguments, we should raise an assertion.
+ """
- self.assertTrue(isinstance(view.cls_instance, APIView))
+ with self.assertRaises(AssertionError):
+ @api_view('GET')
+ def view(request):
+ return Response()
def test_calling_method(self):
diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py
index bc7378e1..146ad1e4 100644
--- a/rest_framework/tests/genericrelations.py
+++ b/rest_framework/tests/genericrelations.py
@@ -1,25 +1,61 @@
+from django.contrib.contenttypes.models import ContentType
+from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
+from django.db import models
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import *
+
+
+class Tag(models.Model):
+ """
+ Tags have a descriptive slug, and are attached to an arbitrary object.
+ """
+ tag = models.SlugField()
+ content_type = models.ForeignKey(ContentType)
+ object_id = models.PositiveIntegerField()
+ tagged_item = GenericForeignKey('content_type', 'object_id')
+
+ def __unicode__(self):
+ return self.tag
+
+
+class Bookmark(models.Model):
+ """
+ A URL bookmark that may have multiple tags attached.
+ """
+ url = models.URLField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Bookmark: %s' % self.url
+
+
+class Note(models.Model):
+ """
+ A textual note that may have multiple tags attached.
+ """
+ text = models.TextField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Note: %s' % self.text
class TestGenericRelations(TestCase):
def setUp(self):
- bookmark = Bookmark(url='https://www.djangoproject.com/')
- bookmark.save()
- django = Tag(tag_name='django')
- django.save()
- python = Tag(tag_name='python')
- python.save()
- t1 = TaggedItem(content_object=bookmark, tag=django)
- t1.save()
- t2 = TaggedItem(content_object=bookmark, tag=python)
- t2.save()
- self.bookmark = bookmark
-
- def test_reverse_generic_relation(self):
+ self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
+ Tag.objects.create(tagged_item=self.bookmark, tag='django')
+ Tag.objects.create(tagged_item=self.bookmark, tag='python')
+ self.note = Note.objects.create(text='Remember the milk')
+ Tag.objects.create(tagged_item=self.note, tag='reminder')
+
+ def test_generic_relation(self):
+ """
+ Test a relationship that spans a GenericRelation field.
+ IE. A reverse generic relationship.
+ """
+
class BookmarkSerializer(serializers.ModelSerializer):
- tags = serializers.ManyRelatedField(source='tags')
+ tags = serializers.ManyRelatedField()
class Meta:
model = Bookmark
@@ -31,3 +67,33 @@ class TestGenericRelations(TestCase):
'url': u'https://www.djangoproject.com/'
}
self.assertEquals(serializer.data, expected)
+
+ def test_generic_fk(self):
+ """
+ Test a relationship that spans a GenericForeignKey field.
+ IE. A forward generic relationship.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ tagged_item = serializers.RelatedField()
+
+ class Meta:
+ model = Tag
+ exclude = ('id', 'content_type', 'object_id')
+
+ serializer = TagSerializer(Tag.objects.all())
+ expected = [
+ {
+ 'tag': u'django',
+ 'tagged_item': u'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': u'python',
+ 'tagged_item': u'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': u'reminder',
+ 'tagged_item': u'Note: Remember the milk'
+ }
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 93f09761..9ab15328 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -86,27 +86,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
text = models.CharField(max_length=100, default='anchor')
rel = models.ManyToManyField(Anchor)
-# Models to test generic relations
-
-
-class Tag(RESTFrameworkModel):
- tag_name = models.SlugField()
-
-
-class TaggedItem(RESTFrameworkModel):
- tag = models.ForeignKey(Tag, related_name='items')
- content_type = models.ForeignKey(ContentType)
- object_id = models.PositiveIntegerField()
- content_object = GenericForeignKey('content_type', 'object_id')
-
- def __unicode__(self):
- return self.tag.tag_name
-
-
-class Bookmark(RESTFrameworkModel):
- url = models.URLField()
- tags = GenericRelation(TaggedItem)
-
# Model to test filtering.
class FilterableItem(RESTFrameworkModel):
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 3b550877..697dfb5b 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -252,6 +252,8 @@ class TestCustomPaginateByParam(TestCase):
self.assertEquals(response.data['results'], self.data[:5])
+### Tests for context in pagination serializers
+
class CustomField(serializers.Field):
def to_native(self, value):
if not 'view' in self.context:
@@ -262,6 +264,11 @@ class CustomField(serializers.Field):
class BasicModelSerializer(serializers.Serializer):
text = CustomField()
+ def __init__(self, *args, **kwargs):
+ super(BasicModelSerializer, self).__init__(*args, **kwargs)
+ if not 'view' in self.context:
+ raise RuntimeError("context isn't getting passed into serializer init")
+
class TestContextPassedToCustomField(TestCase):
def setUp(self):
@@ -279,3 +286,39 @@ class TestContextPassedToCustomField(TestCase):
self.assertEquals(response.status_code, status.HTTP_200_OK)
+
+### Tests for custom pagination serializers
+
+class LinksSerializer(serializers.Serializer):
+ next = pagination.NextPageField(source='*')
+ prev = pagination.PreviousPageField(source='*')
+
+
+class CustomPaginationSerializer(pagination.BasePaginationSerializer):
+ links = LinksSerializer(source='*') # Takes the page object as the source
+ total_results = serializers.Field(source='paginator.count')
+
+ results_field = 'objects'
+
+
+class TestCustomPaginationSerializer(TestCase):
+ def setUp(self):
+ objects = ['john', 'paul', 'george', 'ringo']
+ paginator = Paginator(objects, 2)
+ self.page = paginator.page(1)
+
+ def test_custom_pagination_serializer(self):
+ request = RequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=self.page,
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': 'http://testserver/foobar?page=2',
+ 'prev': None
+ },
+ 'total_results': 4,
+ 'objects': ['john', 'paul']
+ }
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
index 7d65eae7..6d137f68 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -215,6 +215,13 @@ class HyperlinkedForeignKeyTests(TestCase):
]
self.assertEquals(serializer.data, expected)
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Incorrect type. Expected url string, received int.']})
+
def test_reverse_foreign_key_update(self):
data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
instance = ForeignKeyTarget.objects.get(pk=2)
@@ -227,7 +234,7 @@ class HyperlinkedForeignKeyTests(TestCase):
expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
{'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
- ]
+ ]
self.assertEquals(new_serializer.data, expected)
serializer.save()
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
index dd1e86b5..3391e60a 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/relations_pk.py
@@ -194,6 +194,13 @@ class PKForeignKeyTests(TestCase):
]
self.assertEquals(serializer.data, expected)
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': u'source-1', 'target': 'foo'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Incorrect type. Expected pk value, received str.']})
+
def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': u'target-2', 'sources': [1, 3]}
instance = ForeignKeyTarget.objects.get(pk=2)
diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/relations_slug.py
index 503b61e8..37ccc75e 100644
--- a/rest_framework/tests/relations_slug.py
+++ b/rest_framework/tests/relations_slug.py
@@ -1,9 +1,23 @@
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import NullableForeignKeySource, ForeignKeyTarget
+from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
-class NullableSlugSourceSerializer(serializers.ModelSerializer):
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.ManySlugRelatedField(slug_field='name')
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(slug_field='name')
+
+ class Meta:
+ model = ForeignKeySource
+
+
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
target = serializers.SlugRelatedField(slug_field='name', null=True)
class Meta:
@@ -11,6 +25,132 @@ class NullableSlugSourceSerializer(serializers.ModelSerializer):
# TODO: M2M Tests, FKTests (Non-nulable), One2One
+class PKForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': u'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': u'source-3', 'target': 'target-1'}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': u'source-1', 'target': 'target-2'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 'target-2'},
+ {'id': 2, 'name': u'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': u'source-3', 'target': 'target-1'}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': u'source-1', 'target': 123}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Object with name=123 does not exist.']})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': u'source-4', 'target': 'target-2'}
+ serializer = ForeignKeySourceSerializer(data=data)
+ serializer.is_valid()
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': u'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': u'source-3', 'target': 'target-1'},
+ {'id': 4, 'name': u'source-4', 'target': 'target-2'},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': u'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
+
class SlugNullableForeignKeyTests(TestCase):
def setUp(self):
@@ -24,7 +164,7 @@ class SlugNullableForeignKeyTests(TestCase):
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableSlugSourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset)
expected = [
{'id': 1, 'name': u'source-1', 'target': 'target-1'},
{'id': 2, 'name': u'source-2', 'target': 'target-1'},
@@ -34,7 +174,7 @@ class SlugNullableForeignKeyTests(TestCase):
def test_foreign_key_create_with_valid_null(self):
data = {'id': 4, 'name': u'source-4', 'target': None}
- serializer = NullableSlugSourceSerializer(data=data)
+ serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEquals(serializer.data, data)
@@ -42,7 +182,7 @@ class SlugNullableForeignKeyTests(TestCase):
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableSlugSourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset)
expected = [
{'id': 1, 'name': u'source-1', 'target': 'target-1'},
{'id': 2, 'name': u'source-2', 'target': 'target-1'},
@@ -58,7 +198,7 @@ class SlugNullableForeignKeyTests(TestCase):
"""
data = {'id': 4, 'name': u'source-4', 'target': ''}
expected_data = {'id': 4, 'name': u'source-4', 'target': None}
- serializer = NullableSlugSourceSerializer(data=data)
+ serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEquals(serializer.data, expected_data)
@@ -66,7 +206,7 @@ class SlugNullableForeignKeyTests(TestCase):
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableSlugSourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset)
expected = [
{'id': 1, 'name': u'source-1', 'target': 'target-1'},
{'id': 2, 'name': u'source-2', 'target': 'target-1'},
@@ -78,14 +218,14 @@ class SlugNullableForeignKeyTests(TestCase):
def test_foreign_key_update_with_valid_null(self):
data = {'id': 1, 'name': u'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableSlugSourceSerializer(instance, data=data)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableSlugSourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset)
expected = [
{'id': 1, 'name': u'source-1', 'target': None},
{'id': 2, 'name': u'source-2', 'target': 'target-1'},
@@ -101,14 +241,14 @@ class SlugNullableForeignKeyTests(TestCase):
data = {'id': 1, 'name': u'source-1', 'target': ''}
expected_data = {'id': 1, 'name': u'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableSlugSourceSerializer(instance, data=data)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.data, expected_data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableSlugSourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset)
expected = [
{'id': 1, 'name': u'source-1', 'target': None},
{'id': 2, 'name': u'source-2', 'target': 'target-1'},
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index bd96ba23..b4428ca3 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -162,7 +162,6 @@ class BasicTests(TestCase):
"""
Attempting to update fields set as read_only should have no effect.
"""
-
serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
@@ -183,8 +182,7 @@ class ValidationTests(TestCase):
'content': 'x' * 1001,
'created': datetime.datetime(2012, 1, 1)
}
- self.actionitem = ActionItem(title='Some to do item',
- )
+ self.actionitem = ActionItem(title='Some to do item',)
def test_create(self):
serializer = CommentSerializer(data=self.data)
@@ -216,31 +214,6 @@ class ValidationTests(TestCase):
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.errors, {})
- def test_field_validation(self):
-
- class CommentSerializerWithFieldValidator(CommentSerializer):
-
- def validate_content(self, attrs, source):
- value = attrs[source]
- if "test" not in value:
- raise serializers.ValidationError("Test not in value")
- return attrs
-
- data = {
- 'email': 'tom@example.com',
- 'content': 'A test comment',
- 'created': datetime.datetime(2012, 1, 1)
- }
-
- serializer = CommentSerializerWithFieldValidator(data=data)
- self.assertTrue(serializer.is_valid())
-
- data['content'] = 'This should not validate'
-
- serializer = CommentSerializerWithFieldValidator(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
-
def test_bad_type_data_is_false(self):
"""
Data of the wrong type is not valid.
@@ -310,12 +283,69 @@ class ValidationTests(TestCase):
self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']})
+class CustomValidationTests(TestCase):
+ class CommentSerializerWithFieldValidator(CommentSerializer):
+
+ def validate_email(self, attrs, source):
+ value = attrs[source]
+
+ return attrs
+
+ def validate_content(self, attrs, source):
+ value = attrs[source]
+ if "test" not in value:
+ raise serializers.ValidationError("Test not in value")
+ return attrs
+
+ def test_field_validation(self):
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'This should not validate'
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
+
+ def test_missing_data(self):
+ """
+ Make sure that validate_content isn't called if the field is missing
+ """
+ incomplete_data = {
+ 'email': 'tom@example.com',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'content': [u'This field is required.']})
+
+ def test_wrong_data(self):
+ """
+ Make sure that validate_content isn't called if the field input is wrong
+ """
+ wrong_data = {
+ 'email': 'not an email',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'email': [u'Enter a valid e-mail address.']})
+
+
class PositiveIntegerAsChoiceTests(TestCase):
def test_positive_integer_in_json_is_correctly_parsed(self):
- data = {'some_integer':1}
+ data = {'some_integer': 1}
serializer = PositiveIntegerAsChoiceSerializer(data=data)
self.assertEquals(serializer.is_valid(), True)
+
class ModelValidationTests(TestCase):
def test_validate_unique(self):
"""
diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/urlpatterns.py
new file mode 100644
index 00000000..43e8ef69
--- /dev/null
+++ b/rest_framework/tests/urlpatterns.py
@@ -0,0 +1,78 @@
+from collections import namedtuple
+
+from django.core import urlresolvers
+
+from django.test import TestCase
+from django.test.client import RequestFactory
+
+from rest_framework.compat import patterns, url, include
+from rest_framework.urlpatterns import format_suffix_patterns
+
+
+# A container class for test paths for the test case
+URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs'])
+
+
+def dummy_view(request, *args, **kwargs):
+ pass
+
+
+class FormatSuffixTests(TestCase):
+ """
+ Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
+ """
+ def _resolve_urlpatterns(self, urlpatterns, test_paths):
+ factory = RequestFactory()
+ try:
+ urlpatterns = format_suffix_patterns(urlpatterns)
+ except:
+ self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
+ resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
+ for test_path in test_paths:
+ request = factory.get(test_path.path)
+ try:
+ callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
+ except:
+ self.fail("Failed to resolve URL: %s" % request.path_info)
+ self.assertEquals(callback_args, test_path.args)
+ self.assertEquals(callback_kwargs, test_path.kwargs)
+
+ def test_format_suffix(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {}),
+ URLTestPath('/test.api', (), {'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_default_args(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view, {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {'foo': 'bar', }),
+ URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_included_urls(self):
+ nested_patterns = patterns(
+ '',
+ url(r'^path$', dummy_view)
+ )
+ urlpatterns = patterns(
+ '',
+ url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test/path', (), {'foo': 'bar', }),
+ URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index 143928c9..47789026 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -1,5 +1,35 @@
-from rest_framework.compat import url
+from rest_framework.compat import url, include
from rest_framework.settings import api_settings
+from django.core.urlresolvers import RegexURLResolver
+
+
+def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required):
+ ret = []
+ for urlpattern in urlpatterns:
+ if isinstance(urlpattern, RegexURLResolver):
+ # Set of included URL patterns
+ regex = urlpattern.regex.pattern
+ namespace = urlpattern.namespace
+ app_name = urlpattern.app_name
+ kwargs = urlpattern.default_kwargs
+ # Add in the included patterns, after applying the suffixes
+ patterns = apply_suffix_patterns(urlpattern.url_patterns,
+ suffix_pattern,
+ suffix_required)
+ ret.append(url(regex, include(patterns, namespace, app_name), kwargs))
+
+ else:
+ # Regular URL pattern
+ regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern
+ view = urlpattern._callback or urlpattern._callback_str
+ kwargs = urlpattern.default_args
+ name = urlpattern.name
+ # Add in both the existing and the new urlpattern
+ if not suffix_required:
+ ret.append(urlpattern)
+ ret.append(url(regex, view, kwargs, name))
+
+ return ret
def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
@@ -28,15 +58,4 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
else:
suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg
- ret = []
- for urlpattern in urlpatterns:
- # Form our complementing '.format' urlpattern
- regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern
- view = urlpattern._callback or urlpattern._callback_str
- kwargs = urlpattern.default_args
- name = urlpattern.name
- # Add in both the existing and the new urlpattern
- if not suffix_required:
- ret.append(urlpattern)
- ret.append(url(regex, view, kwargs, name))
- return ret
+ return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required)
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 10bdd5a5..ac9b3385 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -148,6 +148,8 @@ class APIView(View):
"""
If request is not permitted, determine what kind of exception to raise.
"""
+ if not self.request.successful_authenticator:
+ raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied()
def throttled(self, request, wait):
@@ -156,6 +158,15 @@ class APIView(View):
"""
raise exceptions.Throttled(wait)
+ def get_authenticate_header(self, request):
+ """
+ If a request is unauthenticated, determine the WWW-Authenticate
+ header to use for 401 responses, if any.
+ """
+ authenticators = self.get_authenticators()
+ if authenticators:
+ return authenticators[0].authenticate_header(request)
+
def get_parser_context(self, http_request):
"""
Returns a dict that is passed through to Parser.parse(),
@@ -319,6 +330,16 @@ class APIView(View):
# Throttle wait header
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
+ if isinstance(exc, (exceptions.NotAuthenticated,
+ exceptions.AuthenticationFailed)):
+ # WWW-Authenticate header for 401 responses, else coerce to 403
+ auth_header = self.get_authenticate_header(self.request)
+
+ if auth_header:
+ self.headers['WWW-Authenticate'] = auth_header
+ else:
+ exc.status_code = status.HTTP_403_FORBIDDEN
+
if isinstance(exc, exceptions.APIException):
return Response({'detail': exc.detail},
status=exc.status_code,