diff options
| author | Tom Christie | 2014-12-17 12:41:46 +0000 | 
|---|---|---|
| committer | Tom Christie | 2014-12-17 12:41:46 +0000 | 
| commit | 05a6eaec8aebdca2248b9e1069a15769fd85a480 (patch) | |
| tree | 58488ee2e6533032d942dc65f038bbbc43462a87 /rest_framework | |
| parent | 70bd3a32f7cf57543e8ec08fddf001a718e40c7f (diff) | |
| download | django-rest-framework-05a6eaec8aebdca2248b9e1069a15769fd85a480.tar.bz2 | |
More docs, plus 'ALLOWED_VERSIONS' setting.
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/exceptions.py | 5 | ||||
| -rw-r--r-- | rest_framework/settings.py | 5 | ||||
| -rw-r--r-- | rest_framework/versioning.py | 120 | 
3 files changed, 82 insertions, 48 deletions
| diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index be41d08d..238934db 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -89,6 +89,11 @@ class PermissionDenied(APIException):      default_detail = _('You do not have permission to perform this action.') +class NotFound(APIException): +    status_code = status.HTTP_404_NOT_FOUND +    default_detail = _('Not found') + +  class MethodNotAllowed(APIException):      status_code = status.HTTP_405_METHOD_NOT_ALLOWED      default_detail = _("Method '%s' not allowed.") diff --git a/rest_framework/settings.py b/rest_framework/settings.py index da3be38d..877d461b 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -68,6 +68,11 @@ DEFAULTS = {      'SEARCH_PARAM': 'search',      'ORDERING_PARAM': 'ordering', +    # Versioning +    'DEFAULT_VERSION': None, +    'ALLOWED_VERSIONS': None, +    'VERSION_PARAM': 'version', +      # Authentication      'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',      'UNAUTHENTICATED_TOKEN': None, diff --git a/rest_framework/versioning.py b/rest_framework/versioning.py index 223d0f61..440efd13 100644 --- a/rest_framework/versioning.py +++ b/rest_framework/versioning.py @@ -1,13 +1,20 @@  # coding: utf-8  from __future__ import unicode_literals +from django.utils.translation import ugettext_lazy as _ +from rest_framework import exceptions  from rest_framework.compat import unicode_http_header  from rest_framework.reverse import _reverse +from rest_framework.settings import api_settings  from rest_framework.templatetags.rest_framework import replace_query_param  from rest_framework.utils.mediatypes import _MediaType  import re  class BaseVersioning(object): +    default_version = api_settings.DEFAULT_VERSION +    allowed_versions = api_settings.ALLOWED_VERSIONS +    version_param = api_settings.VERSION_PARAM +      def determine_version(self, request, *args, **kwargs):          msg = '{cls}.determine_version() must be implemented.'          raise NotImplemented(msg.format( @@ -17,46 +24,10 @@ class BaseVersioning(object):      def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):          return _reverse(viewname, args, kwargs, request, format, **extra) - -class QueryParameterVersioning(BaseVersioning): -    """ -    GET /something/?version=0.1 HTTP/1.1 -    Host: example.com -    Accept: application/json -    """ -    default_version = None -    version_param = 'version' - -    def determine_version(self, request, *args, **kwargs): -        return request.query_params.get(self.version_param) - -    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): -        url = super(QueryParameterVersioning, self).reverse( -            viewname, args, kwargs, request, format, **extra -        ) -        if request.version is not None: -            return replace_query_param(url, self.version_param, request.version) -        return url - - -class HostNameVersioning(BaseVersioning): -    """ -    GET /something/ HTTP/1.1 -    Host: v1.example.com -    Accept: application/json -    """ -    default_version = None -    hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$') - -    def determine_version(self, request, *args, **kwargs): -        hostname, seperator, port = request.get_host().partition(':') -        match = self.hostname_regex.match(hostname) -        if not match: -            return self.default_version -        return match.group(1) - -    # We don't need to implement `reverse`, as the hostname will already be -    # preserved as part of the REST framework `reverse` implementation. +    def is_allowed_version(self, version): +        if not self.allowed_versions: +            return True +        return (version == self.default_version) or (version in self.allowed_versions)  class AcceptHeaderVersioning(BaseVersioning): @@ -65,13 +36,15 @@ class AcceptHeaderVersioning(BaseVersioning):      Host: example.com      Accept: application/json; version=1.0      """ -    default_version = None -    version_param = 'version' +    invalid_version_message = _("Invalid version in 'Accept' header.")      def determine_version(self, request, *args, **kwargs):          media_type = _MediaType(request.accepted_media_type)          version = media_type.params.get(self.version_param, self.default_version) -        return unicode_http_header(version) +        version = unicode_http_header(version) +        if not self.is_allowed_version(version): +            raise exceptions.NotAcceptable(self.invalid_version_message) +        return version      # We don't need to implement `reverse`, as the versioning is based      # on the `Accept` header, not on the request URL. @@ -94,11 +67,13 @@ class URLPathVersioning(BaseVersioning):      Host: example.com      Accept: application/json      """ -    default_version = None -    version_param = 'version' +    invalid_version_message = _('Invalid version in URL path.')      def determine_version(self, request, *args, **kwargs): -        return kwargs.get(self.version_param, self.default_version) +        version = kwargs.get(self.version_param, self.default_version) +        if not self.is_allowed_version(version): +            raise exceptions.NotFound(self.invalid_version_message) +        return version      def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):          if request.version is not None: @@ -134,13 +109,16 @@ class NamespaceVersioning(BaseVersioning):      Host: example.com      Accept: application/json      """ -    default_version = None +    invalid_version_message = _('Invalid version in URL path.')      def determine_version(self, request, *args, **kwargs):          resolver_match = getattr(request, 'resolver_match', None)          if (resolver_match is None or not resolver_match.namespace):              return self.default_version -        return resolver_match.namespace +        version = resolver_match.namespace +        if not self.is_allowed_version(version): +            raise exceptions.NotFound(self.invalid_version_message) +        return version      def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):          if request.version is not None: @@ -148,3 +126,49 @@ class NamespaceVersioning(BaseVersioning):          return super(NamespaceVersioning, self).reverse(              viewname, args, kwargs, request, format, **extra          ) + + +class HostNameVersioning(BaseVersioning): +    """ +    GET /something/ HTTP/1.1 +    Host: v1.example.com +    Accept: application/json +    """ +    hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$') +    invalid_version_message = _('Invalid version in hostname.') + +    def determine_version(self, request, *args, **kwargs): +        hostname, seperator, port = request.get_host().partition(':') +        match = self.hostname_regex.match(hostname) +        if not match: +            return self.default_version +        version = match.group(1) +        if not self.is_allowed_version(version): +            raise exceptions.NotFound(self.invalid_version_message) +        return version + +    # We don't need to implement `reverse`, as the hostname will already be +    # preserved as part of the REST framework `reverse` implementation. + + +class QueryParameterVersioning(BaseVersioning): +    """ +    GET /something/?version=0.1 HTTP/1.1 +    Host: example.com +    Accept: application/json +    """ +    invalid_version_message = _('Invalid version in query parameter.') + +    def determine_version(self, request, *args, **kwargs): +        version = request.query_params.get(self.version_param) +        if not self.is_allowed_version(version): +            raise exceptions.NotFound(self.invalid_version_message) +        return version + +    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): +        url = super(QueryParameterVersioning, self).reverse( +            viewname, args, kwargs, request, format, **extra +        ) +        if request.version is not None: +            return replace_query_param(url, self.version_param, request.version) +        return url | 
