diff options
Diffstat (limited to 'rest_framework/request.py')
| -rw-r--r-- | rest_framework/request.py | 172 |
1 files changed, 140 insertions, 32 deletions
diff --git a/rest_framework/request.py b/rest_framework/request.py index 919716f4..e4b5bc26 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -4,7 +4,7 @@ The Request class is used as a wrapper around the standard request object. The wrapped request then offers a richer API, in particular : - content automatically parsed according to `Content-Type` header, - and available as `request.DATA` + and available as `request.data` - full support of PUT method, including support for file uploads - form overloading of HTTP method, content type and content """ @@ -12,11 +12,13 @@ from __future__ import unicode_literals from django.conf import settings from django.http import QueryDict from django.http.multipartparser import parse_header +from django.utils import six from django.utils.datastructures import MultiValueDict from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions -from rest_framework.compat import BytesIO from rest_framework.settings import api_settings +import sys +import warnings def is_form_media_type(media_type): @@ -28,6 +30,36 @@ def is_form_media_type(media_type): base_media_type == 'multipart/form-data') +class override_method(object): + """ + A context manager that temporarily overrides the method on a request, + additionally setting the `view.request` attribute. + + Usage: + + with override_method(view, request, 'POST') as request: + ... # Do stuff with `view` and `request` + """ + def __init__(self, view, request, method): + self.view = view + self.request = request + self.method = method + self.action = getattr(view, 'action', None) + + def __enter__(self): + self.view.request = clone_request(self.request, self.method) + if self.action is not None: + # For viewsets we also set the `.action` attribute. + action_map = getattr(self.view, 'action_map', {}) + self.view.action = action_map.get(self.method.lower()) + return self.view.request + + def __exit__(self, *args, **kwarg): + self.view.request = self.request + if self.action is not None: + self.view.action = self.action + + class Empty(object): """ Placeholder for unset attributes. @@ -52,6 +84,7 @@ def clone_request(request, method): parser_context=request.parser_context) ret._data = request._data ret._files = request._files + ret._full_data = request._full_data ret._content_type = request._content_type ret._stream = request._stream ret._method = method @@ -61,6 +94,14 @@ def clone_request(request, method): ret._auth = request._auth if hasattr(request, '_authenticator'): ret._authenticator = request._authenticator + if hasattr(request, 'accepted_renderer'): + ret.accepted_renderer = request.accepted_renderer + if hasattr(request, 'accepted_media_type'): + ret.accepted_media_type = request.accepted_media_type + if hasattr(request, 'version'): + ret.version = request.version + if hasattr(request, 'versioning_scheme'): + ret.versioning_scheme = request.versioning_scheme return ret @@ -103,6 +144,7 @@ class Request(object): self.parser_context = parser_context self._data = Empty self._files = Empty + self._full_data = Empty self._method = Empty self._content_type = Empty self._stream = Empty @@ -156,13 +198,31 @@ class Request(object): return self._stream @property - def QUERY_PARAMS(self): + def query_params(self): """ More semantically correct name for request.GET. """ return self._request.GET @property + def QUERY_PARAMS(self): + """ + Synonym for `.query_params`, for backwards compatibility. + """ + warnings.warn( + "`request.QUERY_PARAMS` is deprecated. Use `request.query_params` instead.", + DeprecationWarning, + stacklevel=1 + ) + return self._request.GET + + @property + def data(self): + if not _hasattr(self, '_full_data'): + self._load_data_and_files() + return self._full_data + + @property def DATA(self): """ Parses the request body and returns the data. @@ -170,6 +230,11 @@ class Request(object): Similar to usual behaviour of `request.POST`, except that it handles arbitrary parsers, and also works on methods other than POST (eg PUT). """ + warnings.warn( + "`request.DATA` is deprecated. Use `request.data` instead.", + DeprecationWarning, + stacklevel=1 + ) if not _hasattr(self, '_data'): self._load_data_and_files() return self._data @@ -182,6 +247,11 @@ class Request(object): Similar to usual behaviour of `request.FILES`, except that it handles arbitrary parsers, and also works on methods other than POST (eg PUT). """ + warnings.warn( + "`request.FILES` is deprecated. Use `request.data` instead.", + DeprecationWarning, + stacklevel=1 + ) if not _hasattr(self, '_files'): self._load_data_and_files() return self._files @@ -200,10 +270,14 @@ class Request(object): def user(self, value): """ Sets the user on the current request. This is necessary to maintain - compatilbility with django.contrib.auth where the user proprety is + compatibility with django.contrib.auth where the user property is set in the login and logout functions. + + Note that we also set the user on Django's underlying `HttpRequest` + instance, ensuring that it is available to any middleware in the stack. """ self._user = value + self._request.user = value @property def auth(self): @@ -222,6 +296,7 @@ class Request(object): request, such as an authentication token. """ self._auth = value + self._request.auth = value @property def successful_authenticator(self): @@ -235,13 +310,18 @@ class Request(object): def _load_data_and_files(self): """ - Parses the request content into self.DATA and self.FILES. + Parses the request content into `self.data`. """ if not _hasattr(self, '_content_type'): self._load_method_and_content_type() if not _hasattr(self, '_data'): self._data, self._files = self._parse() + if self._files: + self._full_data = self._data.copy() + self._full_data.update(self._files) + else: + self._full_data = self._data def _load_method_and_content_type(self): """ @@ -256,18 +336,20 @@ class Request(object): if not _hasattr(self, '_method'): self._method = self._request.method - if self._method == 'POST': - # Allow X-HTTP-METHOD-OVERRIDE header - self._method = self.META.get('HTTP_X_HTTP_METHOD_OVERRIDE', - self._method) + # Allow X-HTTP-METHOD-OVERRIDE header + if 'HTTP_X_HTTP_METHOD_OVERRIDE' in self.META: + self._method = self.META['HTTP_X_HTTP_METHOD_OVERRIDE'].upper() def _load_stream(self): """ Return the content body of the request, as a stream. """ try: - content_length = int(self.META.get('CONTENT_LENGTH', - self.META.get('HTTP_CONTENT_LENGTH'))) + content_length = int( + self.META.get( + 'CONTENT_LENGTH', self.META.get('HTTP_CONTENT_LENGTH') + ) + ) except (ValueError, TypeError): content_length = 0 @@ -276,7 +358,7 @@ class Request(object): elif hasattr(self._request, 'read'): self._stream = self._request else: - self._stream = BytesIO(self.raw_post_data) + self._stream = six.BytesIO(self.raw_post_data) def _perform_form_overloading(self): """ @@ -291,28 +373,36 @@ class Request(object): ) # We only need to use form overloading on form POST requests. - if (not USE_FORM_OVERLOADING - or self._request.method != 'POST' - or not is_form_media_type(self._content_type)): + if ( + self._request.method != 'POST' or + not USE_FORM_OVERLOADING or + not is_form_media_type(self._content_type) + ): return # At this point we're committed to parsing the request as form data. self._data = self._request.POST self._files = self._request.FILES + self._full_data = self._data.copy() + self._full_data.update(self._files) # Method overloading - change the method and remove the param from the content. - if (self._METHOD_PARAM and - self._METHOD_PARAM in self._data): + if ( + self._METHOD_PARAM and + self._METHOD_PARAM in self._data + ): self._method = self._data[self._METHOD_PARAM].upper() # Content overloading - modify the content type, and force re-parse. - if (self._CONTENT_PARAM and + if ( + self._CONTENT_PARAM and self._CONTENTTYPE_PARAM and self._CONTENT_PARAM in self._data and - self._CONTENTTYPE_PARAM in self._data): + self._CONTENTTYPE_PARAM in self._data + ): self._content_type = self._data[self._CONTENTTYPE_PARAM] - self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(HTTP_HEADER_ENCODING)) - self._data, self._files = (Empty, Empty) + self._stream = six.BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding'])) + self._data, self._files, self._full_data = (Empty, Empty, Empty) def _parse(self): """ @@ -324,7 +414,7 @@ class Request(object): media_type = self.content_type if stream is None or media_type is None: - empty_data = QueryDict('', self._request._encoding) + empty_data = QueryDict('', encoding=self._request._encoding) empty_files = MultiValueDict() return (empty_data, empty_files) @@ -333,7 +423,17 @@ class Request(object): if not parser: raise exceptions.UnsupportedMediaType(media_type) - parsed = parser.parse(stream, media_type, self.parser_context) + try: + parsed = parser.parse(stream, media_type, self.parser_context) + except: + # If we get an exception during parsing, fill in empty data and + # re-raise. Ensures we don't simply repeat the error when + # attempting to render the browsable renderer response, or when + # logging the request or similar. + self._data = QueryDict('', encoding=self._request._encoding) + self._files = MultiValueDict() + self._full_data = self._data + raise # Parser classes may return the raw data, or a # DataAndFiles object. Unpack the result as required. @@ -356,9 +456,9 @@ class Request(object): self._not_authenticated() raise - if not user_auth_tuple is None: + if user_auth_tuple is not None: self._authenticator = authenticator - self._user, self._auth = user_auth_tuple + self.user, self.auth = user_auth_tuple return self._not_authenticated() @@ -373,17 +473,25 @@ class Request(object): self._authenticator = None if api_settings.UNAUTHENTICATED_USER: - self._user = api_settings.UNAUTHENTICATED_USER() + self.user = api_settings.UNAUTHENTICATED_USER() else: - self._user = None + self.user = None if api_settings.UNAUTHENTICATED_TOKEN: - self._auth = api_settings.UNAUTHENTICATED_TOKEN() + self.auth = api_settings.UNAUTHENTICATED_TOKEN() else: - self._auth = None + self.auth = None - def __getattr__(self, attr): + def __getattribute__(self, attr): """ - Proxy other attributes to the underlying HttpRequest object. + If an attribute does not exist on this instance, then we also attempt + to proxy it to the underlying HttpRequest object. """ - return getattr(self._request, attr) + try: + return super(Request, self).__getattribute__(attr) + except AttributeError: + info = sys.exc_info() + try: + return getattr(self._request, attr) + except AttributeError: + six.reraise(info[0], info[1], info[2].tb_next) |
