diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/serializers.py | 43 | ||||
| -rw-r--r-- | rest_framework/tests/authentication.py | 39 | ||||
| -rw-r--r-- | rest_framework/views.py | 13 |
3 files changed, 70 insertions, 25 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 91af2af2..ba9e9e9c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -418,6 +418,27 @@ class ModelSerializer(Serializer): """ _options_class = ModelSerializerOptions + field_mapping = { + models.AutoField: IntegerField, + models.FloatField: FloatField, + models.IntegerField: IntegerField, + models.PositiveIntegerField: IntegerField, + models.SmallIntegerField: IntegerField, + models.PositiveSmallIntegerField: IntegerField, + models.DateTimeField: DateTimeField, + models.DateField: DateField, + models.TimeField: TimeField, + models.EmailField: EmailField, + models.CharField: CharField, + models.URLField: URLField, + models.SlugField: SlugField, + models.TextField: CharField, + models.CommaSeparatedIntegerField: CharField, + models.BooleanField: BooleanField, + models.FileField: FileField, + models.ImageField: ImageField, + } + def get_default_fields(self): """ Return all the fields that should be serialized for the model. @@ -515,28 +536,8 @@ class ModelSerializer(Serializer): kwargs['choices'] = model_field.flatchoices return ChoiceField(**kwargs) - field_mapping = { - models.AutoField: IntegerField, - models.FloatField: FloatField, - models.IntegerField: IntegerField, - models.PositiveIntegerField: IntegerField, - models.SmallIntegerField: IntegerField, - models.PositiveSmallIntegerField: IntegerField, - models.DateTimeField: DateTimeField, - models.DateField: DateField, - models.TimeField: TimeField, - models.EmailField: EmailField, - models.CharField: CharField, - models.URLField: URLField, - models.SlugField: SlugField, - models.TextField: CharField, - models.CommaSeparatedIntegerField: CharField, - models.BooleanField: BooleanField, - models.FileField: FileField, - models.ImageField: ImageField, - } try: - return field_mapping[model_field.__class__](**kwargs) + return self.field_mapping[model_field.__class__](**kwargs) except KeyError: return ModelField(model_field=model_field, **kwargs) diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 3ceab808..c2c23bcc 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -4,13 +4,21 @@ from django.contrib.auth.models import User from django.http import HttpResponse from django.test import Client, TestCase from rest_framework import HTTP_HEADER_ENCODING +from rest_framework import exceptions from rest_framework import permissions from rest_framework import status from rest_framework.authtoken.models import Token -from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication, OAuth2Authentication +from rest_framework.authentication import ( + BaseAuthentication, + TokenAuthentication, + BasicAuthentication, + SessionAuthentication, + OAuth2Authentication +) from rest_framework.compat import patterns, url, include from rest_framework.compat import oauth2 from rest_framework.compat import oauth2_provider +from rest_framework.tests.utils import RequestFactory from rest_framework.views import APIView import json import base64 @@ -18,17 +26,21 @@ import datetime import unittest +factory = RequestFactory() + + class MockView(APIView): permission_classes = (permissions.IsAuthenticated,) + def get(self, request): + return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + def post(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) def put(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) - def get(self, request): - return HttpResponse({'a': 1, 'b': 2, 'c': 3}) urlpatterns = patterns('', (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), @@ -199,6 +211,27 @@ class TokenAuthTests(TestCase): self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) +class IncorrectCredentialsTests(TestCase): + def test_incorrect_credentials(self): + """ + If a request contains bad authentication credentials, then + authentication should run and error, even if no permissions + are set on the view. + """ + class IncorrectCredentialsAuth(BaseAuthentication): + def authenticate(self, request): + raise exceptions.AuthenticationFailed('Bad credentials') + + request = factory.get('/') + view = MockView.as_view( + authentication_classes=(IncorrectCredentialsAuth,), + permission_classes=() + ) + response = view(request) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.data, {'detail': 'Bad credentials'}) + + class OAuth2Tests(TestCase): """OAuth 2.0 authentication""" urls = 'rest_framework.tests.authentication' diff --git a/rest_framework/views.py b/rest_framework/views.py index 69377bc0..81cbdcbb 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -8,7 +8,7 @@ from django.utils.html import escape from django.utils.safestring import mark_safe from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions -from rest_framework.compat import View, apply_markdown, smart_text +from rest_framework.compat import View, apply_markdown from rest_framework.response import Response from rest_framework.request import Request from rest_framework.settings import api_settings @@ -257,6 +257,16 @@ class APIView(View): return (renderers[0], renderers[0].media_type) raise + def perform_authentication(self, request): + """ + Perform authentication on the incoming request. + + Note that if you override this and simply 'pass', then authentication + will instead be performed lazily, the first time either + `request.user` or `request.auth` is accessed. + """ + request.user + def check_permissions(self, request): """ Check if the request should be permitted. @@ -305,6 +315,7 @@ class APIView(View): self.format_kwarg = self.get_format_suffix(**kwargs) # Ensure that the incoming request is permitted + self.perform_authentication(request) self.check_permissions(request) self.check_throttles(request) |
