aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/authentication.py183
-rw-r--r--rest_framework/compat.py79
-rw-r--r--rest_framework/exceptions.py5
-rw-r--r--rest_framework/parsers.py98
-rw-r--r--rest_framework/permissions.py28
-rw-r--r--rest_framework/renderers.py111
-rw-r--r--rest_framework/reverse.py12
-rw-r--r--rest_framework/settings.py15
-rw-r--r--rest_framework/utils/encoders.py65
-rw-r--r--rest_framework/versioning.py174
-rw-r--r--rest_framework/views.py56
11 files changed, 263 insertions, 563 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 4832ad33..124ef68a 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -3,14 +3,9 @@ Provides various authentication policies.
"""
from __future__ import unicode_literals
import base64
-
from django.contrib.auth import authenticate
-from django.core.exceptions import ImproperlyConfigured
from django.middleware.csrf import CsrfViewMiddleware
-from django.conf import settings
from rest_framework import exceptions, HTTP_HEADER_ENCODING
-from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
-from rest_framework.compat import oauth2_provider, provider_now, check_nonce
from rest_framework.authtoken.models import Token
@@ -178,181 +173,3 @@ class TokenAuthentication(BaseAuthentication):
def authenticate_header(self, request):
return 'Token'
-
-
-class OAuthAuthentication(BaseAuthentication):
- """
- OAuth 1.0a authentication backend using `django-oauth-plus` and `oauth2`.
-
- Note: The `oauth2` package actually provides oauth1.0a support. Urg.
- We import it from the `compat` module as `oauth`.
- """
- www_authenticate_realm = 'api'
-
- def __init__(self, *args, **kwargs):
- super(OAuthAuthentication, self).__init__(*args, **kwargs)
-
- if oauth is None:
- raise ImproperlyConfigured(
- "The 'oauth2' package could not be imported."
- "It is required for use with the 'OAuthAuthentication' class.")
-
- if oauth_provider is None:
- raise ImproperlyConfigured(
- "The 'django-oauth-plus' package could not be imported."
- "It is required for use with the 'OAuthAuthentication' class.")
-
- def authenticate(self, request):
- """
- Returns two-tuple of (user, token) if authentication succeeds,
- or None otherwise.
- """
- try:
- oauth_request = oauth_provider.utils.get_oauth_request(request)
- except oauth.Error as err:
- raise exceptions.AuthenticationFailed(err.message)
-
- if not oauth_request:
- return None
-
- oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES
-
- found = any(param for param in oauth_params if param in oauth_request)
- missing = list(param for param in oauth_params if param not in oauth_request)
-
- if not found:
- # OAuth authentication was not attempted.
- return None
-
- if missing:
- # OAuth was attempted but missing parameters.
- msg = 'Missing parameters: %s' % (', '.join(missing))
- raise exceptions.AuthenticationFailed(msg)
-
- if not self.check_nonce(request, oauth_request):
- msg = 'Nonce check failed'
- raise exceptions.AuthenticationFailed(msg)
-
- try:
- consumer_key = oauth_request.get_parameter('oauth_consumer_key')
- consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
- except oauth_provider.store.InvalidConsumerError:
- msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key')
- raise exceptions.AuthenticationFailed(msg)
-
- if consumer.status != oauth_provider.consts.ACCEPTED:
- msg = 'Invalid consumer key status: %s' % consumer.get_status_display()
- raise exceptions.AuthenticationFailed(msg)
-
- try:
- token_param = oauth_request.get_parameter('oauth_token')
- token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param)
- except oauth_provider.store.InvalidTokenError:
- msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token')
- raise exceptions.AuthenticationFailed(msg)
-
- try:
- self.validate_token(request, consumer, token)
- except oauth.Error as err:
- raise exceptions.AuthenticationFailed(err.message)
-
- user = token.user
-
- if not user.is_active:
- msg = 'User inactive or deleted: %s' % user.username
- raise exceptions.AuthenticationFailed(msg)
-
- return (token.user, token)
-
- def authenticate_header(self, request):
- """
- If permission is denied, return a '401 Unauthorized' response,
- with an appropriate 'WWW-Authenticate' header.
- """
- return 'OAuth realm="%s"' % self.www_authenticate_realm
-
- def validate_token(self, request, consumer, token):
- """
- Check the token and raise an `oauth.Error` exception if invalid.
- """
- oauth_server, oauth_request = oauth_provider.utils.initialize_server_request(request)
- oauth_server.verify_request(oauth_request, consumer, token)
-
- def check_nonce(self, request, oauth_request):
- """
- Checks nonce of request, and return True if valid.
- """
- oauth_nonce = oauth_request['oauth_nonce']
- oauth_timestamp = oauth_request['oauth_timestamp']
- return check_nonce(request, oauth_request, oauth_nonce, oauth_timestamp)
-
-
-class OAuth2Authentication(BaseAuthentication):
- """
- OAuth 2 authentication backend using `django-oauth2-provider`
- """
- www_authenticate_realm = 'api'
- allow_query_params_token = settings.DEBUG
-
- def __init__(self, *args, **kwargs):
- super(OAuth2Authentication, self).__init__(*args, **kwargs)
-
- if oauth2_provider is None:
- raise ImproperlyConfigured(
- "The 'django-oauth2-provider' package could not be imported. "
- "It is required for use with the 'OAuth2Authentication' class.")
-
- def authenticate(self, request):
- """
- Returns two-tuple of (user, token) if authentication succeeds,
- or None otherwise.
- """
-
- auth = get_authorization_header(request).split()
-
- if len(auth) == 1:
- msg = 'Invalid bearer header. No credentials provided.'
- raise exceptions.AuthenticationFailed(msg)
- elif len(auth) > 2:
- msg = 'Invalid bearer header. Token string should not contain spaces.'
- raise exceptions.AuthenticationFailed(msg)
-
- if auth and auth[0].lower() == b'bearer':
- access_token = auth[1]
- elif 'access_token' in request.POST:
- access_token = request.POST['access_token']
- elif 'access_token' in request.GET and self.allow_query_params_token:
- access_token = request.GET['access_token']
- else:
- return None
-
- return self.authenticate_credentials(request, access_token)
-
- def authenticate_credentials(self, request, access_token):
- """
- Authenticate the request, given the access token.
- """
-
- try:
- token = oauth2_provider.oauth2.models.AccessToken.objects.select_related('user')
- # provider_now switches to timezone aware datetime when
- # the oauth2_provider version supports to it.
- token = token.get(token=access_token, expires__gt=provider_now())
- except oauth2_provider.oauth2.models.AccessToken.DoesNotExist:
- raise exceptions.AuthenticationFailed('Invalid token')
-
- user = token.user
-
- if not user.is_active:
- msg = 'User inactive or deleted: %s' % user.get_username()
- raise exceptions.AuthenticationFailed(msg)
-
- return (user, token)
-
- def authenticate_header(self, request):
- """
- Bearer is the only finalized type currently
-
- Check details on the `OAuth2Authentication.authenticate` method
- """
- return 'Bearer realm="%s"' % self.www_authenticate_realm
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 69fdd793..3c8fb0da 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -5,15 +5,13 @@ versions of django/python, and compatibility wrappers around optional packages.
# flake8: noqa
from __future__ import unicode_literals
-
-import inspect
-
from django.core.exceptions import ImproperlyConfigured
+from django.conf import settings
from django.utils.encoding import force_text
from django.utils.six.moves.urllib import parse as urlparse
-from django.conf import settings
from django.utils import six
import django
+import inspect
def unicode_repr(instance):
@@ -33,6 +31,13 @@ def unicode_to_repr(value):
return value
+def unicode_http_header(value):
+ # Coerce HTTP header value to unicode.
+ if isinstance(value, six.binary_type):
+ return value.decode('iso-8859-1')
+ return value
+
+
# OrderedDict only available in Python 2.7.
# This will always be the case in Django 1.7 and above, as these versions
# no longer support Python 2.6.
@@ -207,72 +212,6 @@ except ImportError:
apply_markdown = None
-# Yaml is optional
-try:
- import yaml
-except ImportError:
- yaml = None
-
-
-# XML is optional
-try:
- import defusedxml.ElementTree as etree
-except ImportError:
- etree = None
-
-
-# OAuth2 is optional
-try:
- # Note: The `oauth2` package actually provides oauth1.0a support. Urg.
- import oauth2 as oauth
-except ImportError:
- oauth = None
-
-
-# OAuthProvider is optional
-try:
- import oauth_provider
- from oauth_provider.store import store as oauth_provider_store
-
- # check_nonce's calling signature in django-oauth-plus changes sometime
- # between versions 2.0 and 2.2.1
- def check_nonce(request, oauth_request, oauth_nonce, oauth_timestamp):
- check_nonce_args = inspect.getargspec(oauth_provider_store.check_nonce).args
- if 'timestamp' in check_nonce_args:
- return oauth_provider_store.check_nonce(
- request, oauth_request, oauth_nonce, oauth_timestamp
- )
- return oauth_provider_store.check_nonce(
- request, oauth_request, oauth_nonce
- )
-
-except (ImportError, ImproperlyConfigured):
- oauth_provider = None
- oauth_provider_store = None
- check_nonce = None
-
-
-# OAuth 2 support is optional
-try:
- import provider as oauth2_provider
- from provider import scope as oauth2_provider_scope
- from provider import constants as oauth2_constants
-
- if oauth2_provider.__version__ in ('0.2.3', '0.2.4'):
- # 0.2.3 and 0.2.4 are supported version that do not support
- # timezone aware datetimes
- import datetime
-
- provider_now = datetime.datetime.now
- else:
- # Any other supported version does use timezone aware datetimes
- from django.utils.timezone import now as provider_now
-except ImportError:
- oauth2_provider = None
- oauth2_provider_scope = None
- oauth2_constants = None
- provider_now = None
-
# `separators` argument to `json.dumps()` differs between 2.x and 3.x
# See: http://bugs.python.org/issue22767
if six.PY3:
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 1f381e4e..bcfd8961 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -89,6 +89,11 @@ class PermissionDenied(APIException):
default_detail = _('You do not have permission to perform this action.')
+class NotFound(APIException):
+ status_code = status.HTTP_404_NOT_FOUND
+ default_detail = _('Not found')
+
+
class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
default_detail = _("Method '%s' not allowed.")
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 3e3395c0..cb23423d 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -14,12 +14,9 @@ from django.http.multipartparser import MultiPartParserError, parse_header, Chun
from django.utils import six
from django.utils.six.moves.urllib import parse as urlparse
from django.utils.encoding import force_text
-from rest_framework.compat import etree, yaml
from rest_framework.exceptions import ParseError
from rest_framework import renderers
import json
-import datetime
-import decimal
class DataAndFiles(object):
@@ -67,29 +64,6 @@ class JSONParser(BaseParser):
raise ParseError('JSON parse error - %s' % six.text_type(exc))
-class YAMLParser(BaseParser):
- """
- Parses YAML-serialized data.
- """
-
- media_type = 'application/yaml'
-
- def parse(self, stream, media_type=None, parser_context=None):
- """
- Parses the incoming bytestream as YAML and returns the resulting data.
- """
- assert yaml, 'YAMLParser requires pyyaml to be installed'
-
- parser_context = parser_context or {}
- encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
-
- try:
- data = stream.read().decode(encoding)
- return yaml.safe_load(data)
- except (ValueError, yaml.parser.ParserError) as exc:
- raise ParseError('YAML parse error - %s' % six.text_type(exc))
-
-
class FormParser(BaseParser):
"""
Parser for form data.
@@ -138,78 +112,6 @@ class MultiPartParser(BaseParser):
raise ParseError('Multipart form parse error - %s' % six.text_type(exc))
-class XMLParser(BaseParser):
- """
- XML parser.
- """
-
- media_type = 'application/xml'
-
- def parse(self, stream, media_type=None, parser_context=None):
- """
- Parses the incoming bytestream as XML and returns the resulting data.
- """
- assert etree, 'XMLParser requires defusedxml to be installed'
-
- parser_context = parser_context or {}
- encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
- parser = etree.DefusedXMLParser(encoding=encoding)
- try:
- tree = etree.parse(stream, parser=parser, forbid_dtd=True)
- except (etree.ParseError, ValueError) as exc:
- raise ParseError('XML parse error - %s' % six.text_type(exc))
- data = self._xml_convert(tree.getroot())
-
- return data
-
- def _xml_convert(self, element):
- """
- convert the xml `element` into the corresponding python object
- """
-
- children = list(element)
-
- if len(children) == 0:
- return self._type_convert(element.text)
- else:
- # if the fist child tag is list-item means all children are list-item
- if children[0].tag == "list-item":
- data = []
- for child in children:
- data.append(self._xml_convert(child))
- else:
- data = {}
- for child in children:
- data[child.tag] = self._xml_convert(child)
-
- return data
-
- def _type_convert(self, value):
- """
- Converts the value returned by the XMl parse into the equivalent
- Python type
- """
- if value is None:
- return value
-
- try:
- return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
- except ValueError:
- pass
-
- try:
- return int(value)
- except ValueError:
- pass
-
- try:
- return decimal.Decimal(value)
- except decimal.InvalidOperation:
- pass
-
- return value
-
-
class FileUploadParser(BaseParser):
"""
Parser for file upload data.
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index 3f6f5961..9069d315 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -3,8 +3,7 @@ Provides a set of pluggable permission policies.
"""
from __future__ import unicode_literals
from django.http import Http404
-from rest_framework.compat import (get_model_name, oauth2_provider_scope,
- oauth2_constants)
+from rest_framework.compat import get_model_name
SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
@@ -199,28 +198,3 @@ class DjangoObjectPermissions(DjangoModelPermissions):
return False
return True
-
-
-class TokenHasReadWriteScope(BasePermission):
- """
- The request is authenticated as a user and the token used has the right scope
- """
-
- def has_permission(self, request, view):
- token = request.auth
- read_only = request.method in SAFE_METHODS
-
- if not token:
- return False
-
- if hasattr(token, 'resource'): # OAuth 1
- return read_only or not request.auth.resource.is_readonly
- elif hasattr(token, 'scope'): # OAuth 2
- required = oauth2_constants.READ if read_only else oauth2_constants.WRITE
- return oauth2_provider_scope.check(required, request.auth.scope)
-
- assert False, (
- 'TokenHasReadWriteScope requires either the'
- '`OAuthAuthentication` or `OAuth2Authentication` authentication '
- 'class to be used.'
- )
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 634338e9..c4de30db 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -17,11 +17,8 @@ from django.http.multipartparser import parse_header
from django.template import Context, RequestContext, loader, Template
from django.test.client import encode_multipart
from django.utils import six
-from django.utils.encoding import smart_text
-from django.utils.xmlutils import SimplerXMLGenerator
-from django.utils.six.moves import StringIO
from rest_framework import exceptions, serializers, status, VERSION
-from rest_framework.compat import SHORT_SEPARATORS, LONG_SEPARATORS, yaml
+from rest_framework.compat import SHORT_SEPARATORS, LONG_SEPARATORS
from rest_framework.exceptions import ParseError
from rest_framework.settings import api_settings
from rest_framework.request import is_form_media_type, override_method
@@ -112,112 +109,6 @@ class JSONRenderer(BaseRenderer):
return ret
-class JSONPRenderer(JSONRenderer):
- """
- Renderer which serializes to json,
- wrapping the json output in a callback function.
- """
-
- media_type = 'application/javascript'
- format = 'jsonp'
- callback_parameter = 'callback'
- default_callback = 'callback'
- charset = 'utf-8'
-
- def get_callback(self, renderer_context):
- """
- Determine the name of the callback to wrap around the json output.
- """
- request = renderer_context.get('request', None)
- params = request and request.query_params or {}
- return params.get(self.callback_parameter, self.default_callback)
-
- def render(self, data, accepted_media_type=None, renderer_context=None):
- """
- Renders into jsonp, wrapping the json output in a callback function.
-
- Clients may set the callback function name using a query parameter
- on the URL, for example: ?callback=exampleCallbackName
- """
- renderer_context = renderer_context or {}
- callback = self.get_callback(renderer_context)
- json = super(JSONPRenderer, self).render(data, accepted_media_type,
- renderer_context)
- return callback.encode(self.charset) + b'(' + json + b');'
-
-
-class XMLRenderer(BaseRenderer):
- """
- Renderer which serializes to XML.
- """
-
- media_type = 'application/xml'
- format = 'xml'
- charset = 'utf-8'
-
- def render(self, data, accepted_media_type=None, renderer_context=None):
- """
- Renders `data` into serialized XML.
- """
- if data is None:
- return ''
-
- stream = StringIO()
-
- xml = SimplerXMLGenerator(stream, self.charset)
- xml.startDocument()
- xml.startElement("root", {})
-
- self._to_xml(xml, data)
-
- xml.endElement("root")
- xml.endDocument()
- return stream.getvalue()
-
- def _to_xml(self, xml, data):
- if isinstance(data, (list, tuple)):
- for item in data:
- xml.startElement("list-item", {})
- self._to_xml(xml, item)
- xml.endElement("list-item")
-
- elif isinstance(data, dict):
- for key, value in six.iteritems(data):
- xml.startElement(key, {})
- self._to_xml(xml, value)
- xml.endElement(key)
-
- elif data is None:
- # Don't output any value
- pass
-
- else:
- xml.characters(smart_text(data))
-
-
-class YAMLRenderer(BaseRenderer):
- """
- Renderer which serializes to YAML.
- """
-
- media_type = 'application/yaml'
- format = 'yaml'
- encoder = encoders.SafeDumper
- charset = 'utf-8'
- ensure_ascii = False
-
- def render(self, data, accepted_media_type=None, renderer_context=None):
- """
- Renders `data` into serialized YAML.
- """
- assert yaml, 'YAMLRenderer requires pyyaml to be installed'
-
- if data is None:
- return ''
-
- return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii)
-
-
class TemplateHTMLRenderer(BaseRenderer):
"""
An HTML renderer for use with templates.
diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py
index a74e8aa2..8fcca55b 100644
--- a/rest_framework/reverse.py
+++ b/rest_framework/reverse.py
@@ -9,6 +9,18 @@ from django.utils.functional import lazy
def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
"""
+ If versioning is being used then we pass any `reverse` calls through
+ to the versioning scheme instance, so that the resulting URL
+ can be modified if needed.
+ """
+ scheme = getattr(request, 'versioning_scheme', None)
+ if scheme is not None:
+ return scheme.reverse(viewname, args, kwargs, request, format, **extra)
+ return _reverse(viewname, args, kwargs, request, format, **extra)
+
+
+def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ """
Same as `django.core.urlresolvers.reverse`, but optionally takes a request
and returns a fully qualified URL, using the request to get the base URL.
"""
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 33f84813..877d461b 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -5,11 +5,11 @@ For example your project's `settings.py` file might look like this:
REST_FRAMEWORK = {
'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
- 'rest_framework.renderers.YAMLRenderer',
+ 'rest_framework.renderers.TemplateHTMLRenderer',
)
'DEFAULT_PARSER_CLASSES': (
'rest_framework.parsers.JSONParser',
- 'rest_framework.parsers.YAMLParser',
+ 'rest_framework.parsers.TemplateHTMLRenderer',
)
}
@@ -46,6 +46,7 @@ DEFAULTS = {
'DEFAULT_THROTTLE_CLASSES': (),
'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata',
+ 'DEFAULT_VERSIONING_CLASS': None,
# Generic view behavior
'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer',
@@ -67,6 +68,11 @@ DEFAULTS = {
'SEARCH_PARAM': 'search',
'ORDERING_PARAM': 'ordering',
+ # Versioning
+ 'DEFAULT_VERSION': None,
+ 'ALLOWED_VERSIONS': None,
+ 'VERSION_PARAM': 'version',
+
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
@@ -123,6 +129,7 @@ IMPORT_STRINGS = (
'DEFAULT_THROTTLE_CLASSES',
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
'DEFAULT_METADATA_CLASS',
+ 'DEFAULT_VERSIONING_CLASS',
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'DEFAULT_FILTER_BACKENDS',
'EXCEPTION_HANDLER',
@@ -139,7 +146,9 @@ def perform_import(val, setting_name):
If the given setting is a string import notation,
then perform the necessary import or imports.
"""
- if isinstance(val, six.string_types):
+ if val is None:
+ return None
+ elif isinstance(val, six.string_types):
return import_from_string(val, setting_name)
elif isinstance(val, (list, tuple)):
return [import_from_string(item, setting_name) for item in val]
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 73cbe5d8..0bd24939 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -6,11 +6,8 @@ from django.db.models.query import QuerySet
from django.utils import six, timezone
from django.utils.encoding import force_text
from django.utils.functional import Promise
-from rest_framework.compat import OrderedDict
-from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList
import datetime
import decimal
-import types
import json
@@ -58,65 +55,3 @@ class JSONEncoder(json.JSONEncoder):
elif hasattr(obj, '__iter__'):
return tuple(item for item in obj)
return super(JSONEncoder, self).default(obj)
-
-
-try:
- import yaml
-except ImportError:
- SafeDumper = None
-else:
- # Adapted from http://pyyaml.org/attachment/ticket/161/use_ordered_dict.py
- class SafeDumper(yaml.SafeDumper):
- """
- Handles decimals as strings.
- Handles OrderedDicts as usual dicts, but preserves field order, rather
- than the usual behaviour of sorting the keys.
- """
- def represent_decimal(self, data):
- return self.represent_scalar('tag:yaml.org,2002:str', six.text_type(data))
-
- def represent_mapping(self, tag, mapping, flow_style=None):
- value = []
- node = yaml.MappingNode(tag, value, flow_style=flow_style)
- if self.alias_key is not None:
- self.represented_objects[self.alias_key] = node
- best_style = True
- if hasattr(mapping, 'items'):
- mapping = list(mapping.items())
- if not isinstance(mapping, OrderedDict):
- mapping.sort()
- for item_key, item_value in mapping:
- node_key = self.represent_data(item_key)
- node_value = self.represent_data(item_value)
- if not (isinstance(node_key, yaml.ScalarNode) and not node_key.style):
- best_style = False
- if not (isinstance(node_value, yaml.ScalarNode) and not node_value.style):
- best_style = False
- value.append((node_key, node_value))
- if flow_style is None:
- if self.default_flow_style is not None:
- node.flow_style = self.default_flow_style
- else:
- node.flow_style = best_style
- return node
-
- SafeDumper.add_representer(
- decimal.Decimal,
- SafeDumper.represent_decimal
- )
- SafeDumper.add_representer(
- OrderedDict,
- yaml.representer.SafeRepresenter.represent_dict
- )
- SafeDumper.add_representer(
- ReturnDict,
- yaml.representer.SafeRepresenter.represent_dict
- )
- SafeDumper.add_representer(
- ReturnList,
- yaml.representer.SafeRepresenter.represent_list
- )
- SafeDumper.add_representer(
- types.GeneratorType,
- yaml.representer.SafeRepresenter.represent_list
- )
diff --git a/rest_framework/versioning.py b/rest_framework/versioning.py
new file mode 100644
index 00000000..440efd13
--- /dev/null
+++ b/rest_framework/versioning.py
@@ -0,0 +1,174 @@
+# coding: utf-8
+from __future__ import unicode_literals
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import exceptions
+from rest_framework.compat import unicode_http_header
+from rest_framework.reverse import _reverse
+from rest_framework.settings import api_settings
+from rest_framework.templatetags.rest_framework import replace_query_param
+from rest_framework.utils.mediatypes import _MediaType
+import re
+
+
+class BaseVersioning(object):
+ default_version = api_settings.DEFAULT_VERSION
+ allowed_versions = api_settings.ALLOWED_VERSIONS
+ version_param = api_settings.VERSION_PARAM
+
+ def determine_version(self, request, *args, **kwargs):
+ msg = '{cls}.determine_version() must be implemented.'
+ raise NotImplemented(msg.format(
+ cls=self.__class__.__name__
+ ))
+
+ def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ return _reverse(viewname, args, kwargs, request, format, **extra)
+
+ def is_allowed_version(self, version):
+ if not self.allowed_versions:
+ return True
+ return (version == self.default_version) or (version in self.allowed_versions)
+
+
+class AcceptHeaderVersioning(BaseVersioning):
+ """
+ GET /something/ HTTP/1.1
+ Host: example.com
+ Accept: application/json; version=1.0
+ """
+ invalid_version_message = _("Invalid version in 'Accept' header.")
+
+ def determine_version(self, request, *args, **kwargs):
+ media_type = _MediaType(request.accepted_media_type)
+ version = media_type.params.get(self.version_param, self.default_version)
+ version = unicode_http_header(version)
+ if not self.is_allowed_version(version):
+ raise exceptions.NotAcceptable(self.invalid_version_message)
+ return version
+
+ # We don't need to implement `reverse`, as the versioning is based
+ # on the `Accept` header, not on the request URL.
+
+
+class URLPathVersioning(BaseVersioning):
+ """
+ To the client this is the same style as `NamespaceVersioning`.
+ The difference is in the backend - this implementation uses
+ Django's URL keyword arguments to determine the version.
+
+ An example URL conf for two views that accept two different versions.
+
+ urlpatterns = [
+ url(r'^(?P<version>{v1,v2})/users/$', users_list, name='users-list'),
+ url(r'^(?P<version>{v1,v2})/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail')
+ ]
+
+ GET /1.0/something/ HTTP/1.1
+ Host: example.com
+ Accept: application/json
+ """
+ invalid_version_message = _('Invalid version in URL path.')
+
+ def determine_version(self, request, *args, **kwargs):
+ version = kwargs.get(self.version_param, self.default_version)
+ if not self.is_allowed_version(version):
+ raise exceptions.NotFound(self.invalid_version_message)
+ return version
+
+ def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ if request.version is not None:
+ kwargs = {} if (kwargs is None) else kwargs
+ kwargs[self.version_param] = request.version
+
+ return super(URLPathVersioning, self).reverse(
+ viewname, args, kwargs, request, format, **extra
+ )
+
+
+class NamespaceVersioning(BaseVersioning):
+ """
+ To the client this is the same style as `URLPathVersioning`.
+ The difference is in the backend - this implementation uses
+ Django's URL namespaces to determine the version.
+
+ An example URL conf that is namespaced into two seperate versions
+
+ # users/urls.py
+ urlpatterns = [
+ url(r'^/users/$', users_list, name='users-list'),
+ url(r'^/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail')
+ ]
+
+ # urls.py
+ urlpatterns = [
+ url(r'^v1/', include('users.urls', namespace='v1')),
+ url(r'^v2/', include('users.urls', namespace='v2'))
+ ]
+
+ GET /1.0/something/ HTTP/1.1
+ Host: example.com
+ Accept: application/json
+ """
+ invalid_version_message = _('Invalid version in URL path.')
+
+ def determine_version(self, request, *args, **kwargs):
+ resolver_match = getattr(request, 'resolver_match', None)
+ if (resolver_match is None or not resolver_match.namespace):
+ return self.default_version
+ version = resolver_match.namespace
+ if not self.is_allowed_version(version):
+ raise exceptions.NotFound(self.invalid_version_message)
+ return version
+
+ def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ if request.version is not None:
+ viewname = request.version + ':' + viewname
+ return super(NamespaceVersioning, self).reverse(
+ viewname, args, kwargs, request, format, **extra
+ )
+
+
+class HostNameVersioning(BaseVersioning):
+ """
+ GET /something/ HTTP/1.1
+ Host: v1.example.com
+ Accept: application/json
+ """
+ hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$')
+ invalid_version_message = _('Invalid version in hostname.')
+
+ def determine_version(self, request, *args, **kwargs):
+ hostname, seperator, port = request.get_host().partition(':')
+ match = self.hostname_regex.match(hostname)
+ if not match:
+ return self.default_version
+ version = match.group(1)
+ if not self.is_allowed_version(version):
+ raise exceptions.NotFound(self.invalid_version_message)
+ return version
+
+ # We don't need to implement `reverse`, as the hostname will already be
+ # preserved as part of the REST framework `reverse` implementation.
+
+
+class QueryParameterVersioning(BaseVersioning):
+ """
+ GET /something/?version=0.1 HTTP/1.1
+ Host: example.com
+ Accept: application/json
+ """
+ invalid_version_message = _('Invalid version in query parameter.')
+
+ def determine_version(self, request, *args, **kwargs):
+ version = request.query_params.get(self.version_param)
+ if not self.is_allowed_version(version):
+ raise exceptions.NotFound(self.invalid_version_message)
+ return version
+
+ def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ url = super(QueryParameterVersioning, self).reverse(
+ viewname, args, kwargs, request, format, **extra
+ )
+ if request.version is not None:
+ return replace_query_param(url, self.version_param, request.version)
+ return url
diff --git a/rest_framework/views.py b/rest_framework/views.py
index bc870417..12bb78bd 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -2,6 +2,8 @@
Provides an APIView class that is the base of all views in REST framework.
"""
from __future__ import unicode_literals
+import inspect
+import warnings
from django.core.exceptions import PermissionDenied
from django.http import Http404
@@ -46,7 +48,7 @@ def get_view_description(view_cls, html=False):
return description
-def exception_handler(exc):
+def exception_handler(exc, context):
"""
Returns the response that should be used for any given exception.
@@ -93,6 +95,7 @@ class APIView(View):
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
metadata_class = api_settings.DEFAULT_METADATA_CLASS
+ versioning_class = api_settings.DEFAULT_VERSIONING_CLASS
# Allow dependency injection of other settings to make testing easier.
settings = api_settings
@@ -184,6 +187,18 @@ class APIView(View):
'request': getattr(self, 'request', None)
}
+ def get_exception_handler_context(self):
+ """
+ Returns a dict that is passed through to EXCEPTION_HANDLER,
+ as the `context` argument.
+ """
+ return {
+ 'view': self,
+ 'args': getattr(self, 'args', ()),
+ 'kwargs': getattr(self, 'kwargs', {}),
+ 'request': getattr(self, 'request', None)
+ }
+
def get_view_name(self):
"""
Return the view name, as used in OPTIONS responses and in the
@@ -300,6 +315,16 @@ class APIView(View):
if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait())
+ def determine_version(self, request, *args, **kwargs):
+ """
+ If versioning is being used, then determine any API version for the
+ incoming request. Returns a two-tuple of (version, versioning_scheme)
+ """
+ if self.versioning_class is None:
+ return (None, None)
+ scheme = self.versioning_class()
+ return (scheme.determine_version(request, *args, **kwargs), scheme)
+
# Dispatch methods
def initialize_request(self, request, *args, **kwargs):
@@ -308,11 +333,13 @@ class APIView(View):
"""
parser_context = self.get_parser_context(request)
- return Request(request,
- parsers=self.get_parsers(),
- authenticators=self.get_authenticators(),
- negotiator=self.get_content_negotiator(),
- parser_context=parser_context)
+ return Request(
+ request,
+ parsers=self.get_parsers(),
+ authenticators=self.get_authenticators(),
+ negotiator=self.get_content_negotiator(),
+ parser_context=parser_context
+ )
def initial(self, request, *args, **kwargs):
"""
@@ -329,6 +356,10 @@ class APIView(View):
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
+ # Determine the API version, if versioning is in use.
+ version, scheme = self.determine_version(request, *args, **kwargs)
+ request.version, request.versioning_scheme = version, scheme
+
def finalize_response(self, request, response, *args, **kwargs):
"""
Returns the final response object.
@@ -369,7 +400,18 @@ class APIView(View):
else:
exc.status_code = status.HTTP_403_FORBIDDEN
- response = self.settings.EXCEPTION_HANDLER(exc)
+ exception_handler = self.settings.EXCEPTION_HANDLER
+
+ if len(inspect.getargspec(exception_handler).args) == 1:
+ warnings.warn(
+ 'The `exception_handler(exc)` call signature is deprecated. '
+ 'Use `exception_handler(exc, context) instead.',
+ PendingDeprecationWarning
+ )
+ response = exception_handler(exc)
+ else:
+ context = self.get_exception_handler_context()
+ response = exception_handler(exc, context)
if response is None:
raise