aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/relations.py17
-rw-r--r--rest_framework/serializers.py13
-rw-r--r--rest_framework/tests/models.py2
-rw-r--r--rest_framework/tests/test_authentication.py4
-rw-r--r--rest_framework/tests/test_fields.py41
-rw-r--r--rest_framework/tests/test_routers.py67
-rw-r--r--rest_framework/tests/test_serializer.py28
-rw-r--r--rest_framework/tests/test_testcases.py66
8 files changed, 123 insertions, 115 deletions
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 42abf3ca..e3675b51 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -72,7 +72,6 @@ class RelatedField(WritableField):
else: # Reverse
self.queryset = manager.field.rel.to._default_manager.all()
except Exception:
- raise
msg = ('Serializer related fields must include a `queryset`' +
' argument or set `read_only=True')
raise Exception(msg)
@@ -488,13 +487,15 @@ class HyperlinkedIdentityField(Field):
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
def __init__(self, *args, **kwargs):
- # TODO: Make view_name mandatory, and have the
- # HyperlinkedModelSerializer set it on-the-fly
- self.view_name = kwargs.pop('view_name', None)
- # Optionally the format of the target hyperlink may be specified
- self.format = kwargs.pop('format', None)
+ try:
+ self.view_name = kwargs.pop('view_name')
+ except KeyError:
+ msg = "HyperlinkedIdentityField requires 'view_name' argument"
+ raise ValueError(msg)
- self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
+ self.format = kwargs.pop('format', None)
+ lookup_field = kwargs.pop('lookup_field', None)
+ self.lookup_field = lookup_field or self.lookup_field
# These are pending deprecation
if 'pk_url_kwarg' in kwargs:
@@ -517,7 +518,7 @@ class HyperlinkedIdentityField(Field):
def field_to_native(self, obj, field_name):
request = self.context.get('request', None)
format = self.context.get('format', None)
- view_name = self.view_name or self.parent.opts.view_name
+ view_name = self.view_name
if request is None:
warnings.warn("Using `HyperlinkedIdentityField` without including the "
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 3e5c366e..a4969f60 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -903,13 +903,24 @@ class HyperlinkedModelSerializer(ModelSerializer):
_default_view_name = '%(model_name)s-detail'
_hyperlink_field_class = HyperlinkedRelatedField
- url = HyperlinkedIdentityField()
+ # Just a placeholder to ensure 'url' is the first field
+ # The field itself is actually created on initialization,
+ # when the view_name and lookup_field arguments are available.
+ url = Field()
def __init__(self, *args, **kwargs):
super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
+
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model)
+ url_field = HyperlinkedIdentityField(
+ view_name=self.opts.view_name,
+ lookup_field=self.opts.lookup_field
+ )
+ url_field.initialize(self, 'url')
+ self.fields['url'] = url_field
+
def _get_default_view_name(self, model):
"""
Return the view name to use if 'view_name' is not specified in 'Meta'
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index abf50a2d..e2d4eacd 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -162,8 +162,8 @@ class NullableOneToOneSource(RESTFrameworkModel):
target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
related_name='nullable_source')
+
# Serializer used to test BasicModel
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
-
diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py
index 90e1f5c4..05e9fbc3 100644
--- a/rest_framework/tests/test_authentication.py
+++ b/rest_framework/tests/test_authentication.py
@@ -48,7 +48,7 @@ urlpatterns = patterns('',
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
(r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
- (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
+ (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
permission_classes=[permissions.TokenHasReadWriteScope]))
)
@@ -56,7 +56,7 @@ if oauth2_provider is not None:
urlpatterns += patterns('',
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
- url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
+ url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
permission_classes=[permissions.TokenHasReadWriteScope])),
)
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index bff4400b..3f956051 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -11,8 +11,6 @@ from django.db import models
from django.test import TestCase
from django.utils.datastructures import SortedDict
from rest_framework import serializers
-from rest_framework.fields import Field, CharField
-from rest_framework.serializers import Serializer
from rest_framework.tests.models import RESTFrameworkModel
@@ -590,7 +588,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure the serializer works correctly
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_value=9010,
min_value=9000,
max_digits=6,
@@ -608,7 +606,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure max_value violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_value=100)
s = DecimalSerializer(data={'decimal_field': '123'})
@@ -620,7 +618,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure min_value violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(min_value=100)
s = DecimalSerializer(data={'decimal_field': '99'})
@@ -632,7 +630,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure max_digits violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_digits=5)
s = DecimalSerializer(data={'decimal_field': '123.456'})
@@ -644,7 +642,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure max_decimal_places violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(decimal_places=3)
s = DecimalSerializer(data={'decimal_field': '123.4567'})
@@ -656,7 +654,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure max_whole_digits violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
s = DecimalSerializer(data={'decimal_field': '12345.6'})
@@ -837,11 +835,11 @@ class URLFieldTests(TestCase):
class FieldMetadata(TestCase):
def setUp(self):
- self.required_field = Field()
+ self.required_field = serializers.Field()
self.required_field.label = uuid4().hex
self.required_field.required = True
- self.optional_field = Field()
+ self.optional_field = serializers.Field()
self.optional_field.label = uuid4().hex
self.optional_field.required = False
@@ -854,26 +852,3 @@ class FieldMetadata(TestCase):
def test_label(self):
for field in (self.required_field, self.optional_field):
self.assertEqual(field.metadata()['label'], field.label)
-
-
-class MetadataSerializer(Serializer):
- field1 = CharField(3, required=True)
- field2 = CharField(10, required=False)
-
-
-class MetadataSerializerTestCase(TestCase):
- def setUp(self):
- self.serializer = MetadataSerializer()
-
- def test_serializer_metadata(self):
- metadata = self.serializer.metadata()
- expected = {
- 'field1': {'required': True,
- 'max_length': 3,
- 'type': 'string',
- 'read_only': False},
- 'field2': {'required': False,
- 'max_length': 10,
- 'type': 'string',
- 'read_only': False}}
- self.assertEqual(expected, metadata)
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index 4e4765cb..fc3a87e9 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -1,15 +1,17 @@
from __future__ import unicode_literals
+from django.db import models
from django.test import TestCase
from django.test.client import RequestFactory
-from rest_framework import status
-from rest_framework.response import Response
-from rest_framework import viewsets
+from rest_framework import serializers, viewsets
+from rest_framework.compat import include, patterns, url
from rest_framework.decorators import link, action
+from rest_framework.response import Response
from rest_framework.routers import SimpleRouter
-import copy
factory = RequestFactory()
+urlpatterns = patterns('',)
+
class BasicViewSet(viewsets.ViewSet):
def list(self, request, *args, **kwargs):
@@ -53,3 +55,60 @@ class TestSimpleRouter(TestCase):
self.assertEqual(route.mapping[method_map], endpoint)
+class RouterTestModel(models.Model):
+ uuid = models.CharField(max_length=20)
+ text = models.CharField(max_length=200)
+
+
+class TestCustomLookupFields(TestCase):
+ """
+ Ensure that custom lookup fields are correctly routed.
+ """
+ urls = 'rest_framework.tests.test_routers'
+
+ def setUp(self):
+ class NoteSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RouterTestModel
+ lookup_field = 'uuid'
+ fields = ('url', 'uuid', 'text')
+
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+ serializer_class = NoteSerializer
+ lookup_field = 'uuid'
+
+ RouterTestModel.objects.create(uuid='123', text='foo bar')
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+
+ from rest_framework.tests import test_routers
+ urls = getattr(test_routers, 'urlpatterns')
+ urls += patterns('',
+ url(r'^', include(self.router.urls)),
+ )
+
+ def test_custom_lookup_field_route(self):
+ detail_route = self.router.urls[-1]
+ detail_url_pattern = detail_route.regex.pattern
+ self.assertIn('<uuid>', detail_url_pattern)
+
+ def test_retrieve_lookup_field_list_view(self):
+ response = self.client.get('/notes/')
+ self.assertEquals(response.data,
+ [{
+ "url": "http://testserver/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }]
+ )
+
+ def test_retrieve_lookup_field_detail_view(self):
+ response = self.client.get('/notes/123/')
+ self.assertEquals(response.data,
+ {
+ "url": "http://testserver/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }
+ )
+
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index f2c31872..6cc913c5 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -1528,3 +1528,31 @@ class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
def test_url_field(self):
self.field_test('url_field')
+
+
+class MetadataSerializer(serializers.Serializer):
+ field1 = serializers.CharField(3, required=True)
+ field2 = serializers.CharField(10, required=False)
+
+
+class MetadataSerializerTestCase(TestCase):
+ def setUp(self):
+ self.serializer = MetadataSerializer()
+
+ def test_serializer_metadata(self):
+ metadata = self.serializer.metadata()
+ expected = {
+ 'field1': {
+ 'required': True,
+ 'max_length': 3,
+ 'type': 'string',
+ 'read_only': False
+ },
+ 'field2': {
+ 'required': False,
+ 'max_length': 10,
+ 'type': 'string',
+ 'read_only': False
+ }
+ }
+ self.assertEqual(expected, metadata)
diff --git a/rest_framework/tests/test_testcases.py b/rest_framework/tests/test_testcases.py
deleted file mode 100644
index f8c2579e..00000000
--- a/rest_framework/tests/test_testcases.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# http://djangosnippets.org/snippets/1011/
-from __future__ import unicode_literals
-from django.conf import settings
-from django.core.management import call_command
-from django.db.models import loading
-from django.test import TestCase
-
-NO_SETTING = ('!', None)
-
-
-class TestSettingsManager(object):
- """
- A class which can modify some Django settings temporarily for a
- test and then revert them to their original values later.
-
- Automatically handles resyncing the DB if INSTALLED_APPS is
- modified.
-
- """
- def __init__(self):
- self._original_settings = {}
-
- def set(self, **kwargs):
- for k, v in kwargs.iteritems():
- self._original_settings.setdefault(k, getattr(settings, k,
- NO_SETTING))
- setattr(settings, k, v)
- if 'INSTALLED_APPS' in kwargs:
- self.syncdb()
-
- def syncdb(self):
- loading.cache.loaded = False
- call_command('syncdb', verbosity=0)
-
- def revert(self):
- for k, v in self._original_settings.iteritems():
- if v == NO_SETTING:
- delattr(settings, k)
- else:
- setattr(settings, k, v)
- if 'INSTALLED_APPS' in self._original_settings:
- self.syncdb()
- self._original_settings = {}
-
-
-class SettingsTestCase(TestCase):
- """
- A subclass of the Django TestCase with a settings_manager
- attribute which is an instance of TestSettingsManager.
-
- Comes with a tearDown() method that calls
- self.settings_manager.revert().
-
- """
- def __init__(self, *args, **kwargs):
- super(SettingsTestCase, self).__init__(*args, **kwargs)
- self.settings_manager = TestSettingsManager()
-
- def tearDown(self):
- self.settings_manager.revert()
-
-
-class TestModelsTestCase(SettingsTestCase):
- def setUp(self, *args, **kwargs):
- installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',)
- self.settings_manager.set(INSTALLED_APPS=installed_apps)