aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-06-04 20:59:12 +0100
committerTom Christie2013-06-04 20:59:12 +0100
commitf1251e8c58e86db11028396d79f01db5dfcf9e52 (patch)
treee69f88d597b51a93c49b87c9dfa6d55f1d6be13e /rest_framework
parentffa27b840f66d24562de9f66a9ac7a4142da51db (diff)
downloaddjango-rest-framework-f1251e8c58e86db11028396d79f01db5dfcf9e52.tar.bz2
Added trailing_slash argument to routers. Closes #905
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/routers.py17
-rw-r--r--rest_framework/tests/test_routers.py35
2 files changed, 44 insertions, 8 deletions
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 6c5fd004..9764e569 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -18,7 +18,6 @@ from __future__ import unicode_literals
from collections import namedtuple
from rest_framework import views
from rest_framework.compat import patterns, url
-from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.urlpatterns import format_suffix_patterns
@@ -72,7 +71,7 @@ class SimpleRouter(BaseRouter):
routes = [
# List route.
Route(
- url=r'^{prefix}/$',
+ url=r'^{prefix}{trailing_slash}$',
mapping={
'get': 'list',
'post': 'create'
@@ -82,7 +81,7 @@ class SimpleRouter(BaseRouter):
),
# Detail route.
Route(
- url=r'^{prefix}/{lookup}/$',
+ url=r'^{prefix}/{lookup}{trailing_slash}$',
mapping={
'get': 'retrieve',
'put': 'update',
@@ -95,7 +94,7 @@ class SimpleRouter(BaseRouter):
# Dynamically generated routes.
# Generated using @action or @link decorators on methods of the viewset.
Route(
- url=r'^{prefix}/{lookup}/{methodname}/$',
+ url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$',
mapping={
'{httpmethod}': '{methodname}',
},
@@ -104,6 +103,10 @@ class SimpleRouter(BaseRouter):
),
]
+ def __init__(self, trailing_slash=True):
+ self.trailing_slash = trailing_slash and '/' or ''
+ super(SimpleRouter, self).__init__()
+
def get_default_base_name(self, viewset):
"""
If `base_name` is not specified, attempt to automatically determine
@@ -193,7 +196,11 @@ class SimpleRouter(BaseRouter):
continue
# Build the url pattern
- regex = route.url.format(prefix=prefix, lookup=lookup)
+ regex = route.url.format(
+ prefix=prefix,
+ lookup=lookup,
+ trailing_slash=self.trailing_slash
+ )
view = viewset.as_view(mapping, **route.initkwargs)
name = route.name.format(basename=basename)
ret.append(url(regex, view, name=name))
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index 10d3cc25..a7534f70 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -50,7 +50,7 @@ class TestSimpleRouter(TestCase):
route = decorator_routes[i]
# check url listing
self.assertEqual(route.url,
- '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint))
+ '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
# check method to function mapping
if endpoint == 'action3':
methods_map = ['post', 'delete']
@@ -103,7 +103,7 @@ class TestCustomLookupFields(TestCase):
def test_retrieve_lookup_field_list_view(self):
response = self.client.get('/notes/')
- self.assertEquals(response.data,
+ self.assertEqual(response.data,
[{
"url": "http://testserver/notes/123/",
"uuid": "123", "text": "foo bar"
@@ -112,10 +112,39 @@ class TestCustomLookupFields(TestCase):
def test_retrieve_lookup_field_detail_view(self):
response = self.client.get('/notes/123/')
- self.assertEquals(response.data,
+ self.assertEqual(response.data,
{
"url": "http://testserver/notes/123/",
"uuid": "123", "text": "foo bar"
}
)
+
+class TestTrailingSlash(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_have_trailing_slash_by_default(self):
+ expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestTrailingSlash(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+
+ self.router = SimpleRouter(trailing_slash=False)
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_can_have_trailing_slash_removed(self):
+ expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)