diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/relations.py | 17 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 13 | ||||
| -rw-r--r-- | rest_framework/tests/models.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_authentication.py | 4 | ||||
| -rw-r--r-- | rest_framework/tests/test_fields.py | 41 | ||||
| -rw-r--r-- | rest_framework/tests/test_routers.py | 67 | ||||
| -rw-r--r-- | rest_framework/tests/test_serializer.py | 28 | ||||
| -rw-r--r-- | rest_framework/tests/test_testcases.py | 66 |
8 files changed, 123 insertions, 115 deletions
diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 42abf3ca..e3675b51 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -72,7 +72,6 @@ class RelatedField(WritableField): else: # Reverse self.queryset = manager.field.rel.to._default_manager.all() except Exception: - raise msg = ('Serializer related fields must include a `queryset`' + ' argument or set `read_only=True') raise Exception(msg) @@ -488,13 +487,15 @@ class HyperlinkedIdentityField(Field): slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden def __init__(self, *args, **kwargs): - # TODO: Make view_name mandatory, and have the - # HyperlinkedModelSerializer set it on-the-fly - self.view_name = kwargs.pop('view_name', None) - # Optionally the format of the target hyperlink may be specified - self.format = kwargs.pop('format', None) + try: + self.view_name = kwargs.pop('view_name') + except KeyError: + msg = "HyperlinkedIdentityField requires 'view_name' argument" + raise ValueError(msg) - self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + self.format = kwargs.pop('format', None) + lookup_field = kwargs.pop('lookup_field', None) + self.lookup_field = lookup_field or self.lookup_field # These are pending deprecation if 'pk_url_kwarg' in kwargs: @@ -517,7 +518,7 @@ class HyperlinkedIdentityField(Field): def field_to_native(self, obj, field_name): request = self.context.get('request', None) format = self.context.get('format', None) - view_name = self.view_name or self.parent.opts.view_name + view_name = self.view_name if request is None: warnings.warn("Using `HyperlinkedIdentityField` without including the " diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 3e5c366e..a4969f60 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -903,13 +903,24 @@ class HyperlinkedModelSerializer(ModelSerializer): _default_view_name = '%(model_name)s-detail' _hyperlink_field_class = HyperlinkedRelatedField - url = HyperlinkedIdentityField() + # Just a placeholder to ensure 'url' is the first field + # The field itself is actually created on initialization, + # when the view_name and lookup_field arguments are available. + url = Field() def __init__(self, *args, **kwargs): super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs) + if self.opts.view_name is None: self.opts.view_name = self._get_default_view_name(self.opts.model) + url_field = HyperlinkedIdentityField( + view_name=self.opts.view_name, + lookup_field=self.opts.lookup_field + ) + url_field.initialize(self, 'url') + self.fields['url'] = url_field + def _get_default_view_name(self, model): """ Return the view name to use if 'view_name' is not specified in 'Meta' diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index abf50a2d..e2d4eacd 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -162,8 +162,8 @@ class NullableOneToOneSource(RESTFrameworkModel): target = models.OneToOneField(OneToOneTarget, null=True, blank=True, related_name='nullable_source') + # Serializer used to test BasicModel class BasicModelSerializer(serializers.ModelSerializer): class Meta: model = BasicModel - diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index 90e1f5c4..05e9fbc3 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -48,7 +48,7 @@ urlpatterns = patterns('', (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), - (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], + (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], permission_classes=[permissions.TokenHasReadWriteScope])) ) @@ -56,7 +56,7 @@ if oauth2_provider is not None: urlpatterns += patterns('', url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), - url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], + url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], permission_classes=[permissions.TokenHasReadWriteScope])), ) diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index bff4400b..3f956051 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -11,8 +11,6 @@ from django.db import models from django.test import TestCase from django.utils.datastructures import SortedDict from rest_framework import serializers -from rest_framework.fields import Field, CharField -from rest_framework.serializers import Serializer from rest_framework.tests.models import RESTFrameworkModel @@ -590,7 +588,7 @@ class DecimalFieldTest(TestCase): """ Make sure the serializer works correctly """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_value=9010, min_value=9000, max_digits=6, @@ -608,7 +606,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_value violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_value=100) s = DecimalSerializer(data={'decimal_field': '123'}) @@ -620,7 +618,7 @@ class DecimalFieldTest(TestCase): """ Make sure min_value violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(min_value=100) s = DecimalSerializer(data={'decimal_field': '99'}) @@ -632,7 +630,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_digits violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_digits=5) s = DecimalSerializer(data={'decimal_field': '123.456'}) @@ -644,7 +642,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_decimal_places violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(decimal_places=3) s = DecimalSerializer(data={'decimal_field': '123.4567'}) @@ -656,7 +654,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_whole_digits violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) s = DecimalSerializer(data={'decimal_field': '12345.6'}) @@ -837,11 +835,11 @@ class URLFieldTests(TestCase): class FieldMetadata(TestCase): def setUp(self): - self.required_field = Field() + self.required_field = serializers.Field() self.required_field.label = uuid4().hex self.required_field.required = True - self.optional_field = Field() + self.optional_field = serializers.Field() self.optional_field.label = uuid4().hex self.optional_field.required = False @@ -854,26 +852,3 @@ class FieldMetadata(TestCase): def test_label(self): for field in (self.required_field, self.optional_field): self.assertEqual(field.metadata()['label'], field.label) - - -class MetadataSerializer(Serializer): - field1 = CharField(3, required=True) - field2 = CharField(10, required=False) - - -class MetadataSerializerTestCase(TestCase): - def setUp(self): - self.serializer = MetadataSerializer() - - def test_serializer_metadata(self): - metadata = self.serializer.metadata() - expected = { - 'field1': {'required': True, - 'max_length': 3, - 'type': 'string', - 'read_only': False}, - 'field2': {'required': False, - 'max_length': 10, - 'type': 'string', - 'read_only': False}} - self.assertEqual(expected, metadata) diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index 4e4765cb..fc3a87e9 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -1,15 +1,17 @@ from __future__ import unicode_literals +from django.db import models from django.test import TestCase from django.test.client import RequestFactory -from rest_framework import status -from rest_framework.response import Response -from rest_framework import viewsets +from rest_framework import serializers, viewsets +from rest_framework.compat import include, patterns, url from rest_framework.decorators import link, action +from rest_framework.response import Response from rest_framework.routers import SimpleRouter -import copy factory = RequestFactory() +urlpatterns = patterns('',) + class BasicViewSet(viewsets.ViewSet): def list(self, request, *args, **kwargs): @@ -53,3 +55,60 @@ class TestSimpleRouter(TestCase): self.assertEqual(route.mapping[method_map], endpoint) +class RouterTestModel(models.Model): + uuid = models.CharField(max_length=20) + text = models.CharField(max_length=200) + + +class TestCustomLookupFields(TestCase): + """ + Ensure that custom lookup fields are correctly routed. + """ + urls = 'rest_framework.tests.test_routers' + + def setUp(self): + class NoteSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RouterTestModel + lookup_field = 'uuid' + fields = ('url', 'uuid', 'text') + + class NoteViewSet(viewsets.ModelViewSet): + queryset = RouterTestModel.objects.all() + serializer_class = NoteSerializer + lookup_field = 'uuid' + + RouterTestModel.objects.create(uuid='123', text='foo bar') + + self.router = SimpleRouter() + self.router.register(r'notes', NoteViewSet) + + from rest_framework.tests import test_routers + urls = getattr(test_routers, 'urlpatterns') + urls += patterns('', + url(r'^', include(self.router.urls)), + ) + + def test_custom_lookup_field_route(self): + detail_route = self.router.urls[-1] + detail_url_pattern = detail_route.regex.pattern + self.assertIn('<uuid>', detail_url_pattern) + + def test_retrieve_lookup_field_list_view(self): + response = self.client.get('/notes/') + self.assertEquals(response.data, + [{ + "url": "http://testserver/notes/123/", + "uuid": "123", "text": "foo bar" + }] + ) + + def test_retrieve_lookup_field_detail_view(self): + response = self.client.get('/notes/123/') + self.assertEquals(response.data, + { + "url": "http://testserver/notes/123/", + "uuid": "123", "text": "foo bar" + } + ) + diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index f2c31872..6cc913c5 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -1528,3 +1528,31 @@ class DefaultValuesOnAutogeneratedFieldsTests(TestCase): def test_url_field(self): self.field_test('url_field') + + +class MetadataSerializer(serializers.Serializer): + field1 = serializers.CharField(3, required=True) + field2 = serializers.CharField(10, required=False) + + +class MetadataSerializerTestCase(TestCase): + def setUp(self): + self.serializer = MetadataSerializer() + + def test_serializer_metadata(self): + metadata = self.serializer.metadata() + expected = { + 'field1': { + 'required': True, + 'max_length': 3, + 'type': 'string', + 'read_only': False + }, + 'field2': { + 'required': False, + 'max_length': 10, + 'type': 'string', + 'read_only': False + } + } + self.assertEqual(expected, metadata) diff --git a/rest_framework/tests/test_testcases.py b/rest_framework/tests/test_testcases.py deleted file mode 100644 index f8c2579e..00000000 --- a/rest_framework/tests/test_testcases.py +++ /dev/null @@ -1,66 +0,0 @@ -# http://djangosnippets.org/snippets/1011/ -from __future__ import unicode_literals -from django.conf import settings -from django.core.management import call_command -from django.db.models import loading -from django.test import TestCase - -NO_SETTING = ('!', None) - - -class TestSettingsManager(object): - """ - A class which can modify some Django settings temporarily for a - test and then revert them to their original values later. - - Automatically handles resyncing the DB if INSTALLED_APPS is - modified. - - """ - def __init__(self): - self._original_settings = {} - - def set(self, **kwargs): - for k, v in kwargs.iteritems(): - self._original_settings.setdefault(k, getattr(settings, k, - NO_SETTING)) - setattr(settings, k, v) - if 'INSTALLED_APPS' in kwargs: - self.syncdb() - - def syncdb(self): - loading.cache.loaded = False - call_command('syncdb', verbosity=0) - - def revert(self): - for k, v in self._original_settings.iteritems(): - if v == NO_SETTING: - delattr(settings, k) - else: - setattr(settings, k, v) - if 'INSTALLED_APPS' in self._original_settings: - self.syncdb() - self._original_settings = {} - - -class SettingsTestCase(TestCase): - """ - A subclass of the Django TestCase with a settings_manager - attribute which is an instance of TestSettingsManager. - - Comes with a tearDown() method that calls - self.settings_manager.revert(). - - """ - def __init__(self, *args, **kwargs): - super(SettingsTestCase, self).__init__(*args, **kwargs) - self.settings_manager = TestSettingsManager() - - def tearDown(self): - self.settings_manager.revert() - - -class TestModelsTestCase(SettingsTestCase): - def setUp(self, *args, **kwargs): - installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',) - self.settings_manager.set(INSTALLED_APPS=installed_apps) |
