diff options
| -rw-r--r-- | docs/api-guide/versioning.md | 9 | ||||
| -rw-r--r-- | rest_framework/reverse.py | 12 | ||||
| -rw-r--r-- | rest_framework/settings.py | 6 | ||||
| -rw-r--r-- | rest_framework/versioning.py | 96 | ||||
| -rw-r--r-- | rest_framework/views.py | 27 | ||||
| -rw-r--r-- | tests/test_versioning.py | 104 | 
6 files changed, 248 insertions, 6 deletions
| diff --git a/docs/api-guide/versioning.md b/docs/api-guide/versioning.md new file mode 100644 index 00000000..df814894 --- /dev/null +++ b/docs/api-guide/versioning.md @@ -0,0 +1,9 @@ +source: versioning.py + +# Versioning + +> Versioning an interface is just a "polite" way to kill deployed clients. +>  +> — [Roy Fielding][cite]. + +[cite]: http://www.slideshare.net/evolve_conference/201308-fielding-evolve/31
\ No newline at end of file diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index a74e8aa2..8fcca55b 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -9,6 +9,18 @@ from django.utils.functional import lazy  def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):      """ +    If versioning is being used then we pass any `reverse` calls through +    to the versioning scheme instance, so that the resulting URL +    can be modified if needed. +    """ +    scheme = getattr(request, 'versioning_scheme', None) +    if scheme is not None: +        return scheme.reverse(viewname, args, kwargs, request, format, **extra) +    return _reverse(viewname, args, kwargs, request, format, **extra) + + +def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra): +    """      Same as `django.core.urlresolvers.reverse`, but optionally takes a request      and returns a fully qualified URL, using the request to get the base URL.      """ diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 6c26256a..da3be38d 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -46,6 +46,7 @@ DEFAULTS = {      'DEFAULT_THROTTLE_CLASSES': (),      'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',      'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata', +    'DEFAULT_VERSIONING_CLASS': None,      # Generic view behavior      'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer', @@ -123,6 +124,7 @@ IMPORT_STRINGS = (      'DEFAULT_THROTTLE_CLASSES',      'DEFAULT_CONTENT_NEGOTIATION_CLASS',      'DEFAULT_METADATA_CLASS', +    'DEFAULT_VERSIONING_CLASS',      'DEFAULT_PAGINATION_SERIALIZER_CLASS',      'DEFAULT_FILTER_BACKENDS',      'EXCEPTION_HANDLER', @@ -139,7 +141,9 @@ def perform_import(val, setting_name):      If the given setting is a string import notation,      then perform the necessary import or imports.      """ -    if isinstance(val, six.string_types): +    if val is None: +        return None +    elif isinstance(val, six.string_types):          return import_from_string(val, setting_name)      elif isinstance(val, (list, tuple)):          return [import_from_string(item, setting_name) for item in val] diff --git a/rest_framework/versioning.py b/rest_framework/versioning.py new file mode 100644 index 00000000..2ca8efff --- /dev/null +++ b/rest_framework/versioning.py @@ -0,0 +1,96 @@ +# coding: utf-8 +from __future__ import unicode_literals +from rest_framework.reverse import _reverse +from rest_framework.utils.mediatypes import _MediaType +import re + + +class BaseVersioning(object): +    def determine_version(self, request, *args, **kwargs): +        msg = '{cls}.determine_version() must be implemented.' +        raise NotImplemented(msg.format( +            cls=self.__class__.__name__ +        )) + +    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): +        return _reverse(viewname, args, kwargs, request, format, **extra) + + +class QueryParameterVersioning(BaseVersioning): +    """ +    GET /something/?version=0.1 HTTP/1.1 +    Host: example.com +    Accept: application/json +    """ +    default_version = None +    version_param = 'version' + +    def determine_version(self, request, *args, **kwargs): +        return request.query_params.get(self.version_param) + +    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): +        url = super(QueryParameterVersioning, self).reverse( +            viewname, args, kwargs, request, format, **kwargs +        ) +        if request.version is not None: +            return replace_query_param(url, self.version_param, request.version) +        return url + + +class HostNameVersioning(BaseVersioning): +    """ +    GET /something/ HTTP/1.1 +    Host: v1.example.com +    Accept: application/json +    """ +    default_version = None +    hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$') + +    def determine_version(self, request, *args, **kwargs): +        hostname, seperator, port = request.get_host().partition(':') +        match = self.hostname_regex.match(hostname) +        if not match: +            return self.default_version +        return match.group(1) + +    # We don't need to implement `reverse`, as the hostname will already be +    # preserved as part of the standard `reverse` implementation. + + +class AcceptHeaderVersioning(BaseVersioning): +    """ +    GET /something/ HTTP/1.1 +    Host: example.com +    Accept: application/json; version=1.0 +    """ +    default_version = None +    version_param = 'version' + +    def determine_version(self, request, *args, **kwargs): +        media_type = _MediaType(request.accepted_media_type) +        return media_type.params.get(self.version_param, self.default_version) + +    # We don't need to implement `reverse`, as the versioning is based +    # on the `Accept` header, not on the request URL. + + +class URLPathVersioning(BaseVersioning): +    """ +    GET /1.0/something/ HTTP/1.1 +    Host: example.com +    Accept: application/json +    """ +    default_version = None +    version_param = 'version' + +    def determine_version(self, request, *args, **kwargs): +        return kwargs.get(self.version_param, self.default_version) + +    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): +        if request.version is not None: +            kwargs = {} if (kwargs is None) else kwargs +            kwargs[self.version_param] = request.version + +        return super(URLPathVersioning, self).reverse( +            viewname, args, kwargs, request, format, **kwargs +        ) diff --git a/rest_framework/views.py b/rest_framework/views.py index b39724c2..12bb78bd 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -95,6 +95,7 @@ class APIView(View):      permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES      content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS      metadata_class = api_settings.DEFAULT_METADATA_CLASS +    versioning_class = api_settings.DEFAULT_VERSIONING_CLASS      # Allow dependency injection of other settings to make testing easier.      settings = api_settings @@ -314,6 +315,16 @@ class APIView(View):              if not throttle.allow_request(request, self):                  self.throttled(request, throttle.wait()) +    def determine_version(self, request, *args, **kwargs): +        """ +        If versioning is being used, then determine any API version for the +        incoming request. Returns a two-tuple of (version, versioning_scheme) +        """ +        if self.versioning_class is None: +            return (None, None) +        scheme = self.versioning_class() +        return (scheme.determine_version(request, *args, **kwargs), scheme) +      # Dispatch methods      def initialize_request(self, request, *args, **kwargs): @@ -322,11 +333,13 @@ class APIView(View):          """          parser_context = self.get_parser_context(request) -        return Request(request, -                       parsers=self.get_parsers(), -                       authenticators=self.get_authenticators(), -                       negotiator=self.get_content_negotiator(), -                       parser_context=parser_context) +        return Request( +            request, +            parsers=self.get_parsers(), +            authenticators=self.get_authenticators(), +            negotiator=self.get_content_negotiator(), +            parser_context=parser_context +        )      def initial(self, request, *args, **kwargs):          """ @@ -343,6 +356,10 @@ class APIView(View):          neg = self.perform_content_negotiation(request)          request.accepted_renderer, request.accepted_media_type = neg +        # Determine the API version, if versioning is in use. +        version, scheme = self.determine_version(request, *args, **kwargs) +        request.version, request.versioning_scheme = version, scheme +      def finalize_response(self, request, response, *args, **kwargs):          """          Returns the final response object. diff --git a/tests/test_versioning.py b/tests/test_versioning.py new file mode 100644 index 00000000..d90b29a1 --- /dev/null +++ b/tests/test_versioning.py @@ -0,0 +1,104 @@ +from django.conf.urls import url +from rest_framework import versioning +from rest_framework.decorators import APIView +from rest_framework.response import Response +from rest_framework.reverse import reverse +from rest_framework.test import APIRequestFactory, APITestCase + + +class RequestVersionView(APIView): +    def get(self, request, *args, **kwargs): +        return Response({'version': request.version}) + +class ReverseView(APIView): +    def get(self, request, *args, **kwargs): +        return Response({'url': reverse('another', request=request)}) + + +factory = APIRequestFactory() + +mock_view = lambda request: None + +urlpatterns = [ +    url(r'^another/$', mock_view, name='another') +] + + +class TestRequestVersion: +    def test_unversioned(self): +        view = RequestVersionView.as_view() + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + +    def test_query_param_versioning(self): +        scheme = versioning.QueryParameterVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/?version=1.2.3') +        response = view(request) +        assert response.data == {'version': '1.2.3'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + +    def test_host_name_versioning(self): +        scheme = versioning.HostNameVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_HOST='v1.example.org') +        response = view(request) +        assert response.data == {'version': 'v1'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + +    def test_accept_header_versioning(self): +        scheme = versioning.AcceptHeaderVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=1.2.3') +        response = view(request) +        assert response.data == {'version': '1.2.3'} + +        request = factory.get('/endpoint/', HTTP_ACCEPT='application/json') +        response = view(request) +        assert response.data == {'version': None} + +    def test_url_path_versioning(self): +        scheme = versioning.URLPathVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/1.2.3/endpoint/') +        response = view(request, version='1.2.3') +        assert response.data == {'version': '1.2.3'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + + +class TestURLReversing(APITestCase): +    urls = 'tests.test_versioning' + +    def test_reverse_unversioned(self): +        view = ReverseView.as_view() + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} + +    def test_reverse_host_name_versioning(self): +        scheme = versioning.HostNameVersioning +        view = ReverseView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_HOST='v1.example.org') +        response = view(request) +        assert response.data == {'url': 'http://v1.example.org/another/'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} | 
