diff options
| -rw-r--r-- | rest_framework/decorators.py | 9 | ||||
| -rw-r--r-- | rest_framework/tests/decorators.py | 22 | 
2 files changed, 31 insertions, 0 deletions
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 1b710a03..7a4103e1 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,4 +1,5 @@  from rest_framework.views import APIView +import types  def api_view(http_method_names): @@ -23,6 +24,14 @@ def api_view(http_method_names):          #         pass          #     WrappedAPIView.__doc__ = func.doc    <--- Not possible to do this +        # api_view applied without (method_names) +        assert not(isinstance(http_method_names, types.FunctionType)), \ +            '@api_view missing list of allowed HTTP methods' + +        # api_view applied with eg. string instead of list of strings +        assert isinstance(http_method_names, (list, tuple)), \ +            '@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__ +          allowed_methods = set(http_method_names) | set(('options',))          WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index 4012188d..82f912e9 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -28,6 +28,28 @@ class DecoratorTestCase(TestCase):          response.request = request          return APIView.finalize_response(self, request, response, *args, **kwargs) +    def test_api_view_incorrect(self): +        """ +        If @api_view is not applied correct, we should raise an assertion. +        """ + +        @api_view +        def view(request): +            return Response() + +        request = self.factory.get('/') +        self.assertRaises(AssertionError, view, request) + +    def test_api_view_incorrect_arguments(self): +        """ +        If @api_view is missing arguments, we should raise an assertion. +        """ + +        with self.assertRaises(AssertionError): +            @api_view('GET') +            def view(request): +                return Response() +      def test_calling_method(self):          @api_view(['GET'])  | 
