aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/versioning.py
diff options
context:
space:
mode:
authorTom Christie2014-12-17 12:41:46 +0000
committerTom Christie2014-12-17 12:41:46 +0000
commit05a6eaec8aebdca2248b9e1069a15769fd85a480 (patch)
tree58488ee2e6533032d942dc65f038bbbc43462a87 /rest_framework/versioning.py
parent70bd3a32f7cf57543e8ec08fddf001a718e40c7f (diff)
downloaddjango-rest-framework-05a6eaec8aebdca2248b9e1069a15769fd85a480.tar.bz2
More docs, plus 'ALLOWED_VERSIONS' setting.
Diffstat (limited to 'rest_framework/versioning.py')
-rw-r--r--rest_framework/versioning.py120
1 files changed, 72 insertions, 48 deletions
diff --git a/rest_framework/versioning.py b/rest_framework/versioning.py
index 223d0f61..440efd13 100644
--- a/rest_framework/versioning.py
+++ b/rest_framework/versioning.py
@@ -1,13 +1,20 @@
# coding: utf-8
from __future__ import unicode_literals
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import exceptions
from rest_framework.compat import unicode_http_header
from rest_framework.reverse import _reverse
+from rest_framework.settings import api_settings
from rest_framework.templatetags.rest_framework import replace_query_param
from rest_framework.utils.mediatypes import _MediaType
import re
class BaseVersioning(object):
+ default_version = api_settings.DEFAULT_VERSION
+ allowed_versions = api_settings.ALLOWED_VERSIONS
+ version_param = api_settings.VERSION_PARAM
+
def determine_version(self, request, *args, **kwargs):
msg = '{cls}.determine_version() must be implemented.'
raise NotImplemented(msg.format(
@@ -17,46 +24,10 @@ class BaseVersioning(object):
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
return _reverse(viewname, args, kwargs, request, format, **extra)
-
-class QueryParameterVersioning(BaseVersioning):
- """
- GET /something/?version=0.1 HTTP/1.1
- Host: example.com
- Accept: application/json
- """
- default_version = None
- version_param = 'version'
-
- def determine_version(self, request, *args, **kwargs):
- return request.query_params.get(self.version_param)
-
- def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
- url = super(QueryParameterVersioning, self).reverse(
- viewname, args, kwargs, request, format, **extra
- )
- if request.version is not None:
- return replace_query_param(url, self.version_param, request.version)
- return url
-
-
-class HostNameVersioning(BaseVersioning):
- """
- GET /something/ HTTP/1.1
- Host: v1.example.com
- Accept: application/json
- """
- default_version = None
- hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$')
-
- def determine_version(self, request, *args, **kwargs):
- hostname, seperator, port = request.get_host().partition(':')
- match = self.hostname_regex.match(hostname)
- if not match:
- return self.default_version
- return match.group(1)
-
- # We don't need to implement `reverse`, as the hostname will already be
- # preserved as part of the REST framework `reverse` implementation.
+ def is_allowed_version(self, version):
+ if not self.allowed_versions:
+ return True
+ return (version == self.default_version) or (version in self.allowed_versions)
class AcceptHeaderVersioning(BaseVersioning):
@@ -65,13 +36,15 @@ class AcceptHeaderVersioning(BaseVersioning):
Host: example.com
Accept: application/json; version=1.0
"""
- default_version = None
- version_param = 'version'
+ invalid_version_message = _("Invalid version in 'Accept' header.")
def determine_version(self, request, *args, **kwargs):
media_type = _MediaType(request.accepted_media_type)
version = media_type.params.get(self.version_param, self.default_version)
- return unicode_http_header(version)
+ version = unicode_http_header(version)
+ if not self.is_allowed_version(version):
+ raise exceptions.NotAcceptable(self.invalid_version_message)
+ return version
# We don't need to implement `reverse`, as the versioning is based
# on the `Accept` header, not on the request URL.
@@ -94,11 +67,13 @@ class URLPathVersioning(BaseVersioning):
Host: example.com
Accept: application/json
"""
- default_version = None
- version_param = 'version'
+ invalid_version_message = _('Invalid version in URL path.')
def determine_version(self, request, *args, **kwargs):
- return kwargs.get(self.version_param, self.default_version)
+ version = kwargs.get(self.version_param, self.default_version)
+ if not self.is_allowed_version(version):
+ raise exceptions.NotFound(self.invalid_version_message)
+ return version
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
if request.version is not None:
@@ -134,13 +109,16 @@ class NamespaceVersioning(BaseVersioning):
Host: example.com
Accept: application/json
"""
- default_version = None
+ invalid_version_message = _('Invalid version in URL path.')
def determine_version(self, request, *args, **kwargs):
resolver_match = getattr(request, 'resolver_match', None)
if (resolver_match is None or not resolver_match.namespace):
return self.default_version
- return resolver_match.namespace
+ version = resolver_match.namespace
+ if not self.is_allowed_version(version):
+ raise exceptions.NotFound(self.invalid_version_message)
+ return version
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
if request.version is not None:
@@ -148,3 +126,49 @@ class NamespaceVersioning(BaseVersioning):
return super(NamespaceVersioning, self).reverse(
viewname, args, kwargs, request, format, **extra
)
+
+
+class HostNameVersioning(BaseVersioning):
+ """
+ GET /something/ HTTP/1.1
+ Host: v1.example.com
+ Accept: application/json
+ """
+ hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$')
+ invalid_version_message = _('Invalid version in hostname.')
+
+ def determine_version(self, request, *args, **kwargs):
+ hostname, seperator, port = request.get_host().partition(':')
+ match = self.hostname_regex.match(hostname)
+ if not match:
+ return self.default_version
+ version = match.group(1)
+ if not self.is_allowed_version(version):
+ raise exceptions.NotFound(self.invalid_version_message)
+ return version
+
+ # We don't need to implement `reverse`, as the hostname will already be
+ # preserved as part of the REST framework `reverse` implementation.
+
+
+class QueryParameterVersioning(BaseVersioning):
+ """
+ GET /something/?version=0.1 HTTP/1.1
+ Host: example.com
+ Accept: application/json
+ """
+ invalid_version_message = _('Invalid version in query parameter.')
+
+ def determine_version(self, request, *args, **kwargs):
+ version = request.query_params.get(self.version_param)
+ if not self.is_allowed_version(version):
+ raise exceptions.NotFound(self.invalid_version_message)
+ return version
+
+ def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ url = super(QueryParameterVersioning, self).reverse(
+ viewname, args, kwargs, request, format, **extra
+ )
+ if request.version is not None:
+ return replace_query_param(url, self.version_param, request.version)
+ return url