diff options
| author | Alex Burgel | 2013-06-05 17:39:14 -0400 | 
|---|---|---|
| committer | Alex Burgel | 2013-07-15 17:59:03 -0400 | 
| commit | d72603bc6a16112008959c5267839f819c2bc43a (patch) | |
| tree | f055ed46f7a5f1684ff887535f63f9278370f7ec | |
| parent | 82145e2b06e402c9740ee970c74456a59683667a (diff) | |
| download | django-rest-framework-d72603bc6a16112008959c5267839f819c2bc43a.tar.bz2 | |
Add support for collection routes to SimpleRouter
| -rw-r--r-- | rest_framework/decorators.py | 26 | ||||
| -rw-r--r-- | rest_framework/routers.py | 33 | ||||
| -rw-r--r-- | rest_framework/tests/test_routers.py | 48 | 
3 files changed, 103 insertions, 4 deletions
| diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index c69756a4..dacd380f 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -113,6 +113,7 @@ def link(**kwargs):      """      def decorator(func):          func.bind_to_methods = ['get'] +        func.collection = False          func.kwargs = kwargs          return func      return decorator @@ -124,6 +125,31 @@ def action(methods=['post'], **kwargs):      """      def decorator(func):          func.bind_to_methods = methods +        func.collection = False +        func.kwargs = kwargs +        return func +    return decorator + + +def collection_link(**kwargs): +    """ +    Used to mark a method on a ViewSet that should be routed for GET requests. +    """ +    def decorator(func): +        func.bind_to_methods = ['get'] +        func.collection = True +        func.kwargs = kwargs +        return func +    return decorator + + +def collection_action(methods=['post'], **kwargs): +    """ +    Used to mark a method on a ViewSet that should be routed for POST requests. +    """ +    def decorator(func): +        func.bind_to_methods = methods +        func.collection = True          func.kwargs = kwargs          return func      return decorator diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 930011d3..9b859a7c 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -88,6 +88,17 @@ class SimpleRouter(BaseRouter):              name='{basename}-list',              initkwargs={'suffix': 'List'}          ), +        # Dynamically generated collection routes. +        # Generated using @collection_action or @collection_link decorators +        # on methods of the viewset. +        Route( +            url=r'^{prefix}/{methodname}{trailing_slash}$', +            mapping={ +                '{httpmethod}': '{methodname}', +            }, +            name='{basename}-collection-{methodnamehyphen}', +            initkwargs={} +        ),          # Detail route.          Route(              url=r'^{prefix}/{lookup}{trailing_slash}$', @@ -107,7 +118,7 @@ class SimpleRouter(BaseRouter):              mapping={                  '{httpmethod}': '{methodname}',              }, -            name='{basename}-{methodnamehyphen}', +            name='{basename}-dynamic-{methodnamehyphen}',              initkwargs={}          ),      ] @@ -142,20 +153,25 @@ class SimpleRouter(BaseRouter):          known_actions = flatten([route.mapping.values() for route in self.routes])          # Determine any `@action` or `@link` decorated methods on the viewset +        collection_routes = []          dynamic_routes = []          for methodname in dir(viewset):              attr = getattr(viewset, methodname)              httpmethods = getattr(attr, 'bind_to_methods', None) +            collection = getattr(attr, 'collection', False)              if httpmethods:                  if methodname in known_actions:                      raise ImproperlyConfigured('Cannot use @action or @link decorator on '                                                 'method "%s" as it is an existing route' % methodname)                  httpmethods = [method.lower() for method in httpmethods] -                dynamic_routes.append((httpmethods, methodname)) +                if collection: +                    collection_routes.append((httpmethods, methodname)) +                else: +                    dynamic_routes.append((httpmethods, methodname))          ret = []          for route in self.routes: -            if route.mapping == {'{httpmethod}': '{methodname}'}: +            if route.name == '{basename}-dynamic-{methodnamehyphen}':                  # Dynamic routes (@link or @action decorator)                  for httpmethods, methodname in dynamic_routes:                      initkwargs = route.initkwargs.copy() @@ -166,6 +182,17 @@ class SimpleRouter(BaseRouter):                          name=replace_methodname(route.name, methodname),                          initkwargs=initkwargs,                      )) +            elif route.name == '{basename}-collection-{methodnamehyphen}': +                # Dynamic routes (@collection_link or @collection_action decorator) +                for httpmethods, methodname in collection_routes: +                    initkwargs = route.initkwargs.copy() +                    initkwargs.update(getattr(viewset, methodname).kwargs) +                    ret.append(Route( +                        url=replace_methodname(route.url, methodname), +                        mapping=dict((httpmethod, methodname) for httpmethod in httpmethods), +                        name=replace_methodname(route.name, methodname), +                        initkwargs=initkwargs, +                    ))              else:                  # Standard route                  ret.append(route) diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index 5fcccb74..60f150d2 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -4,7 +4,7 @@ from django.test import TestCase  from django.core.exceptions import ImproperlyConfigured  from rest_framework import serializers, viewsets, permissions  from rest_framework.compat import include, patterns, url -from rest_framework.decorators import link, action +from rest_framework.decorators import link, action, collection_link, collection_action  from rest_framework.response import Response  from rest_framework.routers import SimpleRouter, DefaultRouter  from rest_framework.test import APIRequestFactory @@ -214,3 +214,49 @@ class TestActionAppliedToExistingRoute(TestCase):          with self.assertRaises(ImproperlyConfigured):              self.router.urls + + +class StaticAndDynamicViewSet(viewsets.ViewSet): +    def list(self, request, *args, **kwargs): +        return Response({'method': 'list'}) + +    @collection_action() +    def collection_action(self, request, *args, **kwargs): +        return Response({'method': 'action1'}) + +    @action() +    def dynamic_action(self, request, *args, **kwargs): +        return Response({'method': 'action2'}) + +    @collection_link() +    def collection_link(self, request, *args, **kwargs): +        return Response({'method': 'link1'}) + +    @link() +    def dynamic_link(self, request, *args, **kwargs): +        return Response({'method': 'link2'}) + + +class TestStaticAndDynamicRouter(TestCase): +    def setUp(self): +        self.router = SimpleRouter() + +    def test_link_and_action_decorator(self): +        routes = self.router.get_routes(StaticAndDynamicViewSet) +        decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))] +        # Make sure all these endpoints exist and none have been clobbered +        for i, endpoint in enumerate(['collection_action', 'collection_link', 'dynamic_action', 'dynamic_link']): +            route = decorator_routes[i] +            # check url listing +            if endpoint.startswith('collection_'): +                self.assertEqual(route.url, +                                 '^{{prefix}}/{0}{{trailing_slash}}$'.format(endpoint)) +            else: +                self.assertEqual(route.url, +                                 '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint)) +            # check method to function mapping +            if endpoint.endswith('action'): +                method_map = 'post' +            else: +                method_map = 'get' +            self.assertEqual(route.mapping[method_map], endpoint) | 
