diff options
Diffstat (limited to 'rest_framework/request.py')
| -rw-r--r-- | rest_framework/request.py | 38 | 
1 files changed, 27 insertions, 11 deletions
| diff --git a/rest_framework/request.py b/rest_framework/request.py index 40467c03..27532661 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -42,13 +42,20 @@ class override_method(object):          self.view = view          self.request = request          self.method = method +        self.action = getattr(view, 'action', None)      def __enter__(self):          self.view.request = clone_request(self.request, self.method) +        if self.action is not None: +            # For viewsets we also set the `.action` attribute. +            action_map = getattr(self.view, 'action_map', {}) +            self.view.action = action_map.get(self.method.lower())          return self.view.request      def __exit__(self, *args, **kwarg):          self.view.request = self.request +        if self.action is not None: +            self.view.action = self.action  class Empty(object): @@ -280,16 +287,19 @@ class Request(object):              self._method = self._request.method              # Allow X-HTTP-METHOD-OVERRIDE header -            self._method = self.META.get('HTTP_X_HTTP_METHOD_OVERRIDE', -                                         self._method) +            if 'HTTP_X_HTTP_METHOD_OVERRIDE' in self.META: +                self._method = self.META['HTTP_X_HTTP_METHOD_OVERRIDE'].upper()      def _load_stream(self):          """          Return the content body of the request, as a stream.          """          try: -            content_length = int(self.META.get('CONTENT_LENGTH', -                                    self.META.get('HTTP_CONTENT_LENGTH'))) +            content_length = int( +                self.META.get( +                    'CONTENT_LENGTH', self.META.get('HTTP_CONTENT_LENGTH') +                ) +            )          except (ValueError, TypeError):              content_length = 0 @@ -313,9 +323,11 @@ class Request(object):          )          # We only need to use form overloading on form POST requests. -        if (not USE_FORM_OVERLOADING +        if ( +            not USE_FORM_OVERLOADING              or self._request.method != 'POST' -            or not is_form_media_type(self._content_type)): +            or not is_form_media_type(self._content_type) +        ):              return          # At this point we're committed to parsing the request as form data. @@ -323,15 +335,19 @@ class Request(object):          self._files = self._request.FILES          # Method overloading - change the method and remove the param from the content. -        if (self._METHOD_PARAM and -            self._METHOD_PARAM in self._data): +        if ( +            self._METHOD_PARAM and +            self._METHOD_PARAM in self._data +        ):              self._method = self._data[self._METHOD_PARAM].upper()          # Content overloading - modify the content type, and force re-parse. -        if (self._CONTENT_PARAM and +        if ( +            self._CONTENT_PARAM and              self._CONTENTTYPE_PARAM and              self._CONTENT_PARAM in self._data and -            self._CONTENTTYPE_PARAM in self._data): +            self._CONTENTTYPE_PARAM in self._data +        ):              self._content_type = self._data[self._CONTENTTYPE_PARAM]              self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))              self._data, self._files = (Empty, Empty) @@ -387,7 +403,7 @@ class Request(object):                  self._not_authenticated()                  raise -            if not user_auth_tuple is None: +            if user_auth_tuple is not None:                  self._authenticator = authenticator                  self._user, self._auth = user_auth_tuple                  return | 
