aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/tests
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/tests')
-rw-r--r--rest_framework/tests/__init__.py13
-rw-r--r--rest_framework/tests/authentication.py41
-rw-r--r--rest_framework/tests/breadcrumbs.py2
-rw-r--r--rest_framework/tests/decorators.py17
-rw-r--r--rest_framework/tests/extras/__init__.py0
-rw-r--r--rest_framework/tests/extras/bad_import.py1
-rw-r--r--rest_framework/tests/fields.py49
-rw-r--r--rest_framework/tests/files.py67
-rw-r--r--rest_framework/tests/filterset.py168
-rw-r--r--rest_framework/tests/genericrelations.py2
-rw-r--r--rest_framework/tests/generics.py103
-rw-r--r--rest_framework/tests/htmlrenderer.py73
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py144
-rw-r--r--rest_framework/tests/models.py111
-rw-r--r--rest_framework/tests/modelviews.py2
-rw-r--r--rest_framework/tests/pagination.py206
-rw-r--r--rest_framework/tests/relations.py33
-rw-r--r--rest_framework/tests/relations_hyperlink.py434
-rw-r--r--rest_framework/tests/relations_nested.py114
-rw-r--r--rest_framework/tests/relations_pk.py411
-rw-r--r--rest_framework/tests/renderers.py92
-rw-r--r--rest_framework/tests/request.py43
-rw-r--r--rest_framework/tests/response.py11
-rw-r--r--rest_framework/tests/reverse.py2
-rw-r--r--rest_framework/tests/serializer.py553
-rw-r--r--rest_framework/tests/settings.py21
-rw-r--r--rest_framework/tests/testcases.py6
-rw-r--r--rest_framework/tests/tests.py13
-rw-r--r--rest_framework/tests/throttling.py2
-rw-r--r--rest_framework/tests/utils.py27
-rw-r--r--rest_framework/tests/validators.py10
-rw-r--r--rest_framework/tests/views.py4
32 files changed, 2636 insertions, 139 deletions
diff --git a/rest_framework/tests/__init__.py b/rest_framework/tests/__init__.py
index adeaf6da..e69de29b 100644
--- a/rest_framework/tests/__init__.py
+++ b/rest_framework/tests/__init__.py
@@ -1,13 +0,0 @@
-"""
-Force import of all modules in this package in order to get the standard test
-runner to pick up the tests. Yowzers.
-"""
-import os
-
-modules = [filename.rsplit('.', 1)[0]
- for filename in os.listdir(os.path.dirname(__file__))
- if filename.endswith('.py') and not filename.startswith('_')]
-__test__ = dict()
-
-for module in modules:
- exec("from rest_framework.tests.%s import *" % module)
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index 8ab4c4e4..e86041bc 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -1,16 +1,14 @@
-from django.conf.urls.defaults import patterns
from django.contrib.auth.models import User
-from django.test import Client, TestCase
-
-from django.utils import simplejson as json
from django.http import HttpResponse
+from django.test import Client, TestCase
-from rest_framework.views import APIView
from rest_framework import permissions
-
from rest_framework.authtoken.models import Token
from rest_framework.authentication import TokenAuthentication
+from rest_framework.compat import patterns
+from rest_framework.views import APIView
+import json
import base64
@@ -27,6 +25,7 @@ MockView.authentication_classes += (TokenAuthentication,)
urlpatterns = patterns('',
(r'^$', MockView.as_view()),
+ (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
)
@@ -152,3 +151,33 @@ class TokenAuthTests(TestCase):
self.token.delete()
token = Token.objects.create(user=self.user)
self.assertTrue(bool(token.key))
+
+ def test_token_login_json(self):
+ """Ensure token login view using JSON POST works."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ json.dumps({'username': self.username, 'password': self.password}), 'application/json')
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(json.loads(response.content)['token'], self.key)
+
+ def test_token_login_json_bad_creds(self):
+ """Ensure token login view using JSON POST fails if bad credentials are used."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ json.dumps({'username': self.username, 'password': "badpass"}), 'application/json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_json_missing_fields(self):
+ """Ensure token login view using JSON POST fails if missing fields."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ json.dumps({'username': self.username}), 'application/json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_form(self):
+ """Ensure token login view using form POST works."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username, 'password': self.password})
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(json.loads(response.content)['token'], self.key)
diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/breadcrumbs.py
index 647ab96d..df891683 100644
--- a/rest_framework/tests/breadcrumbs.py
+++ b/rest_framework/tests/breadcrumbs.py
@@ -1,5 +1,5 @@
-from django.conf.urls.defaults import patterns, url
from django.test import TestCase
+from rest_framework.compat import patterns, url
from rest_framework.utils.breadcrumbs import get_breadcrumbs
from rest_framework.views import APIView
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py
index 41864d71..5e6bce4e 100644
--- a/rest_framework/tests/decorators.py
+++ b/rest_framework/tests/decorators.py
@@ -1,7 +1,6 @@
from django.test import TestCase
from rest_framework import status
from rest_framework.response import Response
-from django.test.client import RequestFactory
from rest_framework.renderers import JSONRenderer
from rest_framework.parsers import JSONParser
from rest_framework.authentication import BasicAuthentication
@@ -17,6 +16,8 @@ from rest_framework.decorators import (
permission_classes,
)
+from rest_framework.tests.utils import RequestFactory
+
class DecoratorTestCase(TestCase):
@@ -63,6 +64,20 @@ class DecoratorTestCase(TestCase):
response = view(request)
self.assertEqual(response.status_code, 405)
+ def test_calling_patch_method(self):
+
+ @api_view(['GET', 'PATCH'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.patch('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 405)
+
def test_renderer_classes(self):
@api_view(['GET'])
diff --git a/rest_framework/tests/extras/__init__.py b/rest_framework/tests/extras/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/rest_framework/tests/extras/__init__.py
diff --git a/rest_framework/tests/extras/bad_import.py b/rest_framework/tests/extras/bad_import.py
new file mode 100644
index 00000000..68263d94
--- /dev/null
+++ b/rest_framework/tests/extras/bad_import.py
@@ -0,0 +1 @@
+raise ValueError
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
new file mode 100644
index 00000000..8068272d
--- /dev/null
+++ b/rest_framework/tests/fields.py
@@ -0,0 +1,49 @@
+"""
+General serializer field tests.
+"""
+
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class TimestampedModel(models.Model):
+ added = models.DateTimeField(auto_now_add=True)
+ updated = models.DateTimeField(auto_now=True)
+
+
+class CharPrimaryKeyModel(models.Model):
+ id = models.CharField(max_length=20, primary_key=True)
+
+
+class TimestampedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TimestampedModel
+
+
+class CharPrimaryKeyModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = CharPrimaryKeyModel
+
+
+class ReadOnlyFieldTests(TestCase):
+ def test_auto_now_fields_read_only(self):
+ """
+ auto_now and auto_now_add fields should be read_only by default.
+ """
+ serializer = TimestampedModelSerializer()
+ self.assertEquals(serializer.fields['added'].read_only, True)
+
+ def test_auto_pk_fields_read_only(self):
+ """
+ AutoField fields should be read_only by default.
+ """
+ serializer = TimestampedModelSerializer()
+ self.assertEquals(serializer.fields['id'].read_only, True)
+
+ def test_non_auto_pk_fields_not_read_only(self):
+ """
+ PK fields other than AutoField fields should not be read_only by default.
+ """
+ serializer = CharPrimaryKeyModelSerializer()
+ self.assertEquals(serializer.fields['id'].read_only, False)
diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py
index 61d7f7b1..446e23c0 100644
--- a/rest_framework/tests/files.py
+++ b/rest_framework/tests/files.py
@@ -1,34 +1,51 @@
-# from django.test import TestCase
-# from django import forms
+import StringIO
+import datetime
-# from django.test.client import RequestFactory
-# from rest_framework.views import View
-# from rest_framework.response import Response
+from django.test import TestCase
-# import StringIO
+from rest_framework import serializers
-# class UploadFilesTests(TestCase):
-# """Check uploading of files"""
-# def setUp(self):
-# self.factory = RequestFactory()
+class UploadedFile(object):
+ def __init__(self, file, created=None):
+ self.file = file
+ self.created = created or datetime.datetime.now()
-# def test_upload_file(self):
-# class FileForm(forms.Form):
-# file = forms.FileField()
+class UploadedFileSerializer(serializers.Serializer):
+ file = serializers.FileField()
+ created = serializers.DateTimeField()
-# class MockView(View):
-# permissions = ()
-# form = FileForm
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.file = attrs['file']
+ instance.created = attrs['created']
+ return instance
+ return UploadedFile(**attrs)
-# def post(self, request, *args, **kwargs):
-# return Response({'FILE_NAME': self.CONTENT['file'].name,
-# 'FILE_CONTENT': self.CONTENT['file'].read()})
-# file = StringIO.StringIO('stuff')
-# file.name = 'stuff.txt'
-# request = self.factory.post('/', {'file': file})
-# view = MockView.as_view()
-# response = view(request)
-# self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"})
+class FileSerializerTests(TestCase):
+ def test_create(self):
+ now = datetime.datetime.now()
+ file = StringIO.StringIO('stuff')
+ file.name = 'stuff.txt'
+ file.size = file.len
+ serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
+ uploaded_file = UploadedFile(file=file, created=now)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.object.created, uploaded_file.created)
+ self.assertEquals(serializer.object.file, uploaded_file.file)
+ self.assertFalse(serializer.object is uploaded_file)
+
+ def test_creation_failure(self):
+ """
+ Passing files=None should result in an ValidationError
+
+ Regression test for:
+ https://github.com/tomchristie/django-rest-framework/issues/542
+ """
+ now = datetime.datetime.now()
+
+ serializer = UploadedFileSerializer(data={'created': now})
+ self.assertFalse(serializer.is_valid())
+ self.assertIn('file', serializer.errors)
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
new file mode 100644
index 00000000..af2e6c2e
--- /dev/null
+++ b/rest_framework/tests/filterset.py
@@ -0,0 +1,168 @@
+import datetime
+from decimal import Decimal
+from django.test import TestCase
+from django.test.client import RequestFactory
+from django.utils import unittest
+from rest_framework import generics, status, filters
+from rest_framework.compat import django_filters
+from rest_framework.tests.models import FilterableItem, BasicModel
+
+factory = RequestFactory()
+
+
+if django_filters:
+ # Basic filter on a list view.
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_fields = ['decimal', 'date']
+ filter_backend = filters.DjangoFilterBackend
+
+ # These class are used to test a filter class.
+ class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterClassRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backend = filters.DjangoFilterBackend
+
+ # These classes are used to test a misconfigured filter class.
+ class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+
+ class Meta:
+ model = BasicModel
+ fields = ['text']
+
+ class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = MisconfiguredFilter
+ filter_backend = filters.DjangoFilterBackend
+
+
+class IntegrationTestFiltering(TestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
+ def setUp(self):
+ """
+ Create 10 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(10):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
+ ]
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_fields_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the date filter works.
+ search_date = datetime.date(2012, 9, 22)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] == search_date]
+ self.assertEquals(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_class_root_view(self):
+ """
+ GET requests to filtered ListCreateAPIView that have a filter_class set
+ should return filtered results.
+ """
+ view = FilterClassRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ # Tests that the decimal filter set with 'lt' in the filter class works.
+ search_decimal = Decimal('4.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] < search_decimal]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the date filter set with 'gt' in the filter class works.
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the text filter set with 'icontains' in the filter class works.
+ search_text = 'ff'
+ request = factory.get('/?text=%s' % search_text)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if search_text in f['text'].lower()]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date and
+ f['decimal'] < search_decimal]
+ self.assertEquals(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_incorrectly_configured_filter(self):
+ """
+ An error should be displayed when the filter class is misconfigured.
+ """
+ view = IncorrectlyConfiguredRootView.as_view()
+
+ request = factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_unknown_filter(self):
+ """
+ GET requests with filters that aren't configured should return 200.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ search_integer = 10
+ request = factory.get('/?integer=%s' % search_integer)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py
index 1d7e33bc..bc7378e1 100644
--- a/rest_framework/tests/genericrelations.py
+++ b/rest_framework/tests/genericrelations.py
@@ -25,7 +25,7 @@ class TestGenericRelations(TestCase):
model = Bookmark
exclude = ('id',)
- serializer = BookmarkSerializer(instance=self.bookmark)
+ serializer = BookmarkSerializer(self.bookmark)
expected = {
'tags': [u'django', u'python'],
'url': u'https://www.djangoproject.com/'
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index f4263478..4799a04b 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -1,8 +1,9 @@
+import json
+from django.db import models
from django.test import TestCase
-from django.test.client import RequestFactory
-from django.utils import simplejson as json
from rest_framework import generics, serializers, status
-from rest_framework.tests.models import BasicModel, Comment
+from rest_framework.tests.utils import RequestFactory
+from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
factory = RequestFactory()
@@ -22,6 +23,22 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView):
model = BasicModel
+class SlugSerializer(serializers.ModelSerializer):
+ slug = serializers.Field() # read only
+
+ class Meta:
+ model = SlugBasedModel
+ exclude = ('id',)
+
+
+class SlugBasedInstanceView(InstanceView):
+ """
+ A model with a slug-field.
+ """
+ model = SlugBasedModel
+ serializer_class = SlugSerializer
+
+
class TestRootView(TestCase):
def setUp(self):
"""
@@ -129,6 +146,7 @@ class TestInstanceView(TestCase):
for obj in self.objects.all()
]
self.view = InstanceView.as_view()
+ self.slug_based_view = SlugBasedInstanceView.as_view()
def test_get_instance_view(self):
"""
@@ -157,6 +175,20 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
+ response = self.view(request, pk='1').render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEquals(updated.text, 'foobar')
+
+ def test_patch_instance_view(self):
+ """
+ PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
+ """
+ content = {'text': 'foobar'}
+ request = factory.patch('/1', json.dumps(content),
+ content_type='application/json')
+
response = self.view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
@@ -198,7 +230,7 @@ class TestInstanceView(TestCase):
def test_put_cannot_set_id(self):
"""
- POST requests to create a new object should not be able to set the id.
+ PUT requests to create a new object should not be able to set the id.
"""
content = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
@@ -219,11 +251,39 @@ class TestInstanceView(TestCase):
request = factory.put('/1', json.dumps(content),
content_type='application/json')
response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
self.assertEquals(updated.text, 'foobar')
+ def test_put_as_create_on_id_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if it doesn't exist.
+ """
+ content = {'text': 'foobar'}
+ # pk fields can not be created on demand, only the database can set th pk for a new object
+ request = factory.put('/5', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request, pk=5).render()
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ new_obj = self.objects.get(pk=5)
+ self.assertEquals(new_obj.text, 'foobar')
+
+ def test_put_as_create_on_slug_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
+ """
+ content = {'text': 'foobar'}
+ request = factory.put('/test_slug', json.dumps(content),
+ content_type='application/json')
+ response = self.slug_based_view(request, slug='test_slug').render()
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'})
+ new_obj = SlugBasedModel.objects.get(slug='test_slug')
+ self.assertEquals(new_obj.text, 'foobar')
+
# Regression test for #285
@@ -256,3 +316,36 @@ class TestCreateModelWithAutoNowAddField(TestCase):
self.assertEquals(response.status_code, status.HTTP_201_CREATED)
created = self.objects.get(id=1)
self.assertEquals(created.content, 'foobar')
+
+
+# Test for particularly ugly reression with m2m in browseable API
+class ClassB(models.Model):
+ name = models.CharField(max_length=255)
+
+
+class ClassA(models.Model):
+ name = models.CharField(max_length=255)
+ childs = models.ManyToManyField(ClassB, blank=True, null=True)
+
+
+class ClassASerializer(serializers.ModelSerializer):
+ childs = serializers.ManyPrimaryKeyRelatedField(source='childs')
+
+ class Meta:
+ model = ClassA
+
+
+class ExampleView(generics.ListCreateAPIView):
+ serializer_class = ClassASerializer
+ model = ClassA
+
+
+class TestM2MBrowseableAPI(TestCase):
+ def test_m2m_in_browseable_api(self):
+ """
+ Test for particularly ugly reression with m2m in browseable API
+ """
+ request = factory.get('/', HTTP_ACCEPT='text/html')
+ view = ExampleView().as_view()
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py
index da2f83c3..54096206 100644
--- a/rest_framework/tests/htmlrenderer.py
+++ b/rest_framework/tests/htmlrenderer.py
@@ -1,14 +1,16 @@
-from django.conf.urls.defaults import patterns, url
+from django.core.exceptions import PermissionDenied
+from django.http import Http404
from django.test import TestCase
from django.template import TemplateDoesNotExist, Template
import django.template.loader
+from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view, renderer_classes
-from rest_framework.renderers import HTMLRenderer
+from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response
@api_view(('GET',))
-@renderer_classes((HTMLRenderer,))
+@renderer_classes((TemplateHTMLRenderer,))
def example(request):
"""
A view that can returns an HTML representation.
@@ -17,12 +19,26 @@ def example(request):
return Response(data, template_name='example.html')
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def permission_denied(request):
+ raise PermissionDenied()
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def not_found(request):
+ raise Http404()
+
+
urlpatterns = patterns('',
url(r'^$', example),
+ url(r'^permission_denied$', permission_denied),
+ url(r'^not_found$', not_found),
)
-class HTMLRendererTests(TestCase):
+class TemplateHTMLRendererTests(TestCase):
urls = 'rest_framework.tests.htmlrenderer'
def setUp(self):
@@ -48,3 +64,52 @@ class HTMLRendererTests(TestCase):
response = self.client.get('/')
self.assertContains(response, "example: foobar")
self.assertEquals(response['Content-Type'], 'text/html')
+
+ def test_not_found_html_view(self):
+ response = self.client.get('/not_found')
+ self.assertEquals(response.status_code, 404)
+ self.assertEquals(response.content, "404 Not Found")
+ self.assertEquals(response['Content-Type'], 'text/html')
+
+ def test_permission_denied_html_view(self):
+ response = self.client.get('/permission_denied')
+ self.assertEquals(response.status_code, 403)
+ self.assertEquals(response.content, "403 Forbidden")
+ self.assertEquals(response['Content-Type'], 'text/html')
+
+
+class TemplateHTMLRendererExceptionTests(TestCase):
+ urls = 'rest_framework.tests.htmlrenderer'
+
+ def setUp(self):
+ """
+ Monkeypatch get_template
+ """
+ self.get_template = django.template.loader.get_template
+
+ def get_template(template_name):
+ if template_name == '404.html':
+ return Template("404: {{ detail }}")
+ if template_name == '403.html':
+ return Template("403: {{ detail }}")
+ raise TemplateDoesNotExist(template_name)
+
+ django.template.loader.get_template = get_template
+
+ def tearDown(self):
+ """
+ Revert monkeypatching
+ """
+ django.template.loader.get_template = self.get_template
+
+ def test_not_found_html_view_with_template(self):
+ response = self.client.get('/not_found')
+ self.assertEquals(response.status_code, 404)
+ self.assertEquals(response.content, "404: Not found")
+ self.assertEquals(response['Content-Type'], 'text/html')
+
+ def test_permission_denied_html_view_with_template(self):
+ response = self.client.get('/permission_denied')
+ self.assertEquals(response.status_code, 403)
+ self.assertEquals(response.content, "403: Permission denied")
+ self.assertEquals(response['Content-Type'], 'text/html')
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index 5532a8ee..c6a8224b 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -1,12 +1,31 @@
-from django.conf.urls.defaults import patterns, url
+import json
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status, serializers
-from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel
+from rest_framework.compat import patterns, url
+from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel
factory = RequestFactory()
+class BlogPostCommentSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail')
+ text = serializers.CharField()
+ blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail')
+
+ class Meta:
+ model = BlogPostComment
+ fields = ('text', 'blog_post_url', 'url')
+
+
+class PhotoSerializer(serializers.Serializer):
+ description = serializers.CharField()
+ album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title')
+
+ def restore_object(self, attrs, instance=None):
+ return Photo(**attrs)
+
+
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
@@ -32,12 +51,46 @@ class ManyToManyDetail(generics.RetrieveAPIView):
model_serializer_class = serializers.HyperlinkedModelSerializer
+class BlogPostCommentListCreate(generics.ListCreateAPIView):
+ model = BlogPostComment
+ serializer_class = BlogPostCommentSerializer
+
+
+class BlogPostCommentDetail(generics.RetrieveAPIView):
+ model = BlogPostComment
+ serializer_class = BlogPostCommentSerializer
+
+
+class BlogPostDetail(generics.RetrieveAPIView):
+ model = BlogPost
+
+
+class PhotoListCreate(generics.ListCreateAPIView):
+ model = Photo
+ model_serializer_class = PhotoSerializer
+
+
+class AlbumDetail(generics.RetrieveAPIView):
+ model = Album
+
+
+class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
+ model = OptionalRelationModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
+ url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
+ url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'),
+ url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'),
+ url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'),
+ url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'),
+ url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'),
)
@@ -112,7 +165,7 @@ class TestManyToManyHyperlinkedView(TestCase):
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/manytomany/')
- response = self.list_view(request).render()
+ response = self.list_view(request)
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data)
@@ -121,6 +174,89 @@ class TestManyToManyHyperlinkedView(TestCase):
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/manytomany/1/')
- response = self.detail_view(request, pk=1).render()
+ response = self.detail_view(request, pk=1)
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0])
+
+
+class TestCreateWithForeignKeys(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create a blog post
+ """
+ self.post = BlogPost.objects.create(title="Test post")
+ self.create_view = BlogPostCommentListCreate.as_view()
+
+ def test_create_comment(self):
+
+ data = {
+ 'text': 'A test comment',
+ 'blog_post_url': 'http://testserver/posts/1/'
+ }
+
+ request = factory.post('/comments/', data=data)
+ response = self.create_view(request)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response['Location'], 'http://testserver/comments/1/')
+ self.assertEqual(self.post.blogpostcomment_set.count(), 1)
+ self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
+
+
+class TestCreateWithForeignKeysAndCustomSlug(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create an Album
+ """
+ self.post = Album.objects.create(title='test-album')
+ self.list_create_view = PhotoListCreate.as_view()
+
+ def test_create_photo(self):
+
+ data = {
+ 'description': 'A test photo',
+ 'album_url': 'http://testserver/albums/test-album/'
+ }
+
+ request = factory.post('/photos/', data=data)
+ response = self.list_create_view(request)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
+ self.assertEqual(self.post.photo_set.count(), 1)
+ self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
+
+
+class TestOptionalRelationHyperlinkedView(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 1 OptionalRelationModel intances.
+ """
+ OptionalRelationModel().save()
+ self.objects = OptionalRelationModel.objects
+ self.detail_view = OptionalRelationDetail.as_view()
+ self.data = {"url": "http://testserver/optionalrelation/1/", "other": None}
+
+ def test_get_detail_view(self):
+ """
+ GET requests to RetrieveAPIView with optional relations should return None
+ for non existing relations.
+ """
+ request = factory.get('/optionalrelationmodel-detail/1')
+ response = self.detail_view(request, pk=1)
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ def test_put_detail_view(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView with optional relations
+ should accept None for non existing relations.
+ """
+ response = self.client.put('/optionalrelation/1/',
+ data=json.dumps(self.data),
+ content_type='application/json')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 97cd0849..93f09761 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -35,15 +35,27 @@ def foobar():
return 'foobar'
+class CustomField(models.CharField):
+
+ def __init__(self, *args, **kwargs):
+ kwargs['max_length'] = 12
+ super(CustomField, self).__init__(*args, **kwargs)
+
+
class RESTFrameworkModel(models.Model):
"""
Base for test models that sets app_label, so they play nicely.
"""
class Meta:
- app_label = 'rest_framework'
+ app_label = 'tests'
abstract = True
+class HasPositiveIntegerAsChoice(RESTFrameworkModel):
+ some_choices = ((1, 'A'), (2, 'B'), (3, 'C'))
+ some_integer = models.PositiveIntegerField(choices=some_choices)
+
+
class Anchor(RESTFrameworkModel):
text = models.CharField(max_length=100, default='anchor')
@@ -52,8 +64,14 @@ class BasicModel(RESTFrameworkModel):
text = models.CharField(max_length=100)
+class SlugBasedModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ slug = models.SlugField(max_length=32)
+
+
class DefaultValueModel(RESTFrameworkModel):
text = models.CharField(default='foobar', max_length=100)
+ extra = models.CharField(blank=True, null=True, max_length=100)
class CallableDefaultValueModel(RESTFrameworkModel):
@@ -62,12 +80,12 @@ class CallableDefaultValueModel(RESTFrameworkModel):
class ManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
-
+
class ReadOnlyManyToManyModel(RESTFrameworkModel):
text = models.CharField(max_length=100, default='anchor')
rel = models.ManyToManyField(Anchor)
-
+
# Models to test generic relations
@@ -90,6 +108,13 @@ class Bookmark(RESTFrameworkModel):
tags = GenericRelation(TaggedItem)
+# Model to test filtering.
+class FilterableItem(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
# Model for regression test for #285
class Comment(RESTFrameworkModel):
@@ -101,13 +126,93 @@ class Comment(RESTFrameworkModel):
class ActionItem(RESTFrameworkModel):
title = models.CharField(max_length=200)
done = models.BooleanField(default=False)
+ info = CustomField(default='---', max_length=12)
# Models for reverse relations
+class Person(RESTFrameworkModel):
+ name = models.CharField(max_length=10)
+ age = models.IntegerField(null=True, blank=True)
+
+ @property
+ def info(self):
+ return {
+ 'name': self.name,
+ 'age': self.age,
+ }
+
+
class BlogPost(RESTFrameworkModel):
title = models.CharField(max_length=100)
+ writer = models.ForeignKey(Person, null=True, blank=True)
+
+ def get_first_comment(self):
+ return self.blogpostcomment_set.all()[0]
class BlogPostComment(RESTFrameworkModel):
text = models.TextField()
blog_post = models.ForeignKey(BlogPost)
+
+
+class Album(RESTFrameworkModel):
+ title = models.CharField(max_length=100, unique=True)
+
+
+class Photo(RESTFrameworkModel):
+ description = models.TextField()
+ album = models.ForeignKey(Album)
+
+
+# Model for issue #324
+class BlankFieldModel(RESTFrameworkModel):
+ title = models.CharField(max_length=100, blank=True, null=False)
+
+
+# Model for issue #380
+class OptionalRelationModel(RESTFrameworkModel):
+ other = models.ForeignKey('OptionalRelationModel', blank=True, null=True)
+
+
+# Model for RegexField
+class Book(RESTFrameworkModel):
+ isbn = models.CharField(max_length=13)
+
+
+# Models for relations tests
+# ManyToMany
+class ManyToManyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
+
+
+# ForeignKey
+class ForeignKeyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
+ related_name='nullable_sources')
+
+
+# OneToOne
+class OneToOneTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class NullableOneToOneSource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
+ related_name='nullable_source')
diff --git a/rest_framework/tests/modelviews.py b/rest_framework/tests/modelviews.py
index 1f8468e8..f12e3b97 100644
--- a/rest_framework/tests/modelviews.py
+++ b/rest_framework/tests/modelviews.py
@@ -1,4 +1,4 @@
-# from django.conf.urls.defaults import patterns, url
+# from rest_framework.compat import patterns, url
# from django.forms import ModelForm
# from django.contrib.auth.models import Group, User
# from rest_framework.resources import ModelResource
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index a939c9ef..3b550877 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -1,8 +1,12 @@
+import datetime
+from decimal import Decimal
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
-from rest_framework import generics, status, pagination
-from rest_framework.tests.models import BasicModel
+from django.utils import unittest
+from rest_framework import generics, status, pagination, filters, serializers
+from rest_framework.compat import django_filters
+from rest_framework.tests.models import BasicModel, FilterableItem
factory = RequestFactory()
@@ -15,6 +19,36 @@ class RootView(generics.ListCreateAPIView):
paginate_by = 10
+if django_filters:
+ class DecimalFilter(django_filters.FilterSet):
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_class = DecimalFilter
+ filter_backend = filters.DjangoFilterBackend
+
+
+class DefaultPageSizeKwargView(generics.ListAPIView):
+ """
+ View for testing default paginate_by_param usage
+ """
+ model = BasicModel
+
+
+class PaginateByParamView(generics.ListAPIView):
+ """
+ View for testing custom paginate_by_param usage
+ """
+ model = BasicModel
+ paginate_by_param = 'page_size'
+
+
class IntegrationTestPagination(TestCase):
"""
Integration tests for paginated list views.
@@ -22,7 +56,7 @@ class IntegrationTestPagination(TestCase):
def setUp(self):
"""
- Create 26 BasicModel intances.
+ Create 26 BasicModel instances.
"""
for char in 'abcdefghijklmnopqrstuvwxyz':
BasicModel(text=char * 3).save()
@@ -62,9 +96,66 @@ class IntegrationTestPagination(TestCase):
self.assertNotEquals(response.data['previous'], None)
+class IntegrationTestPaginationAndFiltering(TestCase):
+
+ def setUp(self):
+ """
+ Create 50 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(26):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
+ ]
+ self.view = FilterFieldsRootView.as_view()
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_paginated_filtered_root_view(self):
+ """
+ GET requests to paginated filtered ListCreateAPIView should return
+ paginated results. The next and previous links should preserve the
+ filtered parameters.
+ """
+ request = factory.get('/?decimal=15.20')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[10:15])
+ self.assertEquals(response.data['next'], None)
+ self.assertNotEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+
+class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
+ class Meta:
+ object_serializer_class = serializers.Serializer
+
+
class UnitTestPagination(TestCase):
"""
- Unit tests for pagination of primative objects.
+ Unit tests for pagination of primitive objects.
"""
def setUp(self):
@@ -74,14 +165,117 @@ class UnitTestPagination(TestCase):
self.last_page = paginator.page(3)
def test_native_pagination(self):
- serializer = pagination.PaginationSerializer(instance=self.first_page)
+ serializer = pagination.PaginationSerializer(self.first_page)
self.assertEquals(serializer.data['count'], 26)
self.assertEquals(serializer.data['next'], '?page=2')
self.assertEquals(serializer.data['previous'], None)
self.assertEquals(serializer.data['results'], self.objects[:10])
- serializer = pagination.PaginationSerializer(instance=self.last_page)
+ serializer = pagination.PaginationSerializer(self.last_page)
self.assertEquals(serializer.data['count'], 26)
self.assertEquals(serializer.data['next'], None)
self.assertEquals(serializer.data['previous'], '?page=2')
self.assertEquals(serializer.data['results'], self.objects[20:])
+
+ def test_context_available_in_result(self):
+ """
+ Ensure context gets passed through to the object serializer.
+ """
+ serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
+ serializer.data
+ results = serializer.fields[serializer.results_field]
+ self.assertEquals(serializer.context, results.context)
+
+
+class TestUnpaginated(TestCase):
+ """
+ Tests for list views without pagination.
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = DefaultPageSizeKwargView.as_view()
+
+ def test_unpaginated(self):
+ """
+ Tests the default page size for this view.
+ no page size --> no limit --> no meta data
+ """
+ request = factory.get('/')
+ response = self.view(request)
+ self.assertEquals(response.data, self.data)
+
+
+class TestCustomPaginateByParam(TestCase):
+ """
+ Tests for list views with default page size kwarg
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = PaginateByParamView.as_view()
+
+ def test_default_page_size(self):
+ """
+ Tests the default page size for this view.
+ no page size --> no limit --> no meta data
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEquals(response.data, self.data)
+
+ def test_paginate_by_param(self):
+ """
+ If paginate_by_param is set, the new kwarg should limit per view requests.
+ """
+ request = factory.get('/?page_size=5')
+ response = self.view(request).render()
+ self.assertEquals(response.data['count'], 13)
+ self.assertEquals(response.data['results'], self.data[:5])
+
+
+class CustomField(serializers.Field):
+ def to_native(self, value):
+ if not 'view' in self.context:
+ raise RuntimeError("context isn't getting passed into custom field")
+ return "value"
+
+
+class BasicModelSerializer(serializers.Serializer):
+ text = CustomField()
+
+
+class TestContextPassedToCustomField(TestCase):
+ def setUp(self):
+ BasicModel.objects.create(text='ala ma kota')
+
+ def test_with_pagination(self):
+ class ListView(generics.ListCreateAPIView):
+ model = BasicModel
+ serializer_class = BasicModelSerializer
+ paginate_by = 1
+
+ self.view = ListView.as_view()
+ request = factory.get('/')
+ response = self.view(request).render()
+
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+
diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py
new file mode 100644
index 00000000..91daea8a
--- /dev/null
+++ b/rest_framework/tests/relations.py
@@ -0,0 +1,33 @@
+"""
+General tests for relational fields.
+"""
+
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class NullModel(models.Model):
+ pass
+
+
+class FieldTests(TestCase):
+ def test_pk_related_field_with_empty_string(self):
+ """
+ Regression test for #446
+
+ https://github.com/tomchristie/django-rest-framework/issues/446
+ """
+ field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_hyperlinked_related_field_with_empty_string(self):
+ field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_slug_related_field_with_empty_string(self):
+ field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
new file mode 100644
index 00000000..57913670
--- /dev/null
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -0,0 +1,434 @@
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.compat import patterns, url
+from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+
+def dummy_view(request, pk):
+ pass
+
+urlpatterns = patterns('',
+ url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
+ url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
+ url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
+ url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),
+ url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
+ url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
+ url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
+)
+
+class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
+ sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail')
+
+ class Meta:
+ model = ManyToManyTarget
+
+
+class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ManyToManySource
+
+
+class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
+ sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail')
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+
+
+# Nullable ForeignKey
+class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
+
+
+# OneToOne
+class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
+ nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail')
+
+ class Meta:
+ model = OneToOneTarget
+
+
+# TODO: Add test that .data cannot be accessed prior to .is_valid
+
+class HyperlinkedManyToManyTests(TestCase):
+ urls = 'rest_framework.tests.relations_hyperlink'
+
+ def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
+
+ def test_many_to_many_retrieve(self):
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']},
+ {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
+ {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_retrieve(self):
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
+ {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
+ {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_many_to_many_update(self):
+ data = {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']},
+ {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
+ {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_update(self):
+ data = {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure target 1 is updated, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']},
+ {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
+ {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}
+
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_many_to_many_create(self):
+ data = {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']}
+ serializer = ManyToManySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']},
+ {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
+ {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']},
+ {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_create(self):
+ data = {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']}
+ serializer = ManyToManyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-4')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
+ {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
+ {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']},
+ {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+
+class HyperlinkedForeignKeyTests(TestCase):
+ urls = 'rest_framework.tests.relations_hyperlink'
+
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
+ {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'},
+ {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_update(self):
+ data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
+ {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']},
+ {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'}
+ serializer = ForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-3')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']},
+ {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
+ {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
+
+
+class HyperlinkedNullableForeignKeyTests(TestCase):
+ urls = 'rest_framework.tests.relations_hyperlink'
+
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': ''}
+ expected_data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, expected_data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None},
+ {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': ''}
+ expected_data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None},
+ {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ # reverse foreign keys MUST be read_only
+ # In the general case they do not provide .remove() or .clear()
+ # and cannot be arbitrarily set.
+
+ # def test_reverse_foreign_key_update(self):
+ # data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ # instance = ForeignKeyTarget.objects.get(pk=1)
+ # serializer = ForeignKeyTargetSerializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # self.assertEquals(serializer.data, data)
+ # serializer.save()
+
+ # # Ensure target 1 is updated, and everything else is as expected
+ # queryset = ForeignKeyTarget.objects.all()
+ # serializer = ForeignKeyTargetSerializer(queryset)
+ # expected = [
+ # {'id': 1, 'name': u'target-1', 'sources': [1]},
+ # {'id': 2, 'name': u'target-2', 'sources': []},
+ # ]
+ # self.assertEquals(serializer.data, expected)
+
+
+class HyperlinkedNullableOneToOneTests(TestCase):
+ urls = 'rest_framework.tests.relations_hyperlink'
+
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset)
+ expected = [
+ {'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'},
+ {'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None},
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
new file mode 100644
index 00000000..0e129fae
--- /dev/null
+++ b/rest_framework/tests/relations_nested.py
@@ -0,0 +1,114 @@
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ depth = 1
+ model = ForeignKeySource
+
+
+class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+
+
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ sources = FlatForeignKeySourceSerializer()
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ depth = 1
+ model = NullableForeignKeySource
+
+
+class NullableOneToOneSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableOneToOneSource
+
+
+class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
+ nullable_source = NullableOneToOneSourceSerializer()
+
+ class Meta:
+ model = OneToOneTarget
+
+
+class ReverseForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 3, 'name': u'source-3', 'target': {'id': 1, 'name': u'target-1'}},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': 1},
+ ]},
+ {'id': 2, 'name': u'target-2', 'sources': [
+ ]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+
+class NestedNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+
+class NestedNullableOneToOneTests(TestCase):
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}},
+ {'id': 2, 'name': u'target-2', 'nullable_source': None},
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
new file mode 100644
index 00000000..54835860
--- /dev/null
+++ b/rest_framework/tests/relations_pk.py
@@ -0,0 +1,411 @@
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+
+
+class ManyToManyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.ManyPrimaryKeyRelatedField()
+
+ class Meta:
+ model = ManyToManyTarget
+
+
+class ManyToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManySource
+
+
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.ManyPrimaryKeyRelatedField()
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+
+
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
+
+
+# OneToOne
+class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
+ nullable_source = serializers.PrimaryKeyRelatedField()
+
+ class Meta:
+ model = OneToOneTarget
+
+
+# TODO: Add test that .data cannot be accessed prior to .is_valid
+
+class PKManyToManyTests(TestCase):
+ def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
+
+ def test_many_to_many_retrieve(self):
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'targets': [1]},
+ {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_retrieve(self):
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': u'target-3', 'sources': [3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_many_to_many_update(self):
+ data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]},
+ {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_update(self):
+ data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure target 1 is updated, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1]},
+ {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': u'target-3', 'sources': [3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_many_to_many_create(self):
+ data = {'id': 4, 'name': u'source-4', 'targets': [1, 3]}
+ serializer = ManyToManySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'targets': [1]},
+ {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]},
+ {'id': 4, 'name': u'source-4', 'targets': [1, 3]},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_create(self):
+ data = {'id': 4, 'name': u'target-4', 'sources': [1, 3]}
+ serializer = ManyToManyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-4')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': u'target-3', 'sources': [3]},
+ {'id': 4, 'name': u'target-4', 'sources': [1, 3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+
+class PKForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': 1}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': u'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 2},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': 1}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': u'target-2', 'sources': [1, 3]}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [2]},
+ {'id': 2, 'name': u'target-2', 'sources': [1, 3]},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': u'source-4', 'target': 2}
+ serializer = ForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': 1},
+ {'id': 4, 'name': u'source-4', 'target': 2},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': u'target-3', 'sources': [1, 3]}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [2]},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ {'id': 3, 'name': u'target-3', 'sources': [1, 3]},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': u'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
+
+
+class PKNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'id': 4, 'name': u'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ {'id': 4, 'name': u'source-4', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 4, 'name': u'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': u'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, expected_data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ {'id': 4, 'name': u'source-4', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'id': 1, 'name': u'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': None},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 1, 'name': u'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': u'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': None},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ # reverse foreign keys MUST be read_only
+ # In the general case they do not provide .remove() or .clear()
+ # and cannot be arbitrarily set.
+
+ # def test_reverse_foreign_key_update(self):
+ # data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ # instance = ForeignKeyTarget.objects.get(pk=1)
+ # serializer = ForeignKeyTargetSerializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # self.assertEquals(serializer.data, data)
+ # serializer.save()
+
+ # # Ensure target 1 is updated, and everything else is as expected
+ # queryset = ForeignKeyTarget.objects.all()
+ # serializer = ForeignKeyTargetSerializer(queryset)
+ # expected = [
+ # {'id': 1, 'name': u'target-1', 'sources': [1]},
+ # {'id': 2, 'name': u'target-2', 'sources': []},
+ # ]
+ # self.assertEquals(serializer.data, expected)
+
+
+class PKNullableOneToOneTests(TestCase):
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'nullable_source': 1},
+ {'id': 2, 'name': u'target-2', 'nullable_source': None},
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py
index 48d8d9bd..c1b4e624 100644
--- a/rest_framework/tests/renderers.py
+++ b/rest_framework/tests/renderers.py
@@ -1,11 +1,12 @@
+import pickle
import re
-from django.conf.urls.defaults import patterns, url, include
+from django.core.cache import cache
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import status, permissions
-from rest_framework.compat import yaml
+from rest_framework.compat import yaml, patterns, url, include
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
@@ -83,6 +84,7 @@ class HTMLView1(APIView):
urlpatterns = patterns('',
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^cache$', MockGETView.as_view()),
url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])),
url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])),
url(r'^html$', HTMLView.as_view()),
@@ -416,3 +418,89 @@ class XMLRendererTestCase(TestCase):
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
self.assertTrue(xml.endswith('</root>'))
self.assertTrue(string in xml, '%r not in %r' % (string, xml))
+
+
+# Tests for caching issue, #346
+class CacheRenderTest(TestCase):
+ """
+ Tests specific to caching responses
+ """
+
+ urls = 'rest_framework.tests.renderers'
+
+ cache_key = 'just_a_cache_key'
+
+ @classmethod
+ def _get_pickling_errors(cls, obj, seen=None):
+ """ Return any errors that would be raised if `obj' is pickled
+ Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897
+ """
+ if seen == None:
+ seen = []
+ try:
+ state = obj.__getstate__()
+ except AttributeError:
+ return
+ if state == None:
+ return
+ if isinstance(state, tuple):
+ if not isinstance(state[0], dict):
+ state = state[1]
+ else:
+ state = state[0].update(state[1])
+ result = {}
+ for i in state:
+ try:
+ pickle.dumps(state[i], protocol=2)
+ except pickle.PicklingError:
+ if not state[i] in seen:
+ seen.append(state[i])
+ result[i] = cls._get_pickling_errors(state[i], seen)
+ return result
+
+ def http_resp(self, http_method, url):
+ """
+ Simple wrapper for Client http requests
+ Removes the `client' and `request' attributes from as they are
+ added by django.test.client.Client and not part of caching
+ responses outside of tests.
+ """
+ method = getattr(self.client, http_method)
+ resp = method(url)
+ del resp.client, resp.request
+ return resp
+
+ def test_obj_pickling(self):
+ """
+ Test that responses are properly pickled
+ """
+ resp = self.http_resp('get', '/cache')
+
+ # Make sure that no pickling errors occurred
+ self.assertEqual(self._get_pickling_errors(resp), {})
+
+ # Unfortunately LocMem backend doesn't raise PickleErrors but returns
+ # None instead.
+ cache.set(self.cache_key, resp)
+ self.assertTrue(cache.get(self.cache_key) is not None)
+
+ def test_head_caching(self):
+ """
+ Test caching of HEAD requests
+ """
+ resp = self.http_resp('head', '/cache')
+ cache.set(self.cache_key, resp)
+
+ cached_resp = cache.get(self.cache_key)
+ self.assertIsInstance(cached_resp, Response)
+
+ def test_get_caching(self):
+ """
+ Test caching of GET requests
+ """
+ resp = self.http_resp('get', '/cache')
+ cache.set(self.cache_key, resp)
+
+ cached_resp = cache.get(self.cache_key)
+ self.assertIsInstance(cached_resp, Response)
+ self.assertEqual(cached_resp.content, resp.content)
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py
index ff48f3fa..4b032405 100644
--- a/rest_framework/tests/request.py
+++ b/rest_framework/tests/request.py
@@ -1,14 +1,15 @@
"""
Tests for content parsing, and form-overloaded content parsing.
"""
-from django.conf.urls.defaults import patterns
+import json
from django.contrib.auth.models import User
+from django.contrib.auth import authenticate, login, logout
+from django.contrib.sessions.middleware import SessionMiddleware
from django.test import TestCase, Client
-from django.utils import simplejson as json
-
+from django.test.client import RequestFactory
from rest_framework import status
from rest_framework.authentication import SessionAuthentication
-from django.test.client import RequestFactory
+from rest_framework.compat import patterns
from rest_framework.parsers import (
BaseParser,
FormParser,
@@ -276,3 +277,37 @@ class TestContentParsingWithAuthentication(TestCase):
# response = self.csrf_client.post('/', content)
# self.assertEqual(status.OK, response.status_code, "POST data is malformed")
+
+
+class TestUserSetter(TestCase):
+
+ def setUp(self):
+ # Pass request object through session middleware so session is
+ # available to login and logout functions
+ self.request = Request(factory.get('/'))
+ SessionMiddleware().process_request(self.request)
+
+ User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
+ self.user = authenticate(username='ringo', password='yellow')
+
+ def test_user_can_be_set(self):
+ self.request.user = self.user
+ self.assertEqual(self.request.user, self.user)
+
+ def test_user_can_login(self):
+ login(self.request, self.user)
+ self.assertEqual(self.request.user, self.user)
+
+ def test_user_can_logout(self):
+ self.request.user = self.user
+ self.assertFalse(self.request.user.is_anonymous())
+ logout(self.request)
+ self.assertTrue(self.request.user.is_anonymous())
+
+
+class TestAuthSetter(TestCase):
+
+ def test_auth_can_be_set(self):
+ request = Request(factory.get('/'))
+ request.auth = 'DUMMY'
+ self.assertEqual(request.auth, 'DUMMY')
diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py
index 18b6af39..875f4d42 100644
--- a/rest_framework/tests/response.py
+++ b/rest_framework/tests/response.py
@@ -1,8 +1,5 @@
-import unittest
-
-from django.conf.urls.defaults import patterns, url, include
from django.test import TestCase
-
+from rest_framework.compat import patterns, url, include
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework import status
@@ -131,12 +128,6 @@ class RendererIntegrationTests(TestCase):
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
- @unittest.skip('can\'t pass because view is a simple Django view and response is an ImmediateResponse')
- def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
- """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
- resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
- self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
-
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/reverse.py
index fd9a7d64..8c86e1fb 100644
--- a/rest_framework/tests/reverse.py
+++ b/rest_framework/tests/reverse.py
@@ -1,6 +1,6 @@
-from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from django.test.client import RequestFactory
+from rest_framework.compat import patterns, url
from rest_framework.reverse import reverse
factory = RequestFactory()
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index 936f15aa..bd96ba23 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -1,13 +1,16 @@
import datetime
+import pickle
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import *
+from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
+ BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,
+ ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
class SubComment(object):
def __init__(self, sub_comment):
self.sub_comment = sub_comment
-
+
class Comment(object):
def __init__(self, email, content, created):
@@ -18,7 +21,7 @@ class Comment(object):
def __eq__(self, other):
return all([getattr(self, attr) == getattr(other, attr)
for attr in ('email', 'content', 'created')])
-
+
def get_sub_comment(self):
sub_comment = SubComment('And Merry Christmas!')
return sub_comment
@@ -29,7 +32,7 @@ class CommentSerializer(serializers.Serializer):
content = serializers.CharField(max_length=1000)
created = serializers.DateTimeField()
sub_comment = serializers.Field(source='get_sub_comment.sub_comment')
-
+
def restore_object(self, data, instance=None):
if instance is None:
return Comment(**data)
@@ -38,10 +41,41 @@ class CommentSerializer(serializers.Serializer):
return instance
+class BookSerializer(serializers.ModelSerializer):
+ isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'})
+
+ class Meta:
+ model = Book
+
+
class ActionItemSerializer(serializers.ModelSerializer):
+
class Meta:
model = ActionItem
+
+class PersonSerializer(serializers.ModelSerializer):
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+ read_only_fields = ('age',)
+
+
+class AlbumsSerializer(serializers.ModelSerializer):
+
+ class Meta:
+ model = Album
+ fields = ['title'] # lists are also valid options
+
+
+class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = HasPositiveIntegerAsChoice
+ fields = ['some_integer']
+
+
class BasicTests(TestCase):
def setUp(self):
self.comment = Comment(
@@ -61,6 +95,9 @@ class BasicTests(TestCase):
'created': datetime.datetime(2012, 1, 1),
'sub_comment': 'And Merry Christmas!'
}
+ self.person_data = {'name': 'dwight', 'age': 35}
+ self.person = Person(**self.person_data)
+ self.person.save()
def test_empty(self):
serializer = CommentSerializer()
@@ -73,11 +110,11 @@ class BasicTests(TestCase):
self.assertEquals(serializer.data, expected)
def test_retrieve(self):
- serializer = CommentSerializer(instance=self.comment)
+ serializer = CommentSerializer(self.comment)
self.assertEquals(serializer.data, self.expected)
def test_create(self):
- serializer = CommentSerializer(self.data)
+ serializer = CommentSerializer(data=self.data)
expected = self.comment
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected)
@@ -85,13 +122,54 @@ class BasicTests(TestCase):
self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
def test_update(self):
- serializer = CommentSerializer(self.data, instance=self.comment)
+ serializer = CommentSerializer(self.comment, data=self.data)
expected = self.comment
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected)
self.assertTrue(serializer.object is expected)
self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
+ def test_partial_update(self):
+ msg = 'Merry New Year!'
+ partial_data = {'content': msg}
+ serializer = CommentSerializer(self.comment, data=partial_data)
+ self.assertEquals(serializer.is_valid(), False)
+ serializer = CommentSerializer(self.comment, data=partial_data, partial=True)
+ expected = self.comment
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEquals(serializer.object, expected)
+ self.assertTrue(serializer.object is expected)
+ self.assertEquals(serializer.data['content'], msg)
+
+ def test_model_fields_as_expected(self):
+ """
+ Make sure that the fields returned are the same as defined
+ in the Meta data
+ """
+ serializer = PersonSerializer(self.person)
+ self.assertEquals(set(serializer.data.keys()),
+ set(['name', 'age', 'info']))
+
+ def test_field_with_dictionary(self):
+ """
+ Make sure that dictionaries from fields are left intact
+ """
+ serializer = PersonSerializer(self.person)
+ expected = self.person_data
+ self.assertEquals(serializer.data['info'], expected)
+
+ def test_read_only_fields(self):
+ """
+ Attempting to update fields set as read_only should have no effect.
+ """
+
+ serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(serializer.errors, {})
+ # Assert age is unchanged (35)
+ self.assertEquals(instance.age, self.person_data['age'])
+
class ValidationTests(TestCase):
def setUp(self):
@@ -104,17 +182,17 @@ class ValidationTests(TestCase):
'email': 'tom@example.com',
'content': 'x' * 1001,
'created': datetime.datetime(2012, 1, 1)
- }
- self.actionitem = ActionItem('Some to do item',
+ }
+ self.actionitem = ActionItem(title='Some to do item',
)
def test_create(self):
- serializer = CommentSerializer(self.data)
+ serializer = CommentSerializer(data=self.data)
self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
def test_update(self):
- serializer = CommentSerializer(self.data, instance=self.comment)
+ serializer = CommentSerializer(self.comment, data=self.data)
self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
@@ -123,7 +201,7 @@ class ValidationTests(TestCase):
'content': 'xxx',
'created': datetime.datetime(2012, 1, 1)
}
- serializer = CommentSerializer(data, instance=self.comment)
+ serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'email': [u'This field is required.']})
@@ -131,10 +209,10 @@ class ValidationTests(TestCase):
"""Make sure that a boolean value with a 'False' value is not
mistaken for not having a default."""
data = {
- 'title':'Some action item',
+ 'title': 'Some action item',
#No 'done' value.
}
- serializer = ActionItemSerializer(data, instance=self.actionitem)
+ serializer = ActionItemSerializer(self.actionitem, data=data)
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.errors, {})
@@ -154,15 +232,34 @@ class ValidationTests(TestCase):
'created': datetime.datetime(2012, 1, 1)
}
- serializer = CommentSerializerWithFieldValidator(data)
+ serializer = CommentSerializerWithFieldValidator(data=data)
self.assertTrue(serializer.is_valid())
data['content'] = 'This should not validate'
- serializer = CommentSerializerWithFieldValidator(data)
+ serializer = CommentSerializerWithFieldValidator(data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
+ def test_bad_type_data_is_false(self):
+ """
+ Data of the wrong type is not valid.
+ """
+ data = ['i am', 'a', 'list']
+ serializer = CommentSerializer(self.comment, data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
+
+ data = 'and i am a string'
+ serializer = CommentSerializer(self.comment, data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
+
+ data = 42
+ serializer = CommentSerializer(self.comment, data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
+
def test_cross_field_validation(self):
class CommentSerializerWithCrossFieldValidator(CommentSerializer):
@@ -178,15 +275,109 @@ class ValidationTests(TestCase):
'created': datetime.datetime(2012, 1, 1)
}
- serializer = CommentSerializerWithCrossFieldValidator(data)
+ serializer = CommentSerializerWithCrossFieldValidator(data=data)
self.assertTrue(serializer.is_valid())
data['content'] = 'A comment from foo@bar.com'
- serializer = CommentSerializerWithCrossFieldValidator(data)
+ serializer = CommentSerializerWithCrossFieldValidator(data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']})
+ def test_null_is_true_fields(self):
+ """
+ Omitting a value for null-field should validate.
+ """
+ serializer = PersonSerializer(data={'name': 'marko'})
+ self.assertEquals(serializer.is_valid(), True)
+ self.assertEquals(serializer.errors, {})
+
+ def test_modelserializer_max_length_exceeded(self):
+ data = {
+ 'title': 'x' * 201,
+ }
+ serializer = ActionItemSerializer(data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']})
+
+ def test_default_modelfield_max_length_exceeded(self):
+ data = {
+ 'title': 'Testing "info" field...',
+ 'info': 'x' * 13,
+ }
+ serializer = ActionItemSerializer(data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']})
+
+
+class PositiveIntegerAsChoiceTests(TestCase):
+ def test_positive_integer_in_json_is_correctly_parsed(self):
+ data = {'some_integer':1}
+ serializer = PositiveIntegerAsChoiceSerializer(data=data)
+ self.assertEquals(serializer.is_valid(), True)
+
+class ModelValidationTests(TestCase):
+ def test_validate_unique(self):
+ """
+ Just check if serializers.ModelSerializer handles unique checks via .full_clean()
+ """
+ serializer = AlbumsSerializer(data={'title': 'a'})
+ serializer.is_valid()
+ serializer.save()
+ second_serializer = AlbumsSerializer(data={'title': 'a'})
+ self.assertFalse(second_serializer.is_valid())
+ self.assertEqual(second_serializer.errors, {'title': [u'Album with this Title already exists.']})
+
+ def test_foreign_key_with_partial(self):
+ """
+ Test ModelSerializer validation with partial=True
+
+ Specifically test foreign key validation.
+ """
+
+ album = Album(title='test')
+ album.save()
+
+ class PhotoSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Photo
+
+ photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk})
+ self.assertTrue(photo_serializer.is_valid())
+ photo = photo_serializer.save()
+
+ # Updating only the album (foreign key)
+ photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True)
+ self.assertTrue(photo_serializer.is_valid())
+ self.assertTrue(photo_serializer.save())
+
+ # Updating only the description
+ photo_serializer = PhotoSerializer(instance=photo,
+ data={'description': 'new'},
+ partial=True)
+
+ self.assertTrue(photo_serializer.is_valid())
+ self.assertTrue(photo_serializer.save())
+
+
+class RegexValidationTest(TestCase):
+ def test_create_failed(self):
+ serializer = BookSerializer(data={'isbn': '1234567890'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+
+ serializer = BookSerializer(data={'isbn': '12345678901234'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+
+ serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+
+ def test_create_success(self):
+ serializer = BookSerializer(data={'isbn': '1234567890123'})
+ self.assertTrue(serializer.is_valid())
+
class MetadataTests(TestCase):
def test_empty(self):
@@ -233,7 +424,7 @@ class ManyToManyTests(TestCase):
Create an instance of a model with a ManyToMany relationship.
"""
data = {'rel': [self.anchor.id]}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
@@ -247,7 +438,7 @@ class ManyToManyTests(TestCase):
new_anchor = Anchor()
new_anchor.save()
data = {'rel': [self.anchor.id, new_anchor.id]}
- serializer = self.serializer_class(data, instance=self.instance)
+ serializer = self.serializer_class(self.instance, data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 1)
@@ -260,7 +451,7 @@ class ManyToManyTests(TestCase):
containing no items.
"""
data = {'rel': []}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
@@ -275,7 +466,7 @@ class ManyToManyTests(TestCase):
new_anchor = Anchor()
new_anchor.save()
data = {'rel': []}
- serializer = self.serializer_class(data, instance=self.instance)
+ serializer = self.serializer_class(self.instance, data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 1)
@@ -289,17 +480,19 @@ class ManyToManyTests(TestCase):
lists (eg form data).
"""
data = {'rel': ''}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
self.assertEquals(instance.pk, 2)
self.assertEquals(list(instance.rel.all()), [])
-
+
+
class ReadOnlyManyToManyTests(TestCase):
def setUp(self):
class ReadOnlyManyToManySerializer(serializers.ModelSerializer):
- rel = serializers.ManyRelatedField(readonly=True)
+ rel = serializers.ManyRelatedField(read_only=True)
+
class Meta:
model = ReadOnlyManyToManyModel
@@ -317,16 +510,15 @@ class ReadOnlyManyToManyTests(TestCase):
# A serialized representation of the model instance
self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'}
-
def test_update(self):
"""
Attempt to update an instance of a model with a ManyToMany
- relationship. Not updated due to readonly=True
+ relationship. Not updated due to read_only=True
"""
new_anchor = Anchor()
new_anchor.save()
data = {'rel': [self.anchor.id, new_anchor.id]}
- serializer = self.serializer_class(data, instance=self.instance)
+ serializer = self.serializer_class(self.instance, data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
@@ -337,12 +529,12 @@ class ReadOnlyManyToManyTests(TestCase):
def test_update_without_relationship(self):
"""
Attempt to update an instance of a model where many to ManyToMany
- relationship is not supplied. Not updated due to readonly=True
+ relationship is not supplied. Not updated due to read_only=True
"""
new_anchor = Anchor()
new_anchor.save()
data = {}
- serializer = self.serializer_class(data, instance=self.instance)
+ serializer = self.serializer_class(self.instance, data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
@@ -362,7 +554,7 @@ class DefaultValueTests(TestCase):
def test_create_using_default(self):
data = {}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(self.objects.all()), 1)
@@ -371,13 +563,28 @@ class DefaultValueTests(TestCase):
def test_create_overriding_default(self):
data = {'text': 'overridden'}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(self.objects.all()), 1)
self.assertEquals(instance.pk, 1)
self.assertEquals(instance.text, 'overridden')
+ def test_partial_update_default(self):
+ """ Regression test for issue #532 """
+ data = {'text': 'overridden'}
+ serializer = self.serializer_class(data=data, partial=True)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+
+ data = {'extra': 'extra_value'}
+ serializer = self.serializer_class(instance=instance, data=data, partial=True)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+
+ self.assertEquals(instance.extra, 'extra_value')
+ self.assertEquals(instance.text, 'overridden')
+
class CallableDefaultValueTests(TestCase):
def setUp(self):
@@ -390,7 +597,7 @@ class CallableDefaultValueTests(TestCase):
def test_create_using_default(self):
data = {}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(self.objects.all()), 1)
@@ -399,7 +606,7 @@ class CallableDefaultValueTests(TestCase):
def test_create_overriding_default(self):
data = {'text': 'overridden'}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(self.objects.all()), 1)
@@ -408,7 +615,10 @@ class CallableDefaultValueTests(TestCase):
class ManyRelatedTests(TestCase):
- def setUp(self):
+ def test_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
class BlogPostCommentSerializer(serializers.Serializer):
text = serializers.CharField()
@@ -417,14 +627,7 @@ class ManyRelatedTests(TestCase):
title = serializers.CharField()
comments = BlogPostCommentSerializer(source='blogpostcomment_set')
- self.serializer_class = BlogPostSerializer
-
- def test_reverse_relations(self):
- post = BlogPost.objects.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I hate this blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- serializer = self.serializer_class(instance=post)
+ serializer = BlogPostSerializer(instance=post)
expected = {
'title': 'Test blog post',
'comments': [
@@ -434,3 +637,267 @@ class ManyRelatedTests(TestCase):
}
self.assertEqual(serializer.data, expected)
+
+ def test_callable_source(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ first_comment = BlogPostCommentSerializer(source='get_first_comment')
+
+ serializer = BlogPostSerializer(post)
+
+ expected = {
+ 'title': 'Test blog post',
+ 'first_comment': {'text': 'I love this blog post'}
+ }
+ self.assertEqual(serializer.data, expected)
+
+
+class RelatedTraversalTest(TestCase):
+ def test_nested_traversal(self):
+ user = Person.objects.create(name="django")
+ post = BlogPost.objects.create(title="Test blog post", writer=user)
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ from rest_framework.tests.models import BlogPostComment
+
+ class PersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+ fields = ("name", "age")
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPostComment
+ fields = ("text", "post_owner")
+
+ text = serializers.CharField()
+ post_owner = PersonSerializer(source='blog_post.writer')
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ comments = BlogPostCommentSerializer(source='blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+
+ expected = {
+ 'title': u'Test blog post',
+ 'comments': [{
+ 'text': u'I love this blog post',
+ 'post_owner': {
+ "name": u"django",
+ "age": None
+ }
+ }]
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+
+class SerializerMethodFieldTests(TestCase):
+ def setUp(self):
+
+ class BoopSerializer(serializers.Serializer):
+ beep = serializers.SerializerMethodField('get_beep')
+ boop = serializers.Field()
+ boop_count = serializers.SerializerMethodField('get_boop_count')
+
+ def get_beep(self, obj):
+ return 'hello!'
+
+ def get_boop_count(self, obj):
+ return len(obj.boop)
+
+ self.serializer_class = BoopSerializer
+
+ def test_serializer_method_field(self):
+
+ class MyModel(object):
+ boop = ['a', 'b', 'c']
+
+ source_data = MyModel()
+
+ serializer = self.serializer_class(source_data)
+
+ expected = {
+ 'beep': u'hello!',
+ 'boop': [u'a', u'b', u'c'],
+ 'boop_count': 3,
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+
+# Test for issue #324
+class BlankFieldTests(TestCase):
+ def setUp(self):
+
+ class BlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlankFieldModel
+
+ class BlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField(blank=True)
+
+ class NotBlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+ class NotBlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField()
+
+ self.model_serializer_class = BlankFieldModelSerializer
+ self.serializer_class = BlankFieldSerializer
+ self.not_blank_model_serializer_class = NotBlankFieldModelSerializer
+ self.not_blank_serializer_class = NotBlankFieldSerializer
+ self.data = {'title': ''}
+
+ def test_create_blank_field(self):
+ serializer = self.serializer_class(data=self.data)
+ self.assertEquals(serializer.is_valid(), True)
+
+ def test_create_model_blank_field(self):
+ serializer = self.model_serializer_class(data=self.data)
+ self.assertEquals(serializer.is_valid(), True)
+
+ def test_create_model_null_field(self):
+ serializer = self.model_serializer_class(data={'title': None})
+ self.assertEquals(serializer.is_valid(), True)
+
+ def test_create_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a non-model serializer
+ """
+ serializer = self.not_blank_serializer_class(data=self.data)
+ self.assertEquals(serializer.is_valid(), False)
+
+ def test_create_model_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a model serializer
+ """
+ serializer = self.not_blank_model_serializer_class(data=self.data)
+ self.assertEquals(serializer.is_valid(), False)
+
+ def test_create_model_null_field(self):
+ serializer = self.model_serializer_class(data={})
+ self.assertEquals(serializer.is_valid(), True)
+
+
+#test for issue #460
+class SerializerPickleTests(TestCase):
+ """
+ Test pickleability of the output of Serializers
+ """
+ def test_pickle_simple_model_serializer_data(self):
+ """
+ Test simple serializer
+ """
+ pickle.dumps(PersonSerializer(Person(name="Methusela", age=969)).data)
+
+ def test_pickle_inner_serializer(self):
+ """
+ Test pickling a serializer whose resulting .data (a SortedDictWithMetadata) will
+ have unpickleable meta data--in order to make sure metadata doesn't get pulled into the pickle.
+ See DictWithMetadata.__getstate__
+ """
+ class InnerPersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+ fields = ('name', 'age')
+ pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data)
+
+
+class DepthTest(TestCase):
+ def test_implicit_nesting(self):
+ writer = Person.objects.create(name="django", age=1)
+ post = BlogPost.objects.create(title="Test blog post", writer=writer)
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ depth = 1
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {'id': 1, 'title': u'Test blog post',
+ 'writer': {'id': 1, 'name': u'django', 'age': 1}}
+
+ self.assertEqual(serializer.data, expected)
+
+ def test_explicit_nesting(self):
+ writer = Person.objects.create(name="django", age=1)
+ post = BlogPost.objects.create(title="Test blog post", writer=writer)
+
+ class PersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ writer = PersonSerializer()
+
+ class Meta:
+ model = BlogPost
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {'id': 1, 'title': u'Test blog post',
+ 'writer': {'id': 1, 'name': u'django', 'age': 1}}
+
+ self.assertEqual(serializer.data, expected)
+
+
+class NestedSerializerContextTests(TestCase):
+
+ def test_nested_serializer_context(self):
+ """
+ Regression for #497
+
+ https://github.com/tomchristie/django-rest-framework/issues/497
+ """
+ class PhotoSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Photo
+ fields = ("description", "callable")
+
+ callable = serializers.SerializerMethodField('_callable')
+
+ def _callable(self, instance):
+ if not 'context_item' in self.context:
+ raise RuntimeError("context isn't getting passed into 2nd level nested serializer")
+ return "success"
+
+ class AlbumSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Album
+ fields = ("photo_set", "callable")
+
+ photo_set = PhotoSerializer(source="photo_set")
+ callable = serializers.SerializerMethodField("_callable")
+
+ def _callable(self, instance):
+ if not 'context_item' in self.context:
+ raise RuntimeError("context isn't getting passed into 1st level nested serializer")
+ return "success"
+
+ class AlbumCollection(object):
+ albums = None
+
+ class AlbumCollectionSerializer(serializers.Serializer):
+ albums = AlbumSerializer(source="albums")
+
+ album1 = Album.objects.create(title="album 1")
+ album2 = Album.objects.create(title="album 2")
+ Photo.objects.create(description="Bigfoot", album=album1)
+ Photo.objects.create(description="Unicorn", album=album1)
+ Photo.objects.create(description="Yeti", album=album2)
+ Photo.objects.create(description="Sasquatch", album=album2)
+ album_collection = AlbumCollection()
+ album_collection.albums = [album1, album2]
+
+ # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers
+ AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data
diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/settings.py
new file mode 100644
index 00000000..0293fdc3
--- /dev/null
+++ b/rest_framework/tests/settings.py
@@ -0,0 +1,21 @@
+"""Tests for the settings module"""
+from django.test import TestCase
+
+from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS
+
+
+class TestSettings(TestCase):
+ """Tests relating to the api settings"""
+
+ def test_non_import_errors(self):
+ """Make sure other errors aren't suppressed."""
+ settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
+ with self.assertRaises(ValueError):
+ settings.DEFAULT_MODEL_SERIALIZER_CLASS
+
+ def test_import_error_message_maintained(self):
+ """Make sure real import errors are captured and raised sensibly."""
+ settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
+ with self.assertRaises(ImportError) as cm:
+ settings.DEFAULT_MODEL_SERIALIZER_CLASS
+ self.assertTrue('ImportError' in str(cm.exception))
diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py
index c90224aa..97f492ff 100644
--- a/rest_framework/tests/testcases.py
+++ b/rest_framework/tests/testcases.py
@@ -6,6 +6,7 @@ from django.test import TestCase
NO_SETTING = ('!', None)
+
class TestSettingsManager(object):
"""
A class which can modify some Django settings temporarily for a
@@ -19,7 +20,7 @@ class TestSettingsManager(object):
self._original_settings = {}
def set(self, **kwargs):
- for k,v in kwargs.iteritems():
+ for k, v in kwargs.iteritems():
self._original_settings.setdefault(k, getattr(settings, k,
NO_SETTING))
setattr(settings, k, v)
@@ -31,7 +32,7 @@ class TestSettingsManager(object):
call_command('syncdb', verbosity=0)
def revert(self):
- for k,v in self._original_settings.iteritems():
+ for k, v in self._original_settings.iteritems():
if v == NO_SETTING:
delattr(settings, k)
else:
@@ -57,6 +58,7 @@ class SettingsTestCase(TestCase):
def tearDown(self):
self.settings_manager.revert()
+
class TestModelsTestCase(SettingsTestCase):
def setUp(self, *args, **kwargs):
installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',)
diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py
new file mode 100644
index 00000000..adeaf6da
--- /dev/null
+++ b/rest_framework/tests/tests.py
@@ -0,0 +1,13 @@
+"""
+Force import of all modules in this package in order to get the standard test
+runner to pick up the tests. Yowzers.
+"""
+import os
+
+modules = [filename.rsplit('.', 1)[0]
+ for filename in os.listdir(os.path.dirname(__file__))
+ if filename.endswith('.py') and not filename.startswith('_')]
+__test__ = dict()
+
+for module in modules:
+ exec("from rest_framework.tests.%s import *" % module)
diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py
index 0b94c25b..4b98b941 100644
--- a/rest_framework/tests/throttling.py
+++ b/rest_framework/tests/throttling.py
@@ -106,7 +106,7 @@ class ThrottlingTests(TestCase):
if expect is not None:
self.assertEquals(response['X-Throttle-Wait-Seconds'], expect)
else:
- self.assertFalse('X-Throttle-Wait-Seconds' in response.headers)
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
def test_seconds_fields(self):
"""
diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py
new file mode 100644
index 00000000..3906adb9
--- /dev/null
+++ b/rest_framework/tests/utils.py
@@ -0,0 +1,27 @@
+from django.test.client import RequestFactory, FakePayload
+from django.test.client import MULTIPART_CONTENT
+from urlparse import urlparse
+
+
+class RequestFactory(RequestFactory):
+
+ def __init__(self, **defaults):
+ super(RequestFactory, self).__init__(**defaults)
+
+ def patch(self, path, data={}, content_type=MULTIPART_CONTENT,
+ **extra):
+ "Construct a PATCH request."
+
+ patch_data = self._encode_data(data, content_type)
+
+ parsed = urlparse(path)
+ r = {
+ 'CONTENT_LENGTH': len(patch_data),
+ 'CONTENT_TYPE': content_type,
+ 'PATH_INFO': self._get_path(parsed),
+ 'QUERY_STRING': parsed[4],
+ 'REQUEST_METHOD': 'PATCH',
+ 'wsgi.input': FakePayload(patch_data),
+ }
+ r.update(extra)
+ return self.request(**r)
diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py
index b390c42f..c032985e 100644
--- a/rest_framework/tests/validators.py
+++ b/rest_framework/tests/validators.py
@@ -285,7 +285,7 @@
# uiop = models.CharField(max_length=256, blank=True)
# @property
-# def readonly(self):
+# def read_only(self):
# return 'read only'
# class MockResource(ModelResource):
@@ -298,7 +298,7 @@
# def test_property_fields_are_allowed_on_model_forms(self):
# """Validation on ModelForms may include property fields that exist on the Model to be included in the input."""
-# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only'}
+# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'}
# self.assertEqual(self.validator.validate_request(content, None), content)
# def test_property_fields_are_not_required_on_model_forms(self):
@@ -310,19 +310,19 @@
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only', 'extra': 'extra'}
+# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'}
# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
# def test_validate_requires_fields_on_model_forms(self):
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'readonly': 'read only'}
+# content = {'read_only': 'read only'}
# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
# def test_validate_does_not_require_blankable_fields_on_model_forms(self):
# """Test standard ModelForm validation behaviour - fields with blank=True are not required."""
-# content = {'qwerty': 'example', 'readonly': 'read only'}
+# content = {'qwerty': 'example', 'read_only': 'read only'}
# self.validator.validate_request(content, None)
# def test_model_form_validator_uses_model_forms(self):
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py
index 43365e07..7cd82656 100644
--- a/rest_framework/tests/views.py
+++ b/rest_framework/tests/views.py
@@ -18,7 +18,7 @@ class BasicView(APIView):
return Response({'method': 'POST', 'data': request.DATA})
-@api_view(['GET', 'POST', 'PUT'])
+@api_view(['GET', 'POST', 'PUT', 'PATCH'])
def basic_view(request):
if request.method == 'GET':
return {'method': 'GET'}
@@ -26,6 +26,8 @@ def basic_view(request):
return {'method': 'POST', 'data': request.DATA}
elif request.method == 'PUT':
return {'method': 'PUT', 'data': request.DATA}
+ elif request.method == 'PATCH':
+ return {'method': 'PATCH', 'data': request.DATA}
def sanitise_json_error(error_dict):