diff options
| author | Tom Christie | 2013-06-04 20:59:12 +0100 | 
|---|---|---|
| committer | Tom Christie | 2013-06-04 20:59:12 +0100 | 
| commit | f1251e8c58e86db11028396d79f01db5dfcf9e52 (patch) | |
| tree | e69f88d597b51a93c49b87c9dfa6d55f1d6be13e /rest_framework | |
| parent | ffa27b840f66d24562de9f66a9ac7a4142da51db (diff) | |
| download | django-rest-framework-f1251e8c58e86db11028396d79f01db5dfcf9e52.tar.bz2 | |
Added trailing_slash argument to routers.  Closes #905
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/routers.py | 17 | ||||
| -rw-r--r-- | rest_framework/tests/test_routers.py | 35 | 
2 files changed, 44 insertions, 8 deletions
| diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 6c5fd004..9764e569 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -18,7 +18,6 @@ from __future__ import unicode_literals  from collections import namedtuple  from rest_framework import views  from rest_framework.compat import patterns, url -from rest_framework.decorators import api_view  from rest_framework.response import Response  from rest_framework.reverse import reverse  from rest_framework.urlpatterns import format_suffix_patterns @@ -72,7 +71,7 @@ class SimpleRouter(BaseRouter):      routes = [          # List route.          Route( -            url=r'^{prefix}/$', +            url=r'^{prefix}{trailing_slash}$',              mapping={                  'get': 'list',                  'post': 'create' @@ -82,7 +81,7 @@ class SimpleRouter(BaseRouter):          ),          # Detail route.          Route( -            url=r'^{prefix}/{lookup}/$', +            url=r'^{prefix}/{lookup}{trailing_slash}$',              mapping={                  'get': 'retrieve',                  'put': 'update', @@ -95,7 +94,7 @@ class SimpleRouter(BaseRouter):          # Dynamically generated routes.          # Generated using @action or @link decorators on methods of the viewset.          Route( -            url=r'^{prefix}/{lookup}/{methodname}/$', +            url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$',              mapping={                  '{httpmethod}': '{methodname}',              }, @@ -104,6 +103,10 @@ class SimpleRouter(BaseRouter):          ),      ] +    def __init__(self, trailing_slash=True): +        self.trailing_slash = trailing_slash and '/' or '' +        super(SimpleRouter, self).__init__() +      def get_default_base_name(self, viewset):          """          If `base_name` is not specified, attempt to automatically determine @@ -193,7 +196,11 @@ class SimpleRouter(BaseRouter):                      continue                  # Build the url pattern -                regex = route.url.format(prefix=prefix, lookup=lookup) +                regex = route.url.format( +                    prefix=prefix, +                    lookup=lookup, +                    trailing_slash=self.trailing_slash +                )                  view = viewset.as_view(mapping, **route.initkwargs)                  name = route.name.format(basename=basename)                  ret.append(url(regex, view, name=name)) diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index 10d3cc25..a7534f70 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -50,7 +50,7 @@ class TestSimpleRouter(TestCase):              route = decorator_routes[i]              # check url listing              self.assertEqual(route.url, -                             '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint)) +                             '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))              # check method to function mapping              if endpoint == 'action3':                  methods_map = ['post', 'delete'] @@ -103,7 +103,7 @@ class TestCustomLookupFields(TestCase):      def test_retrieve_lookup_field_list_view(self):          response = self.client.get('/notes/') -        self.assertEquals(response.data, +        self.assertEqual(response.data,              [{                  "url": "http://testserver/notes/123/",                  "uuid": "123", "text": "foo bar" @@ -112,10 +112,39 @@ class TestCustomLookupFields(TestCase):      def test_retrieve_lookup_field_detail_view(self):          response = self.client.get('/notes/123/') -        self.assertEquals(response.data, +        self.assertEqual(response.data,              {                  "url": "http://testserver/notes/123/",                  "uuid": "123", "text": "foo bar"              }          ) + +class TestTrailingSlash(TestCase): +    def setUp(self): +        class NoteViewSet(viewsets.ModelViewSet): +            model = RouterTestModel + +        self.router = SimpleRouter() +        self.router.register(r'notes', NoteViewSet) +        self.urls = self.router.urls + +    def test_urls_have_trailing_slash_by_default(self): +        expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$'] +        for idx in range(len(expected)): +            self.assertEqual(expected[idx], self.urls[idx].regex.pattern) + + +class TestTrailingSlash(TestCase): +    def setUp(self): +        class NoteViewSet(viewsets.ModelViewSet): +            model = RouterTestModel + +        self.router = SimpleRouter(trailing_slash=False) +        self.router.register(r'notes', NoteViewSet) +        self.urls = self.router.urls + +    def test_urls_can_have_trailing_slash_removed(self): +        expected = ['^notes$', '^notes/(?P<pk>[^/]+)$'] +        for idx in range(len(expected)): +            self.assertEqual(expected[idx], self.urls[idx].regex.pattern) | 
