aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--djangorestframework/tests/decorators.py102
-rw-r--r--rest_framework/decorators.py107
2 files changed, 179 insertions, 30 deletions
diff --git a/djangorestframework/tests/decorators.py b/djangorestframework/tests/decorators.py
new file mode 100644
index 00000000..0d3be8f3
--- /dev/null
+++ b/djangorestframework/tests/decorators.py
@@ -0,0 +1,102 @@
+from django.test import TestCase
+from djangorestframework.response import Response
+from djangorestframework.compat import RequestFactory
+from djangorestframework.renderers import JSONRenderer
+from djangorestframework.parsers import JSONParser
+from djangorestframework.authentication import BasicAuthentication
+from djangorestframework.throttling import SimpleRateThottle
+from djangorestframework.permissions import IsAuthenticated
+from djangorestframework.decorators import (
+ api_view,
+ renderer_classes,
+ parser_classes,
+ authentication_classes,
+ throttle_classes,
+ permission_classes,
+ LazyViewCreator
+)
+
+
+class DecoratorTestCase(TestCase):
+
+ def setUp(self):
+ self.factory = RequestFactory()
+
+ def test_wrap_view(self):
+
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
+
+ self.assertTrue(isinstance(view, LazyViewCreator))
+
+ def test_calling_method(self):
+
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 405)
+
+ def test_renderer_classes(self):
+
+ @renderer_classes([JSONRenderer])
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.renderer_classes, [JSONRenderer])
+
+ def test_parser_classes(self):
+
+ @parser_classes([JSONParser])
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.request.parser_classes, [JSONParser])
+
+ def test_authentication_classes(self):
+
+ @authentication_classes([BasicAuthentication])
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.request.authentication_classes, [BasicAuthentication])
+
+# Doesn't look like these bits are working quite yet
+
+# def test_throttle_classes(self):
+#
+# @throttle_classes([SimpleRateThottle])
+# @api_view(['GET'])
+# def view(request):
+# return Response({})
+#
+# request = self.factory.get('/')
+# response = view(request)
+# self.assertEqual(response.request.throttle, [SimpleRateThottle])
+
+# def test_permission_classes(self):
+
+# @permission_classes([IsAuthenticated])
+# @api_view(['GET'])
+# def view(request):
+# return Response({})
+
+# request = self.factory.get('/')
+# response = view(request)
+# self.assertEqual(response.request.permission_classes, [IsAuthenticated])
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 9836c966..1483cb56 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -1,5 +1,4 @@
from functools import wraps
-from django.http import Http404
from django.utils.decorators import available_attrs
from django.core.exceptions import PermissionDenied
from rest_framework import exceptions
@@ -7,47 +6,95 @@ from rest_framework import status
from rest_framework.response import Response
from rest_framework.request import Request
from rest_framework.settings import api_settings
+from rest_framwork.views import APIView
-def api_view(allowed_methods):
+class LazyViewCreator(object):
+
"""
- Decorator for function based views.
+ This class is responsible for dynamically creating an APIView subclass that
+ will wrap a function-based view. Instances of this class are created
+ by the function-based view decorators (below), and each decorator is
+ responsible for setting attributes on the instance that will eventually be
+ copied onto the final class-based view. The CBV gets created lazily the first
+ time it's needed, and then cached for future use.
- @api_view(['GET', 'POST'])
- def my_view(request):
- # request will be an instance of `Request`
- # `Response` objects will have .request set automatically
- # APIException instances will be handled
+ This is done so that the ordering of stacked decorators is irrelevant.
"""
- allowed_methods = [method.upper() for method in allowed_methods]
- def decorator(func):
- @wraps(func, assigned=available_attrs(func))
- def inner(request, *args, **kwargs):
- try:
+ def __init__(self, wrapped_view):
+
+ self.wrapped_view = wrapped_view
+
+ # Each item in this dictionary will be copied onto the final
+ # class-based view that gets created when this object is called
+ self.final_view_attrs = {
+ 'http_method_names': APIView.http_method_names,
+ 'renderer_classes': APIView.renderer_classes,
+ 'parser_classes': APIView.parser_classes,
+ 'authentication_classes': APIView.authentication_classes,
+ 'throttle_classes': APIView.throttle_classes,
+ 'permission_classes': APIView.permission_classes,
+ }
+ self._cached_view = None
+
+ def handler(self, *args, **kwargs):
+ return self.wrapped_view(*args, **kwargs)
+
+ @property
+ def view(self):
+ """
+ Accessor for the dynamically created class-based view. This will
+ be created if necessary and cached for next time.
+ """
- request = Request(request)
+ if self._cached_view is None:
- if request.method not in allowed_methods:
- raise exceptions.MethodNotAllowed(request.method)
+ class WrappedAPIView(APIView):
+ pass
- response = func(request, *args, **kwargs)
+ for attr, value in self.final_view_attrs.items():
+ setattr(WrappedAPIView, attr, value)
- if isinstance(response, Response):
- response.request = request
- if api_settings.FORMAT_SUFFIX_KWARG:
- response.format = kwargs.get(api_settings.FORMAT_SUFFIX_KWARG, None)
- return response
+ # Attach the wrapped view function for each of the
+ # allowed HTTP methods
+ for method in WrappedAPIView.http_method_names:
+ setattr(WrappedAPIView, method.lower(), self.handler)
- except exceptions.APIException as exc:
- return Response({'detail': exc.detail}, status=exc.status_code)
+ self._cached_view = WrappedAPIView.as_view()
- except Http404 as exc:
- return Response({'detail': 'Not found'},
- status=status.HTTP_404_NOT_FOUND)
+ return self._cached_view
- except PermissionDenied as exc:
- return Response({'detail': 'Permission denied'},
- status=status.HTTP_403_FORBIDDEN)
+ def __call__(self, *args, **kwargs):
+ """
+ This is the actual code that gets run per-request
+ """
+ return self.view(*args, **kwargs)
+
+ @staticmethod
+ def maybe_create(func_or_instance):
+ """
+ If the argument is already an instance of LazyViewCreator,
+ just return it. Otherwise, create a new one.
+ """
+ if isinstance(func_or_instance, LazyViewCreator):
+ return func_or_instance
+ return LazyViewCreator(func_or_instance)
+
+
+def _create_attribute_setting_decorator(attribute, filter=lambda item: item):
+ def decorator(value):
+ def inner(func):
+ wrapper = LazyViewCreator.maybe_create(func)
+ wrapper.final_view_attrs[attribute] = filter(value)
+ return wrapper
return inner
return decorator
+
+
+api_view = _create_attribute_setting_decorator('http_method_names', filter=lambda methods: [method.lower() for method in methods])
+renderer_classes = _create_attribute_setting_decorator('renderer_classes')
+parser_classes = _create_attribute_setting_decorator('parser_classes')
+authentication_classes = _create_attribute_setting_decorator('authentication_classes')
+throttle_classes = _create_attribute_setting_decorator('throttle_classes')
+permission_classes = _create_attribute_setting_decorator('permission_classes')