diff options
| author | Tom Christie | 2013-03-07 17:43:13 +0000 |
|---|---|---|
| committer | Tom Christie | 2013-03-07 17:43:13 +0000 |
| commit | a4b33992a5e2affb710d0c16f2286d8ddc81f07c (patch) | |
| tree | 29da9798f52a8ab1376f08b70d729e65caabebd3 /rest_framework | |
| parent | 1d62594fa9ed87545a312681f999bbfa0237491b (diff) | |
| parent | 5a56f92abf5f52ac153c4faa1b75af519c96a207 (diff) | |
| download | django-rest-framework-a4b33992a5e2affb710d0c16f2286d8ddc81f07c.tar.bz2 | |
Merge OAuth2 work.
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/authentication.py | 78 | ||||
| -rw-r--r-- | rest_framework/compat.py | 18 | ||||
| -rw-r--r-- | rest_framework/runtests/settings.py | 13 | ||||
| -rw-r--r-- | rest_framework/tests/authentication.py | 142 |
4 files changed, 246 insertions, 5 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 8ee3a900..4d6e3375 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -7,6 +7,7 @@ from django.core.exceptions import ImproperlyConfigured from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import oauth, oauth_provider, oauth_provider_store +from rest_framework.compat import oauth2_provider, oauth2_provider_forms, oauth2_provider_backends from rest_framework.authtoken.models import Token import base64 @@ -251,3 +252,80 @@ class OAuthAuthentication(BaseAuthentication): Checks nonce of request, and return True if valid. """ return oauth_provider_store.check_nonce(request, oauth_request, oauth_request['oauth_nonce']) + + +class OAuth2Authentication(BaseAuthentication): + """ + OAuth 2 authentication backend using `django-oauth2-provider` + """ + require_active = True + + def __init__(self, **kwargs): + super(OAuth2Authentication, self).__init__(**kwargs) + if oauth2_provider is None: + raise ImproperlyConfigured("The 'django-oauth2-provider' package could not be imported. It is required for use with the 'OAuth2Authentication' class.") + + def authenticate(self, request): + """ + The Bearer type is the only finalized type + + Read the spec for more details + http://tools.ietf.org/html/rfc6749#section-7.1 + """ + auth = request.META.get('HTTP_AUTHORIZATION', '').split() + if not auth or auth[0].lower() != "bearer": + raise exceptions.AuthenticationFailed('Invalid Authorization token type') + + if len(auth) != 2: + raise exceptions.AuthenticationFailed('Invalid token header') + + return self.authenticate_credentials(request, auth[1]) + + def authenticate_credentials(self, request, access_token): + """ + :returns: two-tuple of (user, auth) if authentication succeeds, or None otherwise. + """ + + # authenticate the client + oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST) + if not oauth2_client_form.is_valid(): + raise exceptions.AuthenticationFailed("Client could not be validated") + client = oauth2_client_form.cleaned_data.get('client') + + # retrieve the `oauth2_provider.models.OAuth2AccessToken` instance from the access_token + auth_backend = oauth2_provider_backends.AccessTokenBackend() + token = auth_backend.authenticate(access_token, client) + if token is None: + raise exceptions.AuthenticationFailed("Invalid token") # does not exist or is expired + + # TODO check scope + + if not self.check_active(token.user): + raise exceptions.AuthenticationFailed('User not active: %s' % token.user.username) + + if client and token: + request.user = token.user + return (request.user, None) + + raise exceptions.AuthenticationFailed( + 'You are not allowed to access this resource.') + + def authenticate_header(self, request): + """ + Bearer is the only finalized type currently + + Check details on the `OAuth2Authentication.authenticate` method + """ + return 'Bearer' + + def check_active(self, user): + """ + Ensures the user has an active account. + + Optimized for the ``django.contrib.auth.models.User`` case. + """ + if not self.require_active: + # Ignore & move on. + return True + + return user.is_active diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 6efe6762..69be9543 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -441,3 +441,21 @@ try: except ImportError: oauth_provider = None oauth_provider_store = None + +# OAuth 2 support is optional +try: + import provider.oauth2 as oauth2_provider + # # Hack to fix submodule import issues + # submodules = ['backends', 'forms', 'managers', 'models', 'urls', 'views'] + # for s in submodules: + # mod = __import__('provider.oauth2.%s.*' % s) + # setattr(oauth2_provider, s, mod) + from provider.oauth2 import backends as oauth2_provider_backends + from provider.oauth2 import models as oauth2_provider_models + from provider.oauth2 import forms as oauth2_provider_forms + +except ImportError: + oauth2_provider = None + oauth2_provider_backends = None + oauth2_provider_models = None + oauth2_provider_forms = None diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index eb3f1115..9b519f27 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -107,8 +107,19 @@ try: except ImportError: pass else: - INSTALLED_APPS += ('oauth_provider',) + INSTALLED_APPS += ( + 'oauth_provider', + ) +try: + import provider +except ImportError: + pass +else: + INSTALLED_APPS += ( + 'provider', + 'provider.oauth2', + ) STATIC_URL = '/static/' diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 91429841..ddd61b63 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -12,17 +12,19 @@ from rest_framework.authentication import ( TokenAuthentication, BasicAuthentication, SessionAuthentication, - OAuthAuthentication + OAuthAuthentication, + OAuth2Authentication ) from rest_framework.authtoken.models import Token -from rest_framework.compat import patterns +from rest_framework.compat import patterns, url, include +from rest_framework.compat import oauth2_provider, oauth2_provider_models +from rest_framework.compat import oauth, oauth_provider from rest_framework.tests.utils import RequestFactory from rest_framework.views import APIView -from rest_framework.compat import oauth, oauth_provider import json import base64 import time - +import datetime factory = RequestFactory() @@ -48,6 +50,12 @@ urlpatterns = patterns('', (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])) ) +if oauth2_provider is not None: + urlpatterns += patterns('', + url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), + url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), + ) + class BasicAuthTests(TestCase): """Basic authentication""" @@ -380,3 +388,129 @@ class OAuthTests(TestCase): response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) + + +class OAuth2Tests(TestCase): + """OAuth 2.0 authentication""" + urls = 'rest_framework.tests.authentication' + + def setUp(self): + self.csrf_client = Client(enforce_csrf_checks=True) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + self.CLIENT_ID = 'client_key' + self.CLIENT_SECRET = 'client_secret' + self.ACCESS_TOKEN = "access_token" + self.REFRESH_TOKEN = "refresh_token" + + self.oauth2_client = oauth2_provider_models.Client.objects.create( + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + redirect_uri='', + client_type=0, + name='example', + user=None, + ) + + self.access_token = oauth2_provider_models.AccessToken.objects.create( + token=self.ACCESS_TOKEN, + client=self.oauth2_client, + user=self.user, + ) + self.refresh_token = oauth2_provider_models.RefreshToken.objects.create( + user=self.user, + access_token=self.access_token, + client=self.oauth2_client + ) + + def _create_authorization_header(self, token=None): + return "Bearer {0}".format(token or self.access_token.token) + + def _client_credentials_params(self): + return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET} + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_type_failing(self): + """Ensure that a wrong token type lead to the correct HTTP error status code""" + auth = "Wrong token-type-obsviously" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + params = self._client_credentials_params() + response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_format_failing(self): + """Ensure that a wrong token format lead to the correct HTTP error status code""" + auth = "Bearer wrong token format" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + params = self._client_credentials_params() + response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_failing(self): + """Ensure that a wrong token lead to the correct HTTP error status code""" + auth = "Bearer wrong-token" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + params = self._client_credentials_params() + response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_client_data_failing_auth(self): + """Ensure GETing form over OAuth with incorrect client credentials fails""" + auth = self._create_authorization_header() + params = self._client_credentials_params() + params['client_id'] += 'a' + response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_passing_auth(self): + """Ensure GETing form over OAuth with correct client credentials succeed""" + auth = self._create_authorization_header() + params = self._client_credentials_params() + response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_passing_auth(self): + """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" + auth = self._create_authorization_header() + params = self._client_credentials_params() + response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_token_removed_failing_auth(self): + """Ensure POSTing when there is no OAuth access token in db fails""" + self.access_token.delete() + auth = self._create_authorization_header() + params = self._client_credentials_params() + response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_refresh_token_failing_auth(self): + """Ensure POSTing with refresh token instead of access token fails""" + auth = self._create_authorization_header(token=self.refresh_token.token) + params = self._client_credentials_params() + response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_expired_access_token_failing_auth(self): + """Ensure POSTing with expired access token fails with an 'Invalid token' error""" + self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late + self.access_token.save() + auth = self._create_authorization_header() + params = self._client_credentials_params() + response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + self.assertIn('Invalid token', response.content) |
