aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2012-09-26 13:10:01 -0700
committerTom Christie2012-09-26 13:10:01 -0700
commit622e001e0bde3742f5e8830436ea0304c892480a (patch)
tree709f150fbef84b776ec20614b45c511d89293f19 /rest_framework
parentd3e0ac864f8df4568fe9c0f81d64b41c9f531a02 (diff)
parent686a03481799b465d57f5b6a3f93afe44cb077ca (diff)
downloaddjango-rest-framework-622e001e0bde3742f5e8830436ea0304c892480a.tar.bz2
Merge pull request #261 from j4mie/improved-view-decorators
First stab at new function-based view decorators
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/decorators.py92
-rw-r--r--rest_framework/tests/decorators.py107
2 files changed, 168 insertions, 31 deletions
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 9836c966..fc81489e 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -1,5 +1,4 @@
from functools import wraps
-from django.http import Http404
from django.utils.decorators import available_attrs
from django.core.exceptions import PermissionDenied
from rest_framework import exceptions
@@ -7,47 +6,78 @@ from rest_framework import status
from rest_framework.response import Response
from rest_framework.request import Request
from rest_framework.settings import api_settings
+from rest_framework.views import APIView
-def api_view(allowed_methods):
- """
- Decorator for function based views.
+def api_view(http_method_names):
- @api_view(['GET', 'POST'])
- def my_view(request):
- # request will be an instance of `Request`
- # `Response` objects will have .request set automatically
- # APIException instances will be handled
"""
- allowed_methods = [method.upper() for method in allowed_methods]
+ Decorator that converts a function-based view into an APIView subclass.
+ Takes a list of allowed methods for the view as an argument.
+ """
def decorator(func):
- @wraps(func, assigned=available_attrs(func))
- def inner(request, *args, **kwargs):
- try:
- request = Request(request)
+ class WrappedAPIView(APIView):
+ pass
+
+ WrappedAPIView.http_method_names = [method.lower() for method in http_method_names]
+
+ def handler(self, *args, **kwargs):
+ return func(*args, **kwargs)
- if request.method not in allowed_methods:
- raise exceptions.MethodNotAllowed(request.method)
+ for method in http_method_names:
+ setattr(WrappedAPIView, method.lower(), handler)
- response = func(request, *args, **kwargs)
+ WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes',
+ APIView.renderer_classes)
- if isinstance(response, Response):
- response.request = request
- if api_settings.FORMAT_SUFFIX_KWARG:
- response.format = kwargs.get(api_settings.FORMAT_SUFFIX_KWARG, None)
- return response
+ WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
+ APIView.parser_classes)
- except exceptions.APIException as exc:
- return Response({'detail': exc.detail}, status=exc.status_code)
+ WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
+ APIView.authentication_classes)
- except Http404 as exc:
- return Response({'detail': 'Not found'},
- status=status.HTTP_404_NOT_FOUND)
+ WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes',
+ APIView.throttle_classes)
+
+ WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
+ APIView.permission_classes)
+
+ return WrappedAPIView.as_view()
+ return decorator
- except PermissionDenied as exc:
- return Response({'detail': 'Permission denied'},
- status=status.HTTP_403_FORBIDDEN)
- return inner
+
+def renderer_classes(renderer_classes):
+ def decorator(func):
+ func.renderer_classes = renderer_classes
+ return func
+ return decorator
+
+
+def parser_classes(parser_classes):
+ def decorator(func):
+ func.parser_classes = parser_classes
+ return func
+ return decorator
+
+
+def authentication_classes(authentication_classes):
+ def decorator(func):
+ func.authentication_classes = authentication_classes
+ return func
+ return decorator
+
+
+def throttle_classes(throttle_classes):
+ def decorator(func):
+ func.throttle_classes = throttle_classes
+ return func
+ return decorator
+
+
+def permission_classes(permission_classes):
+ def decorator(func):
+ func.permission_classes = permission_classes
+ return func
return decorator
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py
new file mode 100644
index 00000000..d41f05d4
--- /dev/null
+++ b/rest_framework/tests/decorators.py
@@ -0,0 +1,107 @@
+from django.test import TestCase
+from rest_framework.response import Response
+from rest_framework.compat import RequestFactory
+from rest_framework.renderers import JSONRenderer
+from rest_framework.parsers import JSONParser
+from rest_framework.authentication import BasicAuthentication
+from rest_framework.throttling import SimpleRateThottle
+from rest_framework.permissions import IsAuthenticated
+from rest_framework.views import APIView
+from rest_framework.decorators import (
+ api_view,
+ renderer_classes,
+ parser_classes,
+ authentication_classes,
+ throttle_classes,
+ permission_classes,
+)
+
+
+class DecoratorTestCase(TestCase):
+
+ def setUp(self):
+ self.factory = RequestFactory()
+
+ def _finalize_response(self, request, response, *args, **kwargs):
+ print "HAI"
+ response.request = request
+ return APIView.finalize_response(self, request, response, *args, **kwargs)
+
+ def test_wrap_view(self):
+
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
+
+ self.assertTrue(isinstance(view.cls_instance, APIView))
+
+ def test_calling_method(self):
+
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 405)
+
+ def test_renderer_classes(self):
+
+ @api_view(['GET'])
+ @renderer_classes([JSONRenderer])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertTrue(isinstance(response.renderer, JSONRenderer))
+
+ def test_parser_classes(self):
+
+ @api_view(['GET'])
+ @parser_classes([JSONParser])
+ def view(request):
+ self.assertEqual(request.parser_classes, [JSONParser])
+ return Response({})
+
+ request = self.factory.get('/')
+ view(request)
+
+ def test_authentication_classes(self):
+
+ @api_view(['GET'])
+ @authentication_classes([BasicAuthentication])
+ def view(request):
+ self.assertEqual(request.authentication_classes, [BasicAuthentication])
+ return Response({})
+
+ request = self.factory.get('/')
+ view(request)
+
+ def test_permission_classes(self):
+
+ @api_view(['GET'])
+ @permission_classes([IsAuthenticated])
+ def view(request):
+ self.assertEqual(request.permission_classes, [IsAuthenticated])
+ return Response({})
+
+ request = self.factory.get('/')
+ view(request)
+
+# Doesn't look like this bits are working quite yet
+
+# def test_throttle_classes(self):
+
+# @api_view(['GET'])
+# @throttle_classes([SimpleRateThottle])
+# def view(request):
+# self.assertEqual(request.throttle_classes, [SimpleRateThottle])
+# return Response({})
+
+# request = self.factory.get('/')
+# view(request)