aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2013-05-31 11:50:37 +0100
committerTom Christie2013-05-31 11:50:37 +0100
commitd7753123f60c6e76c732e02b9caddd57b0683a5a (patch)
tree17d3086191c0d7ec1c494677bd1a2fce0b60d835
parentc36ff4e052f0a6a188100908b03d2c3328fd97bc (diff)
downloaddjango-rest-framework-d7753123f60c6e76c732e02b9caddd57b0683a5a.tar.bz2
HyperlinkedModelSerializer lookup_field option should apply to HyperlinkedIdentityField
-rw-r--r--rest_framework/relations.py5
-rw-r--r--rest_framework/serializers.py2
-rw-r--r--rest_framework/tests/test_routers.py67
3 files changed, 70 insertions, 4 deletions
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 42abf3ca..4ecf795c 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -493,8 +493,9 @@ class HyperlinkedIdentityField(Field):
self.view_name = kwargs.pop('view_name', None)
# Optionally the format of the target hyperlink may be specified
self.format = kwargs.pop('format', None)
-
- self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
+ lookup_field = kwargs.pop('lookup_field', None)
+ if lookup_field is not None:
+ self.lookup_field = lookup_field
# These are pending deprecation
if 'pk_url_kwarg' in kwargs:
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 3e5c366e..4dde0d7c 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -907,6 +907,8 @@ class HyperlinkedModelSerializer(ModelSerializer):
def __init__(self, *args, **kwargs):
super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
+ lookup_field = self.opts.lookup_field
+ self.fields['url'] = HyperlinkedIdentityField(lookup_field=lookup_field)
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model)
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index c73f5e72..fc3a87e9 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -1,13 +1,17 @@
from __future__ import unicode_literals
+from django.db import models
from django.test import TestCase
from django.test.client import RequestFactory
-from rest_framework.response import Response
-from rest_framework import viewsets
+from rest_framework import serializers, viewsets
+from rest_framework.compat import include, patterns, url
from rest_framework.decorators import link, action
+from rest_framework.response import Response
from rest_framework.routers import SimpleRouter
factory = RequestFactory()
+urlpatterns = patterns('',)
+
class BasicViewSet(viewsets.ViewSet):
def list(self, request, *args, **kwargs):
@@ -49,3 +53,62 @@ class TestSimpleRouter(TestCase):
else:
method_map = 'get'
self.assertEqual(route.mapping[method_map], endpoint)
+
+
+class RouterTestModel(models.Model):
+ uuid = models.CharField(max_length=20)
+ text = models.CharField(max_length=200)
+
+
+class TestCustomLookupFields(TestCase):
+ """
+ Ensure that custom lookup fields are correctly routed.
+ """
+ urls = 'rest_framework.tests.test_routers'
+
+ def setUp(self):
+ class NoteSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RouterTestModel
+ lookup_field = 'uuid'
+ fields = ('url', 'uuid', 'text')
+
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+ serializer_class = NoteSerializer
+ lookup_field = 'uuid'
+
+ RouterTestModel.objects.create(uuid='123', text='foo bar')
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+
+ from rest_framework.tests import test_routers
+ urls = getattr(test_routers, 'urlpatterns')
+ urls += patterns('',
+ url(r'^', include(self.router.urls)),
+ )
+
+ def test_custom_lookup_field_route(self):
+ detail_route = self.router.urls[-1]
+ detail_url_pattern = detail_route.regex.pattern
+ self.assertIn('<uuid>', detail_url_pattern)
+
+ def test_retrieve_lookup_field_list_view(self):
+ response = self.client.get('/notes/')
+ self.assertEquals(response.data,
+ [{
+ "url": "http://testserver/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }]
+ )
+
+ def test_retrieve_lookup_field_detail_view(self):
+ response = self.client.get('/notes/123/')
+ self.assertEquals(response.data,
+ {
+ "url": "http://testserver/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }
+ )
+