diff options
| author | Tom Christie | 2013-03-07 09:01:53 +0000 | 
|---|---|---|
| committer | Tom Christie | 2013-03-07 09:01:53 +0000 | 
| commit | d4e3610e716f2fbbda32aefb972e604446054127 (patch) | |
| tree | db9daafbe8736d7c8854bd5ef4c310ad1dd6cb0b /rest_framework | |
| parent | ddd7125a63c5187483058bad27c94676b9b6c16e (diff) | |
| parent | 2eabc5c2b46d9f4cc7a467af849ff31397b9d7bf (diff) | |
| download | django-rest-framework-d4e3610e716f2fbbda32aefb972e604446054127.tar.bz2 | |
Merge & clean OAuth support
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/authentication.py | 102 | ||||
| -rw-r--r-- | rest_framework/compat.py | 15 | ||||
| -rw-r--r-- | rest_framework/runtests/settings.py | 12 | ||||
| -rw-r--r-- | rest_framework/tests/authentication.py | 162 | 
4 files changed, 281 insertions, 10 deletions
| diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 14b2136b..24a8e336 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -3,9 +3,10 @@ Provides a set of pluggable authentication policies.  """  from __future__ import unicode_literals  from django.contrib.auth import authenticate -from django.utils.encoding import DjangoUnicodeDecodeError +from django.core.exceptions import ImproperlyConfigured  from rest_framework import exceptions, HTTP_HEADER_ENCODING  from rest_framework.compat import CsrfViewMiddleware +from rest_framework.compat import oauth, oauth_provider, oauth_provider_store  from rest_framework.authtoken.models import Token  import base64 @@ -58,11 +59,7 @@ class BasicAuthentication(BaseAuthentication):          except (TypeError, UnicodeDecodeError):              raise exceptions.AuthenticationFailed('Invalid basic header') -        try: -            userid, password = auth_parts[0], auth_parts[2] -        except DjangoUnicodeDecodeError: -            raise exceptions.AuthenticationFailed('Invalid basic header') - +        userid, password = auth_parts[0], auth_parts[2]          return self.authenticate_credentials(userid, password)      def authenticate_credentials(self, userid, password): @@ -155,4 +152,95 @@ class TokenAuthentication(BaseAuthentication):          return 'Token' -# TODO: OAuthAuthentication +class OAuthAuthentication(BaseAuthentication): +    """ +    OAuth 1.0a authentication backend using `django-oauth-plus` and `oauth2`. + +    Note: The `oauth2` package actually provides oauth1.0a support.  Urg. +    """ +    www_authenticate_realm = 'api' + +    def __init__(self, **kwargs): +        super(OAuthAuthentication, self).__init__(**kwargs) + +        if oauth is None: +            raise ImproperlyConfigured("The 'oauth2' package could not be imported. It is required for use with the 'OAuthAuthentication' class.") + +        if oauth_provider is None: +            raise ImproperlyConfigured("The 'django-oauth-plus' package could not be imported. It is required for use with the 'OAuthAuthentication' class.") + +    def authenticate(self, request): +        """ +        Returns two-tuple of (user, token) if authentication succeeds, +        or None otherwise. +        """ +        if not self.is_valid_request(request): +            return None + +        oauth_request = oauth_provider.utils.get_oauth_request(request) + +        if not self.check_nonce(request, oauth_request): +            raise exceptions.AuthenticationFailed("Nonce check failed") + +        try: +            consumer_key = oauth_request.get_parameter('oauth_consumer_key') +            consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key) +        except oauth_provider_store.InvalidConsumerError, err: +            raise exceptions.AuthenticationFailed(err) + +        if consumer.status != oauth_provider.consts.ACCEPTED: +            msg = 'Invalid consumer key status: %s' % consumer.get_status_display() +            raise exceptions.AuthenticationFailed(msg) + +        try: +            token_param = oauth_request.get_parameter('oauth_token') +            token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param) +        except oauth_provider_store.InvalidTokenError: +            msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token') +            raise exceptions.AuthenticationFailed(msg) + +        try: +            self.validate_token(request, consumer, token) +        except oauth.Error, e: +            raise exceptions.AuthenticationFailed(e.message) + +        user = token.user + +        if not user.is_active: +            raise exceptions.AuthenticationFailed('User inactive or deleted: %s' % user.username) + +        return (token.user, token) + +    def authenticate_header(self, request): +        return 'OAuth realm="%s"' % self.www_authenticate_realm + +    def is_in(self, params): +        """ +        Checks to ensure that all the OAuth parameter names are in the +        provided ``params``. +        """ +        for param_name in oauth_provider.consts.OAUTH_PARAMETERS_NAMES: +            if param_name not in params: +                return False + +        return True + +    def is_valid_request(self, request): +        """ +        Checks whether the required parameters are either in the HTTP +        `Authorization` header sent by some clients. +        (The preferred method according to OAuth spec.) +        Or fall back to `GET/POST`. +        """ +        auth_params = request.META.get('HTTP_AUTHORIZATION', []) +        return self.is_in(auth_params) or self.is_in(request.REQUEST) + +    def validate_token(self, request, consumer, token): +        oauth_server, oauth_request = oauth_provider.utils.initialize_server_request(request) +        return oauth_server.verify_request(oauth_request, consumer, token) + +    def check_nonce(self, request, oauth_request): +        """ +        Checks nonce of request. +        """ +        return oauth_provider.store.store.check_nonce(request, oauth_request, oauth_request['oauth_nonce']) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 07fdddce..6efe6762 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -426,3 +426,18 @@ try:      import defusedxml.ElementTree as etree  except ImportError:      etree = None + +# OAuth is optional +try: +    # Note: The `oauth2` package actually provides oauth1.0a support.  Urg. +    import oauth2 as oauth +except ImportError: +    oauth = None + +# OAuth is optional +try: +    import oauth_provider +    from oauth_provider.store import store as oauth_provider_store +except ImportError: +    oauth_provider = None +    oauth_provider_store = None diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index 03bfc216..eb3f1115 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -97,9 +97,19 @@ INSTALLED_APPS = (      # 'django.contrib.admindocs',      'rest_framework',      'rest_framework.authtoken', -    'rest_framework.tests' +    'rest_framework.tests',  ) +# OAuth is optional and won't work if there is no oauth_provider & oauth2 +try: +    import oauth_provider +    import oauth2 +except ImportError: +    pass +else: +    INSTALLED_APPS += ('oauth_provider',) + +  STATIC_URL = '/static/'  PASSWORD_HASHERS = ( diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 7b754af5..91429841 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -2,22 +2,26 @@ from __future__ import unicode_literals  from django.contrib.auth.models import User  from django.http import HttpResponse  from django.test import Client, TestCase +from django.utils import unittest  from rest_framework import HTTP_HEADER_ENCODING  from rest_framework import exceptions  from rest_framework import permissions  from rest_framework import status -from rest_framework.authtoken.models import Token  from rest_framework.authentication import (      BaseAuthentication,      TokenAuthentication,      BasicAuthentication, -    SessionAuthentication +    SessionAuthentication, +    OAuthAuthentication  ) +from rest_framework.authtoken.models import Token  from rest_framework.compat import patterns  from rest_framework.tests.utils import RequestFactory  from rest_framework.views import APIView +from rest_framework.compat import oauth, oauth_provider  import json  import base64 +import time  factory = RequestFactory() @@ -41,6 +45,7 @@ urlpatterns = patterns('',      (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),      (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),      (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), +    (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication]))  ) @@ -222,3 +227,156 @@ class IncorrectCredentialsTests(TestCase):          response = view(request)          self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)          self.assertEqual(response.data, {'detail': 'Bad credentials'}) + + +class OAuthTests(TestCase): +    """OAuth 1.0a authentication""" +    urls = 'rest_framework.tests.authentication' + +    def setUp(self): +        # these imports are here because oauth is optional and hiding them in try..except block or compat +        # could obscure problems if something breaks +        from oauth_provider.models import Consumer, Resource +        from oauth_provider.models import Token as OAuthToken +        from oauth_provider import consts + +        self.consts = consts + +        self.csrf_client = Client(enforce_csrf_checks=True) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +        self.CONSUMER_KEY = 'consumer_key' +        self.CONSUMER_SECRET = 'consumer_secret' +        self.TOKEN_KEY = "token_key" +        self.TOKEN_SECRET = "token_secret" + +        self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, +            name='example', user=self.user, status=self.consts.ACCEPTED) + +        self.resource = Resource.objects.create(name="resource name", url="api/") +        self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, resource=self.resource, +            token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True +        ) + +    def _create_authorization_header(self): +        params = { +            'oauth_version': "1.0", +            'oauth_nonce': oauth.generate_nonce(), +            'oauth_timestamp': int(time.time()), +            'oauth_token': self.token.key, +            'oauth_consumer_key': self.consumer.key +        } + +        req = oauth.Request(method="GET", url="http://example.com", parameters=params) + +        signature_method = oauth.SignatureMethod_PLAINTEXT() +        req.sign_request(signature_method, self.consumer, self.token) + +        return req.to_header()["Authorization"] + +    def _create_authorization_url_parameters(self): +        params = { +            'oauth_version': "1.0", +            'oauth_nonce': oauth.generate_nonce(), +            'oauth_timestamp': int(time.time()), +            'oauth_token': self.token.key, +            'oauth_consumer_key': self.consumer.key +        } + +        req = oauth.Request(method="GET", url="http://example.com", parameters=params) + +        signature_method = oauth.SignatureMethod_PLAINTEXT() +        req.sign_request(signature_method, self.consumer, self.token) +        return dict(req) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_passing_oauth(self): +        """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_repeated_nonce_failing_oauth(self): +        """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails""" +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 200) + +        # simulate reply attack auth header containes already used (nonce, timestamp) pair +        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_token_removed_failing_oauth(self): +        """Ensure POSTing when there is no OAuth access token in db fails""" +        self.token.delete() +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_consumer_status_not_accepted_failing_oauth(self): +        """Ensure POSTing when consumer status is anything other than ACCEPTED fails""" +        for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED): +            self.consumer.status = consumer_status +            self.consumer.save() + +            auth = self._create_authorization_header() +            response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +            self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_with_request_token_failing_oauth(self): +        """Ensure POSTing with unauthorized request token instead of access token fails""" +        self.token.token_type = self.token.REQUEST +        self.token.save() + +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_with_urlencoded_parameters(self): +        """Ensure POSTing with x-www-form-urlencoded auth parameters passes""" +        params = self._create_authorization_url_parameters() +        response = self.csrf_client.post('/oauth/', params) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_get_form_with_url_parameters(self): +        """Ensure GETing with auth in url parameters passes""" +        params = self._create_authorization_url_parameters() +        response = self.csrf_client.get('/oauth/', params) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_hmac_sha1_signature_passes(self): +        """Ensure POSTing using HMAC_SHA1 signature method passes""" +        params = { +            'oauth_version': "1.0", +            'oauth_nonce': oauth.generate_nonce(), +            'oauth_timestamp': int(time.time()), +            'oauth_token': self.token.key, +            'oauth_consumer_key': self.consumer.key +        } + +        req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + +        signature_method = oauth.SignatureMethod_HMAC_SHA1() +        req.sign_request(signature_method, self.consumer, self.token) +        auth = req.to_header()["Authorization"] + +        response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 200) | 
