diff options
| -rw-r--r-- | rest_framework/versioning.py | 31 | ||||
| -rw-r--r-- | tests/test_versioning.py | 67 | 
2 files changed, 94 insertions, 4 deletions
| diff --git a/rest_framework/versioning.py b/rest_framework/versioning.py index 2ca8efff..42df8b2c 100644 --- a/rest_framework/versioning.py +++ b/rest_framework/versioning.py @@ -1,6 +1,7 @@  # coding: utf-8  from __future__ import unicode_literals  from rest_framework.reverse import _reverse +from rest_framework.templatetags.rest_framework import replace_query_param  from rest_framework.utils.mediatypes import _MediaType  import re @@ -30,7 +31,7 @@ class QueryParameterVersioning(BaseVersioning):      def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):          url = super(QueryParameterVersioning, self).reverse( -            viewname, args, kwargs, request, format, **kwargs +            viewname, args, kwargs, request, format, **extra          )          if request.version is not None:              return replace_query_param(url, self.version_param, request.version) @@ -92,5 +93,31 @@ class URLPathVersioning(BaseVersioning):              kwargs[self.version_param] = request.version          return super(URLPathVersioning, self).reverse( -            viewname, args, kwargs, request, format, **kwargs +            viewname, args, kwargs, request, format, **extra +        ) + + +class NamespaceVersioning(BaseVersioning): +    """ +    To the client this is the same style as `URLPathVersioning`. +    The difference is in the backend - this implementation uses +    Django's URL namespaces to determine the version. + +    GET /1.0/something/ HTTP/1.1 +    Host: example.com +    Accept: application/json +    """ +    default_version = None + +    def determine_version(self, request, *args, **kwargs): +        resolver_match = getattr(request, 'resolver_match', None) +        if (resolver_match is None or not resolver_match.namespace): +            return self.default_version +        return resolver_match.namespace + +    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): +        if request.version is not None: +            viewname = request.version + ':' + viewname +        return super(NamespaceVersioning, self).reverse( +            viewname, args, kwargs, request, format, **extra          ) diff --git a/tests/test_versioning.py b/tests/test_versioning.py index d90b29a1..eaac5dfb 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -1,4 +1,4 @@ -from django.conf.urls import url +from django.conf.urls import include, url  from rest_framework import versioning  from rest_framework.decorators import APIView  from rest_framework.response import Response @@ -10,6 +10,7 @@ class RequestVersionView(APIView):      def get(self, request, *args, **kwargs):          return Response({'version': request.version}) +  class ReverseView(APIView):      def get(self, request, *args, **kwargs):          return Response({'url': reverse('another', request=request)}) @@ -19,8 +20,14 @@ factory = APIRequestFactory()  mock_view = lambda request: None +included_patterns = [ +    url(r'^namespaced/$', mock_view, name='another'), +] +  urlpatterns = [ -    url(r'^another/$', mock_view, name='another') +    url(r'^v1/', include(included_patterns, namespace='v1')), +    url(r'^another/$', mock_view, name='another'), +    url(r'^(?P<version>[^/]+)/another/$', mock_view, name='another')  ] @@ -80,6 +87,22 @@ class TestRequestVersion:          response = view(request)          assert response.data == {'version': None} +    def test_namespace_versioning(self): +        class FakeResolverMatch: +            namespace = 'v1' + +        scheme = versioning.NamespaceVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/v1/endpoint/') +        request.resolver_match = FakeResolverMatch +        response = view(request, version='v1') +        assert response.data == {'version': 'v1'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} +  class TestURLReversing(APITestCase):      urls = 'tests.test_versioning' @@ -91,6 +114,18 @@ class TestURLReversing(APITestCase):          response = view(request)          assert response.data == {'url': 'http://testserver/another/'} +    def test_reverse_query_param_versioning(self): +        scheme = versioning.QueryParameterVersioning +        view = ReverseView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/?version=v1') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/?version=v1'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} +      def test_reverse_host_name_versioning(self):          scheme = versioning.HostNameVersioning          view = ReverseView.as_view(versioning_class=scheme) @@ -102,3 +137,31 @@ class TestURLReversing(APITestCase):          request = factory.get('/endpoint/')          response = view(request)          assert response.data == {'url': 'http://testserver/another/'} + +    def test_reverse_url_path_versioning(self): +        scheme = versioning.URLPathVersioning +        view = ReverseView.as_view(versioning_class=scheme) + +        request = factory.get('/v1/endpoint/') +        response = view(request, version='v1') +        assert response.data == {'url': 'http://testserver/v1/another/'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} + +    def test_namespace_versioning(self): +        class FakeResolverMatch: +            namespace = 'v1' + +        scheme = versioning.NamespaceVersioning +        view = ReverseView.as_view(versioning_class=scheme) + +        request = factory.get('/v1/endpoint/') +        request.resolver_match = FakeResolverMatch +        response = view(request, version='v1') +        assert response.data == {'url': 'http://testserver/v1/namespaced/'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} | 
