diff options
| author | Tom Christie | 2013-01-22 09:12:48 -0800 |
|---|---|---|
| committer | Tom Christie | 2013-01-22 09:12:48 -0800 |
| commit | dd10d538ffc8f76ccc670f65da2220b09c22688c (patch) | |
| tree | 1af09c7dbcc939c749d30adf25b14d232200f44f /rest_framework | |
| parent | e29ba356f054222893655901923811bd9675d4cc (diff) | |
| parent | b7ab2aee46c718f683b19eefba1b48f233da40e4 (diff) | |
| download | django-rest-framework-dd10d538ffc8f76ccc670f65da2220b09c22688c.tar.bz2 | |
Merge pull request #416 from tomchristie/unauthenticated_response
Unauthenticated requests - 401 vs 403 responses
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/authentication.py | 80 | ||||
| -rw-r--r-- | rest_framework/exceptions.py | 16 | ||||
| -rw-r--r-- | rest_framework/request.py | 29 | ||||
| -rw-r--r-- | rest_framework/tests/authentication.py | 41 | ||||
| -rw-r--r-- | rest_framework/views.py | 21 |
5 files changed, 133 insertions, 54 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 30c78ebc..fc169189 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -21,32 +21,46 @@ class BaseAuthentication(object): """ raise NotImplementedError(".authenticate() must be overridden.") + def authenticate_header(self, request): + """ + Return a string to be used as the value of the `WWW-Authenticate` + header in a `401 Unauthenticated` response, or `None` if the + authentication scheme should return `403 Permission Denied` responses. + """ + pass + class BasicAuthentication(BaseAuthentication): """ HTTP Basic authentication against username/password. """ + www_authenticate_realm = 'api' def authenticate(self, request): """ Returns a `User` if a correct username and password have been supplied using HTTP Basic authentication. Otherwise returns `None`. """ - if 'HTTP_AUTHORIZATION' in request.META: - auth = request.META['HTTP_AUTHORIZATION'].split() - if len(auth) == 2 and auth[0].lower() == "basic": - try: - auth_parts = base64.b64decode(auth[1]).partition(':') - except TypeError: - return None - - try: - userid = smart_unicode(auth_parts[0]) - password = smart_unicode(auth_parts[2]) - except DjangoUnicodeDecodeError: - return None - - return self.authenticate_credentials(userid, password) + auth = request.META.get('HTTP_AUTHORIZATION', '').split() + + if not auth or auth[0].lower() != "basic": + return None + + if len(auth) != 2: + raise exceptions.AuthenticationFailed('Invalid basic header') + + try: + auth_parts = base64.b64decode(auth[1]).partition(':') + except TypeError: + raise exceptions.AuthenticationFailed('Invalid basic header') + + try: + userid = smart_unicode(auth_parts[0]) + password = smart_unicode(auth_parts[2]) + except DjangoUnicodeDecodeError: + raise exceptions.AuthenticationFailed('Invalid basic header') + + return self.authenticate_credentials(userid, password) def authenticate_credentials(self, userid, password): """ @@ -55,6 +69,10 @@ class BasicAuthentication(BaseAuthentication): user = authenticate(username=userid, password=password) if user is not None and user.is_active: return (user, None) + raise exceptions.AuthenticationFailed('Invalid username/password') + + def authenticate_header(self, request): + return 'Basic realm="%s"' % self.www_authenticate_realm class SessionAuthentication(BaseAuthentication): @@ -74,7 +92,7 @@ class SessionAuthentication(BaseAuthentication): # Unauthenticated, CSRF validation not required if not user or not user.is_active: - return + return None # Enforce CSRF validation for session based authentication. class CSRFCheck(CsrfViewMiddleware): @@ -85,7 +103,7 @@ class SessionAuthentication(BaseAuthentication): reason = CSRFCheck().process_view(http_request, None, (), {}) if reason: # CSRF failed, bail with explicit error message - raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) + raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason) # CSRF passed with authenticated user return (user, None) @@ -112,14 +130,26 @@ class TokenAuthentication(BaseAuthentication): def authenticate(self, request): auth = request.META.get('HTTP_AUTHORIZATION', '').split() - if len(auth) == 2 and auth[0].lower() == "token": - key = auth[1] - try: - token = self.model.objects.get(key=key) - except self.model.DoesNotExist: - return None + if not auth or auth[0].lower() != "token": + return None + + if len(auth) != 2: + raise exceptions.AuthenticationFailed('Invalid token header') + + return self.authenticate_credentials(auth[1]) + + def authenticate_credentials(self, key): + try: + token = self.model.objects.get(key=key) + except self.model.DoesNotExist: + raise exceptions.AuthenticationFailed('Invalid token') + + if token.user.is_active: + return (token.user, token) + raise exceptions.AuthenticationFailed('User inactive or deleted') + + def authenticate_header(self, request): + return 'Token' - if token.user.is_active: - return (token.user, token) # TODO: OAuthAuthentication diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 89479deb..d635351c 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -23,6 +23,22 @@ class ParseError(APIException): self.detail = detail or self.default_detail +class AuthenticationFailed(APIException): + status_code = status.HTTP_401_UNAUTHORIZED + default_detail = 'Incorrect authentication credentials.' + + def __init__(self, detail=None): + self.detail = detail or self.default_detail + + +class NotAuthenticated(APIException): + status_code = status.HTTP_401_UNAUTHORIZED + default_detail = 'Authentication credentials were not provided.' + + def __init__(self, detail=None): + self.detail = detail or self.default_detail + + class PermissionDenied(APIException): status_code = status.HTTP_403_FORBIDDEN default_detail = 'You do not have permission to perform this action.' diff --git a/rest_framework/request.py b/rest_framework/request.py index b7133608..1c28cd17 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -86,6 +86,7 @@ class Request(object): self._method = Empty self._content_type = Empty self._stream = Empty + self._authenticator = None if self.parser_context is None: self.parser_context = {} @@ -166,7 +167,7 @@ class Request(object): by the authentication classes provided to the request. """ if not hasattr(self, '_user'): - self._user, self._auth = self._authenticate() + self._authenticator, self._user, self._auth = self._authenticate() return self._user @user.setter @@ -185,7 +186,7 @@ class Request(object): request, such as an authentication token. """ if not hasattr(self, '_auth'): - self._user, self._auth = self._authenticate() + self._authenticator, self._user, self._auth = self._authenticate() return self._auth @auth.setter @@ -196,6 +197,14 @@ class Request(object): """ self._auth = value + @property + def successful_authenticator(self): + """ + Return the instance of the authentication instance class that was used + to authenticate the request, or `None`. + """ + return self._authenticator + def _load_data_and_files(self): """ Parses the request content into self.DATA and self.FILES. @@ -299,21 +308,23 @@ class Request(object): def _authenticate(self): """ - Attempt to authenticate the request using each authentication instance in turn. - Returns a two-tuple of (user, authtoken). + Attempt to authenticate the request using each authentication instance + in turn. + Returns a three-tuple of (authenticator, user, authtoken). """ for authenticator in self.authenticators: user_auth_tuple = authenticator.authenticate(self) if not user_auth_tuple is None: - return user_auth_tuple + user, auth = user_auth_tuple + return (authenticator, user, auth) return self._not_authenticated() def _not_authenticated(self): """ - Return a two-tuple of (user, authtoken), representing an - unauthenticated request. + Return a three-tuple of (authenticator, user, authtoken), representing + an unauthenticated request. - By default this will be (AnonymousUser, None). + By default this will be (None, AnonymousUser, None). """ if api_settings.UNAUTHENTICATED_USER: user = api_settings.UNAUTHENTICATED_USER() @@ -325,7 +336,7 @@ class Request(object): else: auth = None - return (user, auth) + return (None, user, auth) def __getattr__(self, attr): """ diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index e86041bc..1f17e8d2 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -4,7 +4,7 @@ from django.test import Client, TestCase from rest_framework import permissions from rest_framework.authtoken.models import Token -from rest_framework.authentication import TokenAuthentication +from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication from rest_framework.compat import patterns from rest_framework.views import APIView @@ -21,10 +21,10 @@ class MockView(APIView): def put(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) -MockView.authentication_classes += (TokenAuthentication,) - urlpatterns = patterns('', - (r'^$', MockView.as_view()), + (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), + (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), + (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), ) @@ -43,24 +43,25 @@ class BasicAuthTests(TestCase): def test_post_form_passing_basic_auth(self): """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() - response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) def test_post_json_passing_basic_auth(self): """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) def test_post_form_failing_basic_auth(self): """Ensure POSTing form over basic auth without correct credentials fails""" - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/basic/', {'example': 'example'}) + self.assertEqual(response.status_code, 401) def test_post_json_failing_basic_auth(self): """Ensure POSTing json over basic auth without correct credentials fails""" - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, 401) + self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') class SessionAuthTests(TestCase): @@ -83,7 +84,7 @@ class SessionAuthTests(TestCase): Ensure POSTing form over session authentication without CSRF token fails. """ self.csrf_client.login(username=self.username, password=self.password) - response = self.csrf_client.post('/', {'example': 'example'}) + response = self.csrf_client.post('/session/', {'example': 'example'}) self.assertEqual(response.status_code, 403) def test_post_form_session_auth_passing(self): @@ -91,7 +92,7 @@ class SessionAuthTests(TestCase): Ensure POSTing form over session authentication with logged in user and CSRF token passes. """ self.non_csrf_client.login(username=self.username, password=self.password) - response = self.non_csrf_client.post('/', {'example': 'example'}) + response = self.non_csrf_client.post('/session/', {'example': 'example'}) self.assertEqual(response.status_code, 200) def test_put_form_session_auth_passing(self): @@ -99,14 +100,14 @@ class SessionAuthTests(TestCase): Ensure PUTting form over session authentication with logged in user and CSRF token passes. """ self.non_csrf_client.login(username=self.username, password=self.password) - response = self.non_csrf_client.put('/', {'example': 'example'}) + response = self.non_csrf_client.put('/session/', {'example': 'example'}) self.assertEqual(response.status_code, 200) def test_post_form_session_auth_failing(self): """ Ensure POSTing form over session authentication without logged in user fails. """ - response = self.csrf_client.post('/', {'example': 'example'}) + response = self.csrf_client.post('/session/', {'example': 'example'}) self.assertEqual(response.status_code, 403) @@ -127,24 +128,24 @@ class TokenAuthTests(TestCase): def test_post_form_passing_token_auth(self): """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) def test_post_json_passing_token_auth(self): """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) def test_post_form_failing_token_auth(self): """Ensure POSTing form over token auth without correct credentials fails""" - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/token/', {'example': 'example'}) + self.assertEqual(response.status_code, 401) def test_post_json_failing_token_auth(self): """Ensure POSTing json over token auth without correct credentials fails""" - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, 401) def test_token_has_auto_assigned_key_if_none_provided(self): """Ensure creating a token with no key will auto-assign a key""" diff --git a/rest_framework/views.py b/rest_framework/views.py index 10bdd5a5..ac9b3385 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -148,6 +148,8 @@ class APIView(View): """ If request is not permitted, determine what kind of exception to raise. """ + if not self.request.successful_authenticator: + raise exceptions.NotAuthenticated() raise exceptions.PermissionDenied() def throttled(self, request, wait): @@ -156,6 +158,15 @@ class APIView(View): """ raise exceptions.Throttled(wait) + def get_authenticate_header(self, request): + """ + If a request is unauthenticated, determine the WWW-Authenticate + header to use for 401 responses, if any. + """ + authenticators = self.get_authenticators() + if authenticators: + return authenticators[0].authenticate_header(request) + def get_parser_context(self, http_request): """ Returns a dict that is passed through to Parser.parse(), @@ -319,6 +330,16 @@ class APIView(View): # Throttle wait header self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait + if isinstance(exc, (exceptions.NotAuthenticated, + exceptions.AuthenticationFailed)): + # WWW-Authenticate header for 401 responses, else coerce to 403 + auth_header = self.get_authenticate_header(self.request) + + if auth_header: + self.headers['WWW-Authenticate'] = auth_header + else: + exc.status_code = status.HTTP_403_FORBIDDEN + if isinstance(exc, exceptions.APIException): return Response({'detail': exc.detail}, status=exc.status_code, |
