diff options
| author | Tom Christie | 2013-01-21 21:29:49 +0000 | 
|---|---|---|
| committer | Tom Christie | 2013-01-21 21:29:49 +0000 | 
| commit | 65b62d64ec54b528b62a1500b8f6ffe216d45c09 (patch) | |
| tree | eb30f11fdb82a7940070cd9dca2d276c00cfb2ee /rest_framework | |
| parent | 36fa722ebb1b438b710b90fe470fbdbf82fd676e (diff) | |
| download | django-rest-framework-65b62d64ec54b528b62a1500b8f6ffe216d45c09.tar.bz2 | |
WWW-Authenticate responses
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/authentication.py | 4 | ||||
| -rw-r--r-- | rest_framework/tests/authentication.py | 41 | ||||
| -rw-r--r-- | rest_framework/views.py | 21 | 
3 files changed, 43 insertions, 23 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 6dc80498..fc169189 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -71,7 +71,7 @@ class BasicAuthentication(BaseAuthentication):              return (user, None)          raise exceptions.AuthenticationFailed('Invalid username/password') -    def authenticate_header(self): +    def authenticate_header(self, request):          return 'Basic realm="%s"' % self.www_authenticate_realm @@ -148,7 +148,7 @@ class TokenAuthentication(BaseAuthentication):              return (token.user, token)          raise exceptions.AuthenticationFailed('User inactive or deleted') -    def authenticate_header(self): +    def authenticate_header(self, request):          return 'Token' diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index e86041bc..1f17e8d2 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -4,7 +4,7 @@ from django.test import Client, TestCase  from rest_framework import permissions  from rest_framework.authtoken.models import Token -from rest_framework.authentication import TokenAuthentication +from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication  from rest_framework.compat import patterns  from rest_framework.views import APIView @@ -21,10 +21,10 @@ class MockView(APIView):      def put(self, request):          return HttpResponse({'a': 1, 'b': 2, 'c': 3}) -MockView.authentication_classes += (TokenAuthentication,) -  urlpatterns = patterns('', -    (r'^$', MockView.as_view()), +    (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), +    (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), +    (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),      (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),  ) @@ -43,24 +43,25 @@ class BasicAuthTests(TestCase):      def test_post_form_passing_basic_auth(self):          """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""          auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() -        response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 200)      def test_post_json_passing_basic_auth(self):          """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""          auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() -        response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 200)      def test_post_form_failing_basic_auth(self):          """Ensure POSTing form over basic auth without correct credentials fails""" -        response = self.csrf_client.post('/', {'example': 'example'}) -        self.assertEqual(response.status_code, 403) +        response = self.csrf_client.post('/basic/', {'example': 'example'}) +        self.assertEqual(response.status_code, 401)      def test_post_json_failing_basic_auth(self):          """Ensure POSTing json over basic auth without correct credentials fails""" -        response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') -        self.assertEqual(response.status_code, 403) +        response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') +        self.assertEqual(response.status_code, 401) +        self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')  class SessionAuthTests(TestCase): @@ -83,7 +84,7 @@ class SessionAuthTests(TestCase):          Ensure POSTing form over session authentication without CSRF token fails.          """          self.csrf_client.login(username=self.username, password=self.password) -        response = self.csrf_client.post('/', {'example': 'example'}) +        response = self.csrf_client.post('/session/', {'example': 'example'})          self.assertEqual(response.status_code, 403)      def test_post_form_session_auth_passing(self): @@ -91,7 +92,7 @@ class SessionAuthTests(TestCase):          Ensure POSTing form over session authentication with logged in user and CSRF token passes.          """          self.non_csrf_client.login(username=self.username, password=self.password) -        response = self.non_csrf_client.post('/', {'example': 'example'}) +        response = self.non_csrf_client.post('/session/', {'example': 'example'})          self.assertEqual(response.status_code, 200)      def test_put_form_session_auth_passing(self): @@ -99,14 +100,14 @@ class SessionAuthTests(TestCase):          Ensure PUTting form over session authentication with logged in user and CSRF token passes.          """          self.non_csrf_client.login(username=self.username, password=self.password) -        response = self.non_csrf_client.put('/', {'example': 'example'}) +        response = self.non_csrf_client.put('/session/', {'example': 'example'})          self.assertEqual(response.status_code, 200)      def test_post_form_session_auth_failing(self):          """          Ensure POSTing form over session authentication without logged in user fails.          """ -        response = self.csrf_client.post('/', {'example': 'example'}) +        response = self.csrf_client.post('/session/', {'example': 'example'})          self.assertEqual(response.status_code, 403) @@ -127,24 +128,24 @@ class TokenAuthTests(TestCase):      def test_post_form_passing_token_auth(self):          """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""          auth = "Token " + self.key -        response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 200)      def test_post_json_passing_token_auth(self):          """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""          auth = "Token " + self.key -        response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 200)      def test_post_form_failing_token_auth(self):          """Ensure POSTing form over token auth without correct credentials fails""" -        response = self.csrf_client.post('/', {'example': 'example'}) -        self.assertEqual(response.status_code, 403) +        response = self.csrf_client.post('/token/', {'example': 'example'}) +        self.assertEqual(response.status_code, 401)      def test_post_json_failing_token_auth(self):          """Ensure POSTing json over token auth without correct credentials fails""" -        response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') -        self.assertEqual(response.status_code, 403) +        response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') +        self.assertEqual(response.status_code, 401)      def test_token_has_auto_assigned_key_if_none_provided(self):          """Ensure creating a token with no key will auto-assign a key""" diff --git a/rest_framework/views.py b/rest_framework/views.py index fdb373da..ac9b3385 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -148,7 +148,7 @@ class APIView(View):          """          If request is not permitted, determine what kind of exception to raise.          """ -        if self.request.successful_authenticator: +        if not self.request.successful_authenticator:              raise exceptions.NotAuthenticated()          raise exceptions.PermissionDenied() @@ -158,6 +158,15 @@ class APIView(View):          """          raise exceptions.Throttled(wait) +    def get_authenticate_header(self, request): +        """ +        If a request is unauthenticated, determine the WWW-Authenticate +        header to use for 401 responses, if any. +        """ +        authenticators = self.get_authenticators() +        if authenticators: +            return authenticators[0].authenticate_header(request) +      def get_parser_context(self, http_request):          """          Returns a dict that is passed through to Parser.parse(), @@ -321,6 +330,16 @@ class APIView(View):              # Throttle wait header              self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait +        if isinstance(exc, (exceptions.NotAuthenticated, +                            exceptions.AuthenticationFailed)): +            # WWW-Authenticate header for 401 responses, else coerce to 403 +            auth_header = self.get_authenticate_header(self.request) + +            if auth_header: +                self.headers['WWW-Authenticate'] = auth_header +            else: +                exc.status_code = status.HTTP_403_FORBIDDEN +          if isinstance(exc, exceptions.APIException):              return Response({'detail': exc.detail},                              status=exc.status_code,  | 
