diff options
Diffstat (limited to 'rest_framework/tests/authentication.py')
| -rw-r--r-- | rest_framework/tests/authentication.py | 79 |
1 files changed, 55 insertions, 24 deletions
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 8ef9d3ff..91429841 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -3,30 +3,42 @@ from django.contrib.auth.models import User from django.http import HttpResponse from django.test import Client, TestCase from django.utils import unittest -import time -from rest_framework import HTTP_HEADER_ENCODING, status +from rest_framework import HTTP_HEADER_ENCODING +from rest_framework import exceptions from rest_framework import permissions +from rest_framework import status +from rest_framework.authentication import ( + BaseAuthentication, + TokenAuthentication, + BasicAuthentication, + SessionAuthentication, + OAuthAuthentication +) from rest_framework.authtoken.models import Token -from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication, OAuthAuthentication from rest_framework.compat import patterns +from rest_framework.tests.utils import RequestFactory from rest_framework.views import APIView -from rest_framework.compat import oauth -from rest_framework.compat import oauth_provider +from rest_framework.compat import oauth, oauth_provider import json import base64 +import time + + +factory = RequestFactory() class MockView(APIView): permission_classes = (permissions.IsAuthenticated,) + def get(self, request): + return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + def post(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) def put(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) - def get(self, request): - return HttpResponse({'a': 1, 'b': 2, 'c': 3}) urlpatterns = patterns('', (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), @@ -54,7 +66,7 @@ class BasicAuthTests(TestCase): base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) auth = 'Basic %s' % base64_credentials response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_json_passing_basic_auth(self): """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" @@ -62,17 +74,17 @@ class BasicAuthTests(TestCase): base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) auth = 'Basic %s' % base64_credentials response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_basic_auth(self): """Ensure POSTing form over basic auth without correct credentials fails""" response = self.csrf_client.post('/basic/', {'example': 'example'}) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_post_json_failing_basic_auth(self): """Ensure POSTing json over basic auth without correct credentials fails""" response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') @@ -97,7 +109,7 @@ class SessionAuthTests(TestCase): """ self.csrf_client.login(username=self.username, password=self.password) response = self.csrf_client.post('/session/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_post_form_session_auth_passing(self): """ @@ -105,7 +117,7 @@ class SessionAuthTests(TestCase): """ self.non_csrf_client.login(username=self.username, password=self.password) response = self.non_csrf_client.post('/session/', {'example': 'example'}) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_put_form_session_auth_passing(self): """ @@ -113,14 +125,14 @@ class SessionAuthTests(TestCase): """ self.non_csrf_client.login(username=self.username, password=self.password) response = self.non_csrf_client.put('/session/', {'example': 'example'}) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_session_auth_failing(self): """ Ensure POSTing form over session authentication without logged in user fails. """ response = self.csrf_client.post('/session/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) class TokenAuthTests(TestCase): @@ -141,23 +153,23 @@ class TokenAuthTests(TestCase): """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_json_passing_token_auth(self): """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_token_auth(self): """Ensure POSTing form over token auth without correct credentials fails""" response = self.csrf_client.post('/token/', {'example': 'example'}) - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_post_json_failing_token_auth(self): """Ensure POSTing json over token auth without correct credentials fails""" response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') - self.assertEqual(response.status_code, 401) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_token_has_auto_assigned_key_if_none_provided(self): """Ensure creating a token with no key will auto-assign a key""" @@ -170,7 +182,7 @@ class TokenAuthTests(TestCase): client = Client(enforce_csrf_checks=True) response = client.post('/auth-token/', json.dumps({'username': self.username, 'password': self.password}), 'application/json') - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) def test_token_login_json_bad_creds(self): @@ -192,9 +204,31 @@ class TokenAuthTests(TestCase): client = Client(enforce_csrf_checks=True) response = client.post('/auth-token/', {'username': self.username, 'password': self.password}) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) + +class IncorrectCredentialsTests(TestCase): + def test_incorrect_credentials(self): + """ + If a request contains bad authentication credentials, then + authentication should run and error, even if no permissions + are set on the view. + """ + class IncorrectCredentialsAuth(BaseAuthentication): + def authenticate(self, request): + raise exceptions.AuthenticationFailed('Bad credentials') + + request = factory.get('/') + view = MockView.as_view( + authentication_classes=(IncorrectCredentialsAuth,), + permission_classes=() + ) + response = view(request) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.data, {'detail': 'Bad credentials'}) + + class OAuthTests(TestCase): """OAuth 1.0a authentication""" urls = 'rest_framework.tests.authentication' @@ -222,13 +256,11 @@ class OAuthTests(TestCase): self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, name='example', user=self.user, status=self.consts.ACCEPTED) - self.resource = Resource.objects.create(name="resource name", url="api/") self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, resource=self.resource, token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True ) - def _create_authorization_header(self): params = { 'oauth_version': "1.0", @@ -348,4 +380,3 @@ class OAuthTests(TestCase): response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) - |
