aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/authentication.py85
-rw-r--r--rest_framework/compat.py11
-rw-r--r--rest_framework/runtests/settings.py4
-rw-r--r--rest_framework/tests/authentication.py108
4 files changed, 204 insertions, 4 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 14b2136b..c20d9cb5 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -6,6 +6,7 @@ from django.contrib.auth import authenticate
from django.utils.encoding import DjangoUnicodeDecodeError
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
+from rest_framework.compat import oauth2_provider, oauth2
from rest_framework.authtoken.models import Token
import base64
@@ -155,4 +156,86 @@ class TokenAuthentication(BaseAuthentication):
return 'Token'
-# TODO: OAuthAuthentication
+class OAuth2Authentication(BaseAuthentication):
+ """
+ OAuth 2 authentication backend using `django-oauth2-provider`
+ """
+ require_active = True
+
+ def __init__(self, **kwargs):
+ super(OAuth2Authentication, self).__init__(**kwargs)
+ if oauth2_provider is None:
+ raise ImproperlyConfigured("The 'django-oauth2-provider' package could not be imported. It is required for use with the 'OAuth2Authentication' class.")
+
+ def authenticate(self, request):
+ """
+ The Bearer type is the only finalized type
+
+ Read the spec for more details
+ http://tools.ietf.org/html/rfc6749#section-7.1
+ """
+ auth = request.META.get('HTTP_AUTHORIZATION', '').split()
+ print auth
+ if not auth or auth[0].lower() != "bearer":
+ return None
+
+ if len(auth) != 2:
+ raise exceptions.AuthenticationFailed('Invalid token header')
+
+ return self.authenticate_credentials(request, auth[1])
+
+ def authenticate_credentials(self, request, access_token):
+ """
+ :returns: two-tuple of (user, auth) if authentication succeeds, or None otherwise.
+ """
+
+ # authenticate the client
+ oauth2_client_form = oauth2.forms.ClientAuthForm(request.REQUEST)
+ if not oauth2_client_form.is_valid():
+ raise exceptions.AuthenticationFailed("Client could not be validated")
+ client = oauth2_client_form.cleaned_data.get('client')
+
+ # retrieve the `oauth2.models.OAuth2AccessToken` instance from the access_token
+ auth_backend = oauth2.backends.AccessTokenBackend()
+ token = auth_backend.authenticate(access_token, client)
+ if token is None:
+ raise exceptions.AuthenticationFailed("Invalid token") # does not exist or is expired
+
+ # TODO check scope
+ # try:
+ # self.validate_token(request, consumer, token)
+ # except oauth2.Error, e:
+ # print "got e"
+ # raise exceptions.AuthenticationFailed(e.message)
+
+ if not self.check_active(token.user):
+ raise exceptions.AuthenticationFailed('User not active: %s' % token.user.username)
+
+ if client and token:
+ request.user = token.user
+ return (request.user, None)
+
+ raise exceptions.AuthenticationFailed(
+ 'You are not allowed to access this resource.')
+
+ return None
+
+ def authenticate_header(self, request):
+ """
+ Bearer is the only finalized type currently
+
+ Check details on the `OAuth2Authentication.authenticate` method
+ """
+ return 'Bearer'
+
+ def check_active(self, user):
+ """
+ Ensures the user has an active account.
+
+ Optimized for the ``django.contrib.auth.models.User`` case.
+ """
+ if not self.require_active:
+ # Ignore & move on.
+ return True
+
+ return user.is_active
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 07fdddce..5bba0c86 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -426,3 +426,14 @@ try:
import defusedxml.ElementTree as etree
except ImportError:
etree = None
+
+
+# OAuth 2 support is optional
+try:
+ import provider as oauth2_provider
+except ImportError:
+ oauth2_provider = None
+try:
+ import provider.oauth2 as oauth2
+except ImportError:
+ oauth2 = None
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index 03bfc216..67dc7fff 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -97,7 +97,9 @@ INSTALLED_APPS = (
# 'django.contrib.admindocs',
'rest_framework',
'rest_framework.authtoken',
- 'rest_framework.tests'
+ 'rest_framework.tests',
+ 'provider',
+ 'provider.oauth2',
)
STATIC_URL = '/static/'
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index 7b754af5..c2c23bcc 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -1,4 +1,5 @@
from __future__ import unicode_literals
+from django.core.urlresolvers import reverse
from django.contrib.auth.models import User
from django.http import HttpResponse
from django.test import Client, TestCase
@@ -11,13 +12,18 @@ from rest_framework.authentication import (
BaseAuthentication,
TokenAuthentication,
BasicAuthentication,
- SessionAuthentication
+ SessionAuthentication,
+ OAuth2Authentication
)
-from rest_framework.compat import patterns
+from rest_framework.compat import patterns, url, include
+from rest_framework.compat import oauth2
+from rest_framework.compat import oauth2_provider
from rest_framework.tests.utils import RequestFactory
from rest_framework.views import APIView
import json
import base64
+import datetime
+import unittest
factory = RequestFactory()
@@ -41,6 +47,8 @@ urlpatterns = patterns('',
(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
+ url(r'^oauth2/', include('provider.oauth2.urls', namespace = 'oauth2')),
+ url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
)
@@ -222,3 +230,99 @@ class IncorrectCredentialsTests(TestCase):
response = view(request)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.data, {'detail': 'Bad credentials'})
+
+
+class OAuth2Tests(TestCase):
+ """OAuth 2.0 authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CLIENT_ID = 'client_key'
+ self.CLIENT_SECRET = 'client_secret'
+ self.ACCESS_TOKEN = "access_token"
+ self.REFRESH_TOKEN = "refresh_token"
+
+ self.oauth2_client = oauth2.models.Client.objects.create(
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ redirect_uri='',
+ client_type=0,
+ name='example',
+ user=None,
+ )
+
+ self.access_token = oauth2.models.AccessToken.objects.create(
+ token=self.ACCESS_TOKEN,
+ client=self.oauth2_client,
+ user=self.user,
+ )
+ self.refresh_token = oauth2.models.RefreshToken.objects.create(
+ user=self.user,
+ access_token=self.access_token,
+ client=self.oauth2_client
+ )
+
+ def _create_authorization_header(self, token=None):
+ return "Bearer {0}".format(token or self.access_token.token)
+
+ def _client_credentials_params(self):
+ return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET}
+
+ @unittest.skipUnless(oauth2, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_client_data_failing_auth(self):
+ """Ensure GETing form over OAuth with incorrect client credentials fails"""
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ params['client_id'] += 'a'
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2, 'django-oauth2-provider not installed')
+ def test_get_form_passing_auth(self):
+ """Ensure GETing form over OAuth with correct client credentials succeed"""
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2, 'django-oauth2-provider not installed')
+ def test_post_form_passing_auth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2, 'django-oauth2-provider not installed')
+ def test_post_form_token_removed_failing_auth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.access_token.delete()
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2, 'django-oauth2-provider not installed')
+ def test_post_form_with_refresh_token_failing_auth(self):
+ """Ensure POSTing with refresh token instead of access token fails"""
+ auth = self._create_authorization_header(token=self.refresh_token.token)
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2, 'django-oauth2-provider not installed')
+ def test_post_form_with_expired_access_token_failing_auth(self):
+ """Ensure POSTing with expired access token fails with an 'Invalid token' error"""
+ self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
+ self.access_token.save()
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+ self.assertIn('Invalid token', response.content)