aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-01-22 09:12:48 -0800
committerTom Christie2013-01-22 09:12:48 -0800
commitdd10d538ffc8f76ccc670f65da2220b09c22688c (patch)
tree1af09c7dbcc939c749d30adf25b14d232200f44f /rest_framework
parente29ba356f054222893655901923811bd9675d4cc (diff)
parentb7ab2aee46c718f683b19eefba1b48f233da40e4 (diff)
downloaddjango-rest-framework-dd10d538ffc8f76ccc670f65da2220b09c22688c.tar.bz2
Merge pull request #416 from tomchristie/unauthenticated_response
Unauthenticated requests - 401 vs 403 responses
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/authentication.py80
-rw-r--r--rest_framework/exceptions.py16
-rw-r--r--rest_framework/request.py29
-rw-r--r--rest_framework/tests/authentication.py41
-rw-r--r--rest_framework/views.py21
5 files changed, 133 insertions, 54 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 30c78ebc..fc169189 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -21,32 +21,46 @@ class BaseAuthentication(object):
"""
raise NotImplementedError(".authenticate() must be overridden.")
+ def authenticate_header(self, request):
+ """
+ Return a string to be used as the value of the `WWW-Authenticate`
+ header in a `401 Unauthenticated` response, or `None` if the
+ authentication scheme should return `403 Permission Denied` responses.
+ """
+ pass
+
class BasicAuthentication(BaseAuthentication):
"""
HTTP Basic authentication against username/password.
"""
+ www_authenticate_realm = 'api'
def authenticate(self, request):
"""
Returns a `User` if a correct username and password have been supplied
using HTTP Basic authentication. Otherwise returns `None`.
"""
- if 'HTTP_AUTHORIZATION' in request.META:
- auth = request.META['HTTP_AUTHORIZATION'].split()
- if len(auth) == 2 and auth[0].lower() == "basic":
- try:
- auth_parts = base64.b64decode(auth[1]).partition(':')
- except TypeError:
- return None
-
- try:
- userid = smart_unicode(auth_parts[0])
- password = smart_unicode(auth_parts[2])
- except DjangoUnicodeDecodeError:
- return None
-
- return self.authenticate_credentials(userid, password)
+ auth = request.META.get('HTTP_AUTHORIZATION', '').split()
+
+ if not auth or auth[0].lower() != "basic":
+ return None
+
+ if len(auth) != 2:
+ raise exceptions.AuthenticationFailed('Invalid basic header')
+
+ try:
+ auth_parts = base64.b64decode(auth[1]).partition(':')
+ except TypeError:
+ raise exceptions.AuthenticationFailed('Invalid basic header')
+
+ try:
+ userid = smart_unicode(auth_parts[0])
+ password = smart_unicode(auth_parts[2])
+ except DjangoUnicodeDecodeError:
+ raise exceptions.AuthenticationFailed('Invalid basic header')
+
+ return self.authenticate_credentials(userid, password)
def authenticate_credentials(self, userid, password):
"""
@@ -55,6 +69,10 @@ class BasicAuthentication(BaseAuthentication):
user = authenticate(username=userid, password=password)
if user is not None and user.is_active:
return (user, None)
+ raise exceptions.AuthenticationFailed('Invalid username/password')
+
+ def authenticate_header(self, request):
+ return 'Basic realm="%s"' % self.www_authenticate_realm
class SessionAuthentication(BaseAuthentication):
@@ -74,7 +92,7 @@ class SessionAuthentication(BaseAuthentication):
# Unauthenticated, CSRF validation not required
if not user or not user.is_active:
- return
+ return None
# Enforce CSRF validation for session based authentication.
class CSRFCheck(CsrfViewMiddleware):
@@ -85,7 +103,7 @@ class SessionAuthentication(BaseAuthentication):
reason = CSRFCheck().process_view(http_request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
- raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
+ raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason)
# CSRF passed with authenticated user
return (user, None)
@@ -112,14 +130,26 @@ class TokenAuthentication(BaseAuthentication):
def authenticate(self, request):
auth = request.META.get('HTTP_AUTHORIZATION', '').split()
- if len(auth) == 2 and auth[0].lower() == "token":
- key = auth[1]
- try:
- token = self.model.objects.get(key=key)
- except self.model.DoesNotExist:
- return None
+ if not auth or auth[0].lower() != "token":
+ return None
+
+ if len(auth) != 2:
+ raise exceptions.AuthenticationFailed('Invalid token header')
+
+ return self.authenticate_credentials(auth[1])
+
+ def authenticate_credentials(self, key):
+ try:
+ token = self.model.objects.get(key=key)
+ except self.model.DoesNotExist:
+ raise exceptions.AuthenticationFailed('Invalid token')
+
+ if token.user.is_active:
+ return (token.user, token)
+ raise exceptions.AuthenticationFailed('User inactive or deleted')
+
+ def authenticate_header(self, request):
+ return 'Token'
- if token.user.is_active:
- return (token.user, token)
# TODO: OAuthAuthentication
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 89479deb..d635351c 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -23,6 +23,22 @@ class ParseError(APIException):
self.detail = detail or self.default_detail
+class AuthenticationFailed(APIException):
+ status_code = status.HTTP_401_UNAUTHORIZED
+ default_detail = 'Incorrect authentication credentials.'
+
+ def __init__(self, detail=None):
+ self.detail = detail or self.default_detail
+
+
+class NotAuthenticated(APIException):
+ status_code = status.HTTP_401_UNAUTHORIZED
+ default_detail = 'Authentication credentials were not provided.'
+
+ def __init__(self, detail=None):
+ self.detail = detail or self.default_detail
+
+
class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN
default_detail = 'You do not have permission to perform this action.'
diff --git a/rest_framework/request.py b/rest_framework/request.py
index b7133608..1c28cd17 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -86,6 +86,7 @@ class Request(object):
self._method = Empty
self._content_type = Empty
self._stream = Empty
+ self._authenticator = None
if self.parser_context is None:
self.parser_context = {}
@@ -166,7 +167,7 @@ class Request(object):
by the authentication classes provided to the request.
"""
if not hasattr(self, '_user'):
- self._user, self._auth = self._authenticate()
+ self._authenticator, self._user, self._auth = self._authenticate()
return self._user
@user.setter
@@ -185,7 +186,7 @@ class Request(object):
request, such as an authentication token.
"""
if not hasattr(self, '_auth'):
- self._user, self._auth = self._authenticate()
+ self._authenticator, self._user, self._auth = self._authenticate()
return self._auth
@auth.setter
@@ -196,6 +197,14 @@ class Request(object):
"""
self._auth = value
+ @property
+ def successful_authenticator(self):
+ """
+ Return the instance of the authentication instance class that was used
+ to authenticate the request, or `None`.
+ """
+ return self._authenticator
+
def _load_data_and_files(self):
"""
Parses the request content into self.DATA and self.FILES.
@@ -299,21 +308,23 @@ class Request(object):
def _authenticate(self):
"""
- Attempt to authenticate the request using each authentication instance in turn.
- Returns a two-tuple of (user, authtoken).
+ Attempt to authenticate the request using each authentication instance
+ in turn.
+ Returns a three-tuple of (authenticator, user, authtoken).
"""
for authenticator in self.authenticators:
user_auth_tuple = authenticator.authenticate(self)
if not user_auth_tuple is None:
- return user_auth_tuple
+ user, auth = user_auth_tuple
+ return (authenticator, user, auth)
return self._not_authenticated()
def _not_authenticated(self):
"""
- Return a two-tuple of (user, authtoken), representing an
- unauthenticated request.
+ Return a three-tuple of (authenticator, user, authtoken), representing
+ an unauthenticated request.
- By default this will be (AnonymousUser, None).
+ By default this will be (None, AnonymousUser, None).
"""
if api_settings.UNAUTHENTICATED_USER:
user = api_settings.UNAUTHENTICATED_USER()
@@ -325,7 +336,7 @@ class Request(object):
else:
auth = None
- return (user, auth)
+ return (None, user, auth)
def __getattr__(self, attr):
"""
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index e86041bc..1f17e8d2 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -4,7 +4,7 @@ from django.test import Client, TestCase
from rest_framework import permissions
from rest_framework.authtoken.models import Token
-from rest_framework.authentication import TokenAuthentication
+from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication
from rest_framework.compat import patterns
from rest_framework.views import APIView
@@ -21,10 +21,10 @@ class MockView(APIView):
def put(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
-MockView.authentication_classes += (TokenAuthentication,)
-
urlpatterns = patterns('',
- (r'^$', MockView.as_view()),
+ (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
+ (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
+ (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
)
@@ -43,24 +43,25 @@ class BasicAuthTests(TestCase):
def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
- response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_form_failing_basic_auth(self):
"""Ensure POSTing form over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/basic/', {'example': 'example'})
+ self.assertEqual(response.status_code, 401)
def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, 401)
+ self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
class SessionAuthTests(TestCase):
@@ -83,7 +84,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication without CSRF token fails.
"""
self.csrf_client.login(username=self.username, password=self.password)
- response = self.csrf_client.post('/', {'example': 'example'})
+ response = self.csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
def test_post_form_session_auth_passing(self):
@@ -91,7 +92,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.post('/', {'example': 'example'})
+ response = self.non_csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 200)
def test_put_form_session_auth_passing(self):
@@ -99,14 +100,14 @@ class SessionAuthTests(TestCase):
Ensure PUTting form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.put('/', {'example': 'example'})
+ response = self.non_csrf_client.put('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 200)
def test_post_form_session_auth_failing(self):
"""
Ensure POSTing form over session authentication without logged in user fails.
"""
- response = self.csrf_client.post('/', {'example': 'example'})
+ response = self.csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
@@ -127,24 +128,24 @@ class TokenAuthTests(TestCase):
def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
- response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/token/', {'example': 'example'})
+ self.assertEqual(response.status_code, 401)
def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails"""
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, 401)
def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key"""
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 10bdd5a5..ac9b3385 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -148,6 +148,8 @@ class APIView(View):
"""
If request is not permitted, determine what kind of exception to raise.
"""
+ if not self.request.successful_authenticator:
+ raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied()
def throttled(self, request, wait):
@@ -156,6 +158,15 @@ class APIView(View):
"""
raise exceptions.Throttled(wait)
+ def get_authenticate_header(self, request):
+ """
+ If a request is unauthenticated, determine the WWW-Authenticate
+ header to use for 401 responses, if any.
+ """
+ authenticators = self.get_authenticators()
+ if authenticators:
+ return authenticators[0].authenticate_header(request)
+
def get_parser_context(self, http_request):
"""
Returns a dict that is passed through to Parser.parse(),
@@ -319,6 +330,16 @@ class APIView(View):
# Throttle wait header
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
+ if isinstance(exc, (exceptions.NotAuthenticated,
+ exceptions.AuthenticationFailed)):
+ # WWW-Authenticate header for 401 responses, else coerce to 403
+ auth_header = self.get_authenticate_header(self.request)
+
+ if auth_header:
+ self.headers['WWW-Authenticate'] = auth_header
+ else:
+ exc.status_code = status.HTTP_403_FORBIDDEN
+
if isinstance(exc, exceptions.APIException):
return Response({'detail': exc.detail},
status=exc.status_code,