diff options
| -rw-r--r-- | rest_framework/authentication.py | 12 | ||||
| -rw-r--r-- | rest_framework/tests/test_authentication.py | 26 | 
2 files changed, 36 insertions, 2 deletions
| diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index b0e88d88..da9ca510 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -6,6 +6,7 @@ import base64  from django.contrib.auth import authenticate  from django.core.exceptions import ImproperlyConfigured +from django.conf import settings  from rest_framework import exceptions, HTTP_HEADER_ENCODING  from rest_framework.compat import CsrfViewMiddleware  from rest_framework.compat import oauth, oauth_provider, oauth_provider_store @@ -291,6 +292,7 @@ class OAuth2Authentication(BaseAuthentication):      OAuth 2 authentication backend using `django-oauth2-provider`      """      www_authenticate_realm = 'api' +    allow_query_params_token = settings.DEBUG      def __init__(self, *args, **kwargs):          super(OAuth2Authentication, self).__init__(*args, **kwargs) @@ -308,7 +310,13 @@ class OAuth2Authentication(BaseAuthentication):          auth = get_authorization_header(request).split() -        if not auth or auth[0].lower() != b'bearer': +        if auth and auth[0].lower() == b'bearer': +            access_token = auth[1] +        elif 'access_token' in request.POST: +            access_token = request.POST['access_token'] +        elif 'access_token' in request.GET and self.allow_query_params_token: +            access_token = request.GET['access_token'] +        else:              return None          if len(auth) == 1: @@ -318,7 +326,7 @@ class OAuth2Authentication(BaseAuthentication):              msg = 'Invalid bearer header. Token string should not contain spaces.'              raise exceptions.AuthenticationFailed(msg) -        return self.authenticate_credentials(request, auth[1]) +        return self.authenticate_credentials(request, access_token)      def authenticate_credentials(self, request, access_token):          """ diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index 8caeb081..c37d2a51 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -3,6 +3,7 @@ from django.contrib.auth.models import User  from django.http import HttpResponse  from django.test import TestCase  from django.utils import unittest +from django.utils.http import urlencode  from rest_framework import HTTP_HEADER_ENCODING  from rest_framework import exceptions  from rest_framework import permissions @@ -53,10 +54,14 @@ urlpatterns = patterns('',          permission_classes=[permissions.TokenHasReadWriteScope]))  ) +class OAuth2AuthenticationDebug(OAuth2Authentication): +    allow_query_params_token = True +  if oauth2_provider is not None:      urlpatterns += patterns('',          url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),          url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), +        url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),          url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],              permission_classes=[permissions.TokenHasReadWriteScope])),      ) @@ -546,6 +551,27 @@ class OAuth2Tests(TestCase):          self.assertEqual(response.status_code, 200)      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_post_form_passing_auth_url_transport(self): +        """Ensure GETing form over OAuth with correct client credentials in form data succeed""" +        response = self.csrf_client.post('/oauth2-test/', +                data={'access_token': self.access_token.token}) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_passing_auth_url_transport(self): +        """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True""" +        query = urlencode({'access_token': self.access_token.token}) +        response = self.csrf_client.get('/oauth2-test-debug/?%s' % query) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_failing_auth_url_transport(self): +        """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False""" +        query = urlencode({'access_token': self.access_token.token}) +        response = self.csrf_client.get('/oauth2-test/?%s' % query) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')      def test_post_form_passing_auth(self):          """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""          auth = self._create_authorization_header() | 
