aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-03-07 17:43:13 +0000
committerTom Christie2013-03-07 17:43:13 +0000
commita4b33992a5e2affb710d0c16f2286d8ddc81f07c (patch)
tree29da9798f52a8ab1376f08b70d729e65caabebd3 /rest_framework
parent1d62594fa9ed87545a312681f999bbfa0237491b (diff)
parent5a56f92abf5f52ac153c4faa1b75af519c96a207 (diff)
downloaddjango-rest-framework-a4b33992a5e2affb710d0c16f2286d8ddc81f07c.tar.bz2
Merge OAuth2 work.
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/authentication.py78
-rw-r--r--rest_framework/compat.py18
-rw-r--r--rest_framework/runtests/settings.py13
-rw-r--r--rest_framework/tests/authentication.py142
4 files changed, 246 insertions, 5 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 8ee3a900..4d6e3375 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -7,6 +7,7 @@ from django.core.exceptions import ImproperlyConfigured
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
+from rest_framework.compat import oauth2_provider, oauth2_provider_forms, oauth2_provider_backends
from rest_framework.authtoken.models import Token
import base64
@@ -251,3 +252,80 @@ class OAuthAuthentication(BaseAuthentication):
Checks nonce of request, and return True if valid.
"""
return oauth_provider_store.check_nonce(request, oauth_request, oauth_request['oauth_nonce'])
+
+
+class OAuth2Authentication(BaseAuthentication):
+ """
+ OAuth 2 authentication backend using `django-oauth2-provider`
+ """
+ require_active = True
+
+ def __init__(self, **kwargs):
+ super(OAuth2Authentication, self).__init__(**kwargs)
+ if oauth2_provider is None:
+ raise ImproperlyConfigured("The 'django-oauth2-provider' package could not be imported. It is required for use with the 'OAuth2Authentication' class.")
+
+ def authenticate(self, request):
+ """
+ The Bearer type is the only finalized type
+
+ Read the spec for more details
+ http://tools.ietf.org/html/rfc6749#section-7.1
+ """
+ auth = request.META.get('HTTP_AUTHORIZATION', '').split()
+ if not auth or auth[0].lower() != "bearer":
+ raise exceptions.AuthenticationFailed('Invalid Authorization token type')
+
+ if len(auth) != 2:
+ raise exceptions.AuthenticationFailed('Invalid token header')
+
+ return self.authenticate_credentials(request, auth[1])
+
+ def authenticate_credentials(self, request, access_token):
+ """
+ :returns: two-tuple of (user, auth) if authentication succeeds, or None otherwise.
+ """
+
+ # authenticate the client
+ oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST)
+ if not oauth2_client_form.is_valid():
+ raise exceptions.AuthenticationFailed("Client could not be validated")
+ client = oauth2_client_form.cleaned_data.get('client')
+
+ # retrieve the `oauth2_provider.models.OAuth2AccessToken` instance from the access_token
+ auth_backend = oauth2_provider_backends.AccessTokenBackend()
+ token = auth_backend.authenticate(access_token, client)
+ if token is None:
+ raise exceptions.AuthenticationFailed("Invalid token") # does not exist or is expired
+
+ # TODO check scope
+
+ if not self.check_active(token.user):
+ raise exceptions.AuthenticationFailed('User not active: %s' % token.user.username)
+
+ if client and token:
+ request.user = token.user
+ return (request.user, None)
+
+ raise exceptions.AuthenticationFailed(
+ 'You are not allowed to access this resource.')
+
+ def authenticate_header(self, request):
+ """
+ Bearer is the only finalized type currently
+
+ Check details on the `OAuth2Authentication.authenticate` method
+ """
+ return 'Bearer'
+
+ def check_active(self, user):
+ """
+ Ensures the user has an active account.
+
+ Optimized for the ``django.contrib.auth.models.User`` case.
+ """
+ if not self.require_active:
+ # Ignore & move on.
+ return True
+
+ return user.is_active
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 6efe6762..69be9543 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -441,3 +441,21 @@ try:
except ImportError:
oauth_provider = None
oauth_provider_store = None
+
+# OAuth 2 support is optional
+try:
+ import provider.oauth2 as oauth2_provider
+ # # Hack to fix submodule import issues
+ # submodules = ['backends', 'forms', 'managers', 'models', 'urls', 'views']
+ # for s in submodules:
+ # mod = __import__('provider.oauth2.%s.*' % s)
+ # setattr(oauth2_provider, s, mod)
+ from provider.oauth2 import backends as oauth2_provider_backends
+ from provider.oauth2 import models as oauth2_provider_models
+ from provider.oauth2 import forms as oauth2_provider_forms
+
+except ImportError:
+ oauth2_provider = None
+ oauth2_provider_backends = None
+ oauth2_provider_models = None
+ oauth2_provider_forms = None
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index eb3f1115..9b519f27 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -107,8 +107,19 @@ try:
except ImportError:
pass
else:
- INSTALLED_APPS += ('oauth_provider',)
+ INSTALLED_APPS += (
+ 'oauth_provider',
+ )
+try:
+ import provider
+except ImportError:
+ pass
+else:
+ INSTALLED_APPS += (
+ 'provider',
+ 'provider.oauth2',
+ )
STATIC_URL = '/static/'
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index 91429841..ddd61b63 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -12,17 +12,19 @@ from rest_framework.authentication import (
TokenAuthentication,
BasicAuthentication,
SessionAuthentication,
- OAuthAuthentication
+ OAuthAuthentication,
+ OAuth2Authentication
)
from rest_framework.authtoken.models import Token
-from rest_framework.compat import patterns
+from rest_framework.compat import patterns, url, include
+from rest_framework.compat import oauth2_provider, oauth2_provider_models
+from rest_framework.compat import oauth, oauth_provider
from rest_framework.tests.utils import RequestFactory
from rest_framework.views import APIView
-from rest_framework.compat import oauth, oauth_provider
import json
import base64
import time
-
+import datetime
factory = RequestFactory()
@@ -48,6 +50,12 @@ urlpatterns = patterns('',
(r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication]))
)
+if oauth2_provider is not None:
+ urlpatterns += patterns('',
+ url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
+ url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
+ )
+
class BasicAuthTests(TestCase):
"""Basic authentication"""
@@ -380,3 +388,129 @@ class OAuthTests(TestCase):
response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
+
+
+class OAuth2Tests(TestCase):
+ """OAuth 2.0 authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CLIENT_ID = 'client_key'
+ self.CLIENT_SECRET = 'client_secret'
+ self.ACCESS_TOKEN = "access_token"
+ self.REFRESH_TOKEN = "refresh_token"
+
+ self.oauth2_client = oauth2_provider_models.Client.objects.create(
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ redirect_uri='',
+ client_type=0,
+ name='example',
+ user=None,
+ )
+
+ self.access_token = oauth2_provider_models.AccessToken.objects.create(
+ token=self.ACCESS_TOKEN,
+ client=self.oauth2_client,
+ user=self.user,
+ )
+ self.refresh_token = oauth2_provider_models.RefreshToken.objects.create(
+ user=self.user,
+ access_token=self.access_token,
+ client=self.oauth2_client
+ )
+
+ def _create_authorization_header(self, token=None):
+ return "Bearer {0}".format(token or self.access_token.token)
+
+ def _client_credentials_params(self):
+ return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET}
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_type_failing(self):
+ """Ensure that a wrong token type lead to the correct HTTP error status code"""
+ auth = "Wrong token-type-obsviously"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_format_failing(self):
+ """Ensure that a wrong token format lead to the correct HTTP error status code"""
+ auth = "Bearer wrong token format"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_failing(self):
+ """Ensure that a wrong token lead to the correct HTTP error status code"""
+ auth = "Bearer wrong-token"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_client_data_failing_auth(self):
+ """Ensure GETing form over OAuth with incorrect client credentials fails"""
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ params['client_id'] += 'a'
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_passing_auth(self):
+ """Ensure GETing form over OAuth with correct client credentials succeed"""
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_passing_auth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_token_removed_failing_auth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.access_token.delete()
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_refresh_token_failing_auth(self):
+ """Ensure POSTing with refresh token instead of access token fails"""
+ auth = self._create_authorization_header(token=self.refresh_token.token)
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_expired_access_token_failing_auth(self):
+ """Ensure POSTing with expired access token fails with an 'Invalid token' error"""
+ self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
+ self.access_token.save()
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+ self.assertIn('Invalid token', response.content)