diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/authentication.py | 14 | ||||
| -rw-r--r-- | rest_framework/tests/test_authentication.py | 9 | 
2 files changed, 16 insertions, 7 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index da9ca510..887ef5d7 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -310,6 +310,13 @@ class OAuth2Authentication(BaseAuthentication):          auth = get_authorization_header(request).split() +        if len(auth) == 1: +            msg = 'Invalid bearer header. No credentials provided.' +            raise exceptions.AuthenticationFailed(msg) +        elif len(auth) > 2: +            msg = 'Invalid bearer header. Token string should not contain spaces.' +            raise exceptions.AuthenticationFailed(msg) +          if auth and auth[0].lower() == b'bearer':              access_token = auth[1]          elif 'access_token' in request.POST: @@ -319,13 +326,6 @@ class OAuth2Authentication(BaseAuthentication):          else:              return None -        if len(auth) == 1: -            msg = 'Invalid bearer header. No credentials provided.' -            raise exceptions.AuthenticationFailed(msg) -        elif len(auth) > 2: -            msg = 'Invalid bearer header. Token string should not contain spaces.' -            raise exceptions.AuthenticationFailed(msg) -          return self.authenticate_credentials(request, access_token)      def authenticate_credentials(self, request, access_token): diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index a1c43d9c..34bf2910 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -550,6 +550,15 @@ class OAuth2Tests(TestCase):          self.assertEqual(response.status_code, 401)      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_with_wrong_authorization_header_token_missing(self): +        """Ensure that a missing token lead to the correct HTTP error status code""" +        auth = "Bearer" +        response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 401) +        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 401) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')      def test_get_form_passing_auth(self):          """Ensure GETing form over OAuth with correct client credentials succeed"""          auth = self._create_authorization_header()  | 
