diff options
| author | Tom Christie | 2012-10-05 14:48:33 +0100 | 
|---|---|---|
| committer | Tom Christie | 2012-10-05 14:48:33 +0100 | 
| commit | 9d8bce8f5b0915223f57d9fe3d4b63029cfc64c2 (patch) | |
| tree | a0a3f9e5a80335dcba3315f81b498e3aed241dcb | |
| parent | 3e862c77379b2f84356e2e8f0be20b7aca5b9e89 (diff) | |
| download | django-rest-framework-9d8bce8f5b0915223f57d9fe3d4b63029cfc64c2.tar.bz2 | |
Remove Parser.can_handle_request()
| -rw-r--r-- | rest_framework/negotiation.py | 10 | ||||
| -rw-r--r-- | rest_framework/parsers.py | 17 | ||||
| -rw-r--r-- | rest_framework/request.py | 52 | ||||
| -rw-r--r-- | rest_framework/tests/decorators.py | 8 | ||||
| -rw-r--r-- | rest_framework/tests/request.py | 12 | ||||
| -rw-r--r-- | rest_framework/views.py | 38 | 
6 files changed, 79 insertions, 58 deletions
| diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index 73ae7899..8b22f669 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -11,6 +11,16 @@ class BaseContentNegotiation(object):  class DefaultContentNegotiation(object):      settings = api_settings +    def select_parser(self, parsers, media_type): +        """ +        Given a list of parsers and a media type, return the appropriate +        parser to handle the incoming request. +        """ +        for parser in parsers: +            if media_type_matches(parser.media_type, media_type): +                return parser +        return None +      def negotiate(self, request, renderers, format=None, force=False):          """          Given a request and a list of renderers, return a two-tuple of: diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 5151b252..5325a64b 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -15,11 +15,9 @@ from django.http import QueryDict  from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser  from django.http.multipartparser import MultiPartParserError  from django.utils import simplejson as json -from rest_framework.compat import yaml +from rest_framework.compat import yaml, ETParseError  from rest_framework.exceptions import ParseError -from rest_framework.utils.mediatypes import media_type_matches  from xml.etree import ElementTree as ET -from rest_framework.compat import ETParseError  from xml.parsers.expat import ExpatError  import datetime  import decimal @@ -40,19 +38,6 @@ class BaseParser(object):      media_type = None -    def can_handle_request(self, content_type): -        """ -        Returns :const:`True` if this parser is able to deal with the given *content_type*. - -        The default implementation for this function is to check the *content_type* -        argument against the :attr:`media_type` attribute set on the class to see if -        they match. - -        This may be overridden to provide for other behavior, but typically you'll -        instead want to just set the :attr:`media_type` attribute on the class. -        """ -        return media_type_matches(self.media_type, content_type) -      def parse(self, string_or_stream, **opts):          """          The main entry point to parsers.  This is a light wrapper around diff --git a/rest_framework/request.py b/rest_framework/request.py index e254cf8e..ac15defc 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -34,8 +34,8 @@ def clone_request(request, method):      HTTP method.  Used for checking permissions against other methods.      """      ret = Request(request._request, -                  request.parser_classes, -                  request.authentication_classes) +                  request.parsers, +                  request.authenticators)      ret._data = request._data      ret._files = request._files      ret._content_type = request._content_type @@ -60,27 +60,20 @@ class Request(object):      _CONTENT_PARAM = api_settings.FORM_CONTENT_OVERRIDE      _CONTENTTYPE_PARAM = api_settings.FORM_CONTENTTYPE_OVERRIDE -    def __init__(self, request, parser_classes=None, authentication_classes=None): +    def __init__(self, request, parsers=None, authenticators=None, +                 negotiator=None):          self._request = request -        self.parser_classes = parser_classes or () -        self.authentication_classes = authentication_classes or () +        self.parsers = parsers or () +        self.authenticators = authenticators or () +        self.negotiator = negotiator or self._default_negotiator()          self._data = Empty          self._files = Empty          self._method = Empty          self._content_type = Empty          self._stream = Empty -    def get_parsers(self): -        """ -        Instantiates and returns the list of parsers the request will use. -        """ -        return [parser() for parser in self.parser_classes] - -    def get_authentications(self): -        """ -        Instantiates and returns the list of parsers the request will use. -        """ -        return [authentication() for authentication in self.authentication_classes] +    def _default_negotiator(self): +        return api_settings.DEFAULT_CONTENT_NEGOTIATION()      @property      def method(self): @@ -254,26 +247,27 @@ class Request(object):          if self.stream is None or self.content_type is None:              return (None, None) -        for parser in self.get_parsers(): -            if parser.can_handle_request(self.content_type): -                parsed = parser.parse(self.stream, meta=self.META, -                                      upload_handlers=self.upload_handlers) -                # Parser classes may return the raw data, or a -                # DataAndFiles object.  Unpack the result as required. -                try: -                    return (parsed.data, parsed.files) -                except AttributeError: -                    return (parsed, None) +        parser = self.negotiator.select_parser(self.parsers, self.content_type) + +        if not parser: +            raise exceptions.UnsupportedMediaType(self._content_type) -        raise exceptions.UnsupportedMediaType(self._content_type) +        parsed = parser.parse(self.stream, meta=self.META, +                              upload_handlers=self.upload_handlers) +        # Parser classes may return the raw data, or a +        # DataAndFiles object.  Unpack the result as required. +        try: +            return (parsed.data, parsed.files) +        except AttributeError: +            return (parsed, None)      def _authenticate(self):          """          Attempt to authenticate the request using each authentication instance in turn.          Returns a two-tuple of (user, authtoken).          """ -        for authentication in self.get_authentications(): -            user_auth_tuple = authentication.authenticate(self) +        for authenticator in self.authenticators: +            user_auth_tuple = authenticator.authenticate(self)              if not user_auth_tuple is None:                  return user_auth_tuple          return self._not_authenticated() diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index e943d8fe..a3217bd6 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -65,7 +65,9 @@ class DecoratorTestCase(TestCase):          @api_view(['GET'])          @parser_classes([JSONParser])          def view(request): -            self.assertEqual(request.parser_classes, [JSONParser]) +            self.assertEqual(len(request.parsers), 1) +            self.assertTrue(isinstance(request.parsers[0], +                                       JSONParser))              return Response({})          request = self.factory.get('/') @@ -76,7 +78,9 @@ class DecoratorTestCase(TestCase):          @api_view(['GET'])          @authentication_classes([BasicAuthentication])          def view(request): -            self.assertEqual(request.authentication_classes, [BasicAuthentication]) +            self.assertEqual(len(request.authenticators), 1) +            self.assertTrue(isinstance(request.authenticators[0], +                                       BasicAuthentication))              return Response({})          request = self.factory.get('/') diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index 42274fcd..f5c63f11 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -61,7 +61,7 @@ class TestContentParsing(TestCase):          """          data = {'qwerty': 'uiop'}          request = Request(factory.post('/', data)) -        request.parser_classes = (FormParser, MultiPartParser) +        request.parsers = (FormParser(), MultiPartParser())          self.assertEqual(request.DATA.items(), data.items())      def test_request_DATA_with_text_content(self): @@ -72,7 +72,7 @@ class TestContentParsing(TestCase):          content = 'qwerty'          content_type = 'text/plain'          request = Request(factory.post('/', content, content_type=content_type)) -        request.parser_classes = (PlainTextParser,) +        request.parsers = (PlainTextParser(),)          self.assertEqual(request.DATA, content)      def test_request_POST_with_form_content(self): @@ -81,7 +81,7 @@ class TestContentParsing(TestCase):          """          data = {'qwerty': 'uiop'}          request = Request(factory.post('/', data)) -        request.parser_classes = (FormParser, MultiPartParser) +        request.parsers = (FormParser(), MultiPartParser())          self.assertEqual(request.POST.items(), data.items())      def test_standard_behaviour_determines_form_content_PUT(self): @@ -99,7 +99,7 @@ class TestContentParsing(TestCase):          else:              request = Request(factory.put('/', data)) -        request.parser_classes = (FormParser, MultiPartParser) +        request.parsers = (FormParser(), MultiPartParser())          self.assertEqual(request.DATA.items(), data.items())      def test_standard_behaviour_determines_non_form_content_PUT(self): @@ -110,7 +110,7 @@ class TestContentParsing(TestCase):          content = 'qwerty'          content_type = 'text/plain'          request = Request(factory.put('/', content, content_type=content_type)) -        request.parser_classes = (PlainTextParser, ) +        request.parsers = (PlainTextParser(), )          self.assertEqual(request.DATA, content)      def test_overloaded_behaviour_allows_content_tunnelling(self): @@ -124,7 +124,7 @@ class TestContentParsing(TestCase):              Request._CONTENTTYPE_PARAM: content_type          }          request = Request(factory.post('/', data)) -        request.parser_classes = (PlainTextParser, ) +        request.parsers = (PlainTextParser(), )          self.assertEqual(request.DATA, content)      # def test_accessing_post_after_data_form(self): diff --git a/rest_framework/views.py b/rest_framework/views.py index 166bf0b1..0aa1dd0d 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -70,6 +70,7 @@ class APIView(View):          as an attribute on the callable function.  This allows us to discover          information about the view when we do URL reverse lookups.          """ +        # TODO: deprecate?          view = super(APIView, cls).as_view(**initkwargs)          view.cls_instance = cls(**initkwargs)          return view @@ -84,6 +85,7 @@ class APIView(View):      @property      def default_response_headers(self): +        # TODO: Only vary by accept if multiple renderers          return {              'Allow': ', '.join(self.allowed_methods),              'Vary': 'Accept' @@ -94,6 +96,7 @@ class APIView(View):          Return the resource or view class name for use as this view's name.          Override to customize.          """ +        # TODO: deprecate?          name = self.__class__.__name__          name = _remove_trailing_string(name, 'View')          return _camelcase_to_spaces(name) @@ -103,6 +106,7 @@ class APIView(View):          Return the resource or view docstring for use as this view's description.          Override to customize.          """ +        # TODO: deprecate?          description = self.__doc__ or ''          description = _remove_leading_indent(description)          if html: @@ -113,6 +117,7 @@ class APIView(View):          """          Apply HTML markup to the description of this view.          """ +        # TODO: deprecate?          if apply_markdown:              description = apply_markdown(description)          else: @@ -137,6 +142,8 @@ class APIView(View):          """          raise exceptions.Throttled(wait) +    # API policy instantiation methods +      def get_format_suffix(self, **kwargs):          """          Determine if the request includes a '.json' style format suffix @@ -144,12 +151,24 @@ class APIView(View):          if self.settings.FORMAT_SUFFIX_KWARG:              return kwargs.get(self.settings.FORMAT_SUFFIX_KWARG) -    def get_renderers(self, format=None): +    def get_renderers(self):          """          Instantiates and returns the list of renderers that this view can use.          """          return [renderer(self) for renderer in self.renderer_classes] +    def get_parsers(self): +        """ +        Instantiates and returns the list of renderers that this view can use. +        """ +        return [parser() for parser in self.parser_classes] + +    def get_authenticators(self): +        """ +        Instantiates and returns the list of renderers that this view can use. +        """ +        return [auth() for auth in self.authentication_classes] +      def get_permissions(self):          """          Instantiates and returns the list of permissions that this view requires. @@ -166,7 +185,11 @@ class APIView(View):          """          Instantiate and return the content negotiation class to use.          """ -        return self.content_negotiation_class() +        if not getattr(self, '_negotiator', None): +            self._negotiator = self.content_negotiation_class() +        return self._negotiator + +    # API policy implementation methods      def perform_content_negotiation(self, request, force=False):          """ @@ -193,19 +216,24 @@ class APIView(View):              if not throttle.allow_request(request):                  self.throttled(request, throttle.wait()) +    # Dispatch methods +      def initialize_request(self, request, *args, **kargs):          """          Returns the initial request object.          """ -        return Request(request, parser_classes=self.parser_classes, -                       authentication_classes=self.authentication_classes) +        return Request(request, +                       parsers=self.get_parsers(), +                       authenticators=self.get_authenticators(), +                       negotiator=self.get_content_negotiator())      def initial(self, request, *args, **kwargs):          """ -        Runs anything that needs to occur prior to calling the method handlers. +        Runs anything that needs to occur prior to calling the method handler.          """          self.format_kwarg = self.get_format_suffix(**kwargs) +        # Ensure that the incoming request is permitted          if not self.has_permission(request):              self.permission_denied(request)          self.check_throttles(request) | 
