aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-01-15 17:53:24 +0000
committerTom Christie2013-01-15 17:53:24 +0000
commit71e55cc4f6300959398f7aef4a8d91b6a6a2af57 (patch)
tree68c2080034263d897741da33cbc5e09746006257 /rest_framework
parent52847a215d4e8de88e81d9ae79ce8bee9a36a9a2 (diff)
parente1076cfb49b6293aa837cf7bdb4c11988892c598 (diff)
downloaddjango-rest-framework-71e55cc4f6300959398f7aef4a8d91b6a6a2af57.tar.bz2
Merge with latest master
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/authtoken/migrations/0001_initial.py14
-rw-r--r--rest_framework/authtoken/models.py3
-rw-r--r--rest_framework/authtoken/serializers.py24
-rw-r--r--rest_framework/authtoken/views.py26
-rw-r--r--rest_framework/compat.py76
-rw-r--r--rest_framework/decorators.py14
-rw-r--r--rest_framework/exceptions.py8
-rw-r--r--rest_framework/fields.py455
-rw-r--r--rest_framework/filters.py59
-rw-r--r--rest_framework/generics.py92
-rw-r--r--rest_framework/mixins.py50
-rw-r--r--rest_framework/negotiation.py7
-rw-r--r--rest_framework/pagination.py17
-rw-r--r--rest_framework/parsers.py2
-rw-r--r--rest_framework/permissions.py13
-rw-r--r--rest_framework/relations.py509
-rw-r--r--rest_framework/renderers.py153
-rw-r--r--rest_framework/request.py21
-rw-r--r--rest_framework/response.py21
-rw-r--r--rest_framework/reverse.py8
-rwxr-xr-xrest_framework/runtests/runcoverage.py13
-rwxr-xr-xrest_framework/runtests/runtests.py5
-rw-r--r--rest_framework/runtests/settings.py14
-rw-r--r--rest_framework/runtests/urls.py2
-rw-r--r--rest_framework/serializers.py282
-rw-r--r--rest_framework/settings.py31
-rw-r--r--rest_framework/static/rest_framework/css/default.css21
-rw-r--r--rest_framework/status.py2
-rw-r--r--rest_framework/templates/rest_framework/base.html33
-rw-r--r--rest_framework/templates/rest_framework/login.html58
-rw-r--r--rest_framework/templatetags/rest_framework.py108
-rw-r--r--rest_framework/tests/__init__.py13
-rw-r--r--rest_framework/tests/authentication.py41
-rw-r--r--rest_framework/tests/breadcrumbs.py2
-rw-r--r--rest_framework/tests/decorators.py17
-rw-r--r--rest_framework/tests/extras/__init__.py0
-rw-r--r--rest_framework/tests/extras/bad_import.py1
-rw-r--r--rest_framework/tests/fields.py49
-rw-r--r--rest_framework/tests/files.py67
-rw-r--r--rest_framework/tests/filterset.py168
-rw-r--r--rest_framework/tests/genericrelations.py2
-rw-r--r--rest_framework/tests/generics.py103
-rw-r--r--rest_framework/tests/htmlrenderer.py73
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py144
-rw-r--r--rest_framework/tests/models.py111
-rw-r--r--rest_framework/tests/modelviews.py2
-rw-r--r--rest_framework/tests/pagination.py206
-rw-r--r--rest_framework/tests/relations.py33
-rw-r--r--rest_framework/tests/relations_hyperlink.py434
-rw-r--r--rest_framework/tests/relations_nested.py114
-rw-r--r--rest_framework/tests/relations_pk.py411
-rw-r--r--rest_framework/tests/renderers.py92
-rw-r--r--rest_framework/tests/request.py43
-rw-r--r--rest_framework/tests/response.py11
-rw-r--r--rest_framework/tests/reverse.py2
-rw-r--r--rest_framework/tests/serializer.py553
-rw-r--r--rest_framework/tests/settings.py21
-rw-r--r--rest_framework/tests/testcases.py6
-rw-r--r--rest_framework/tests/tests.py13
-rw-r--r--rest_framework/tests/throttling.py2
-rw-r--r--rest_framework/tests/utils.py27
-rw-r--r--rest_framework/tests/validators.py10
-rw-r--r--rest_framework/tests/views.py4
-rw-r--r--rest_framework/throttling.py2
-rw-r--r--rest_framework/urlpatterns.py4
-rw-r--r--rest_framework/urls.py6
-rw-r--r--rest_framework/utils/__init__.py1
-rw-r--r--rest_framework/utils/breadcrumbs.py14
-rw-r--r--rest_framework/utils/encoders.py6
-rw-r--r--rest_framework/views.py14
71 files changed, 4268 insertions, 697 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 557f5943..bc267fad 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,3 +1,3 @@
-__version__ = '2.0.0'
+__version__ = '2.1.16'
VERSION = __version__ # synonym
diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py
index 9d750381..f4e052e4 100644
--- a/rest_framework/authtoken/migrations/0001_initial.py
+++ b/rest_framework/authtoken/migrations/0001_initial.py
@@ -5,13 +5,21 @@ from south.v2 import SchemaMigration
from django.db import models
+try:
+ from django.contrib.auth import get_user_model
+except ImportError: # django < 1.5
+ from django.contrib.auth.models import User
+else:
+ User = get_user_model()
+
+
class Migration(SchemaMigration):
def forwards(self, orm):
# Adding model 'Token'
db.create_table('authtoken_token', (
('key', self.gf('django.db.models.fields.CharField')(max_length=40, primary_key=True)),
- ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['auth.User'])),
+ ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['%s.%s' % (User._meta.app_label, User._meta.object_name)])),
('created', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)),
))
db.send_create_signal('authtoken', ['Token'])
@@ -36,7 +44,7 @@ class Migration(SchemaMigration):
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '50'})
},
- 'auth.user': {
+ "%s.%s" % (User._meta.app_label, User._meta.module_name): {
'Meta': {'object_name': 'User'},
'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}),
@@ -56,7 +64,7 @@ class Migration(SchemaMigration):
'Meta': {'object_name': 'Token'},
'created': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}),
'key': ('django.db.models.fields.CharField', [], {'max_length': '40', 'primary_key': 'True'}),
- 'user': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'auth_token'", 'unique': 'True', 'to': "orm['auth.User']"})
+ 'user': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'auth_token'", 'unique': 'True', 'to': "orm['%s.%s']" % (User._meta.app_label, User._meta.object_name)})
},
'contenttypes.contenttype': {
'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"},
diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py
index 5b3071aa..4da2aa62 100644
--- a/rest_framework/authtoken/models.py
+++ b/rest_framework/authtoken/models.py
@@ -1,6 +1,7 @@
import uuid
import hmac
from hashlib import sha1
+from rest_framework.compat import User
from django.db import models
@@ -9,7 +10,7 @@ class Token(models.Model):
The default authorization token model.
"""
key = models.CharField(max_length=40, primary_key=True)
- user = models.OneToOneField('auth.User', related_name='auth_token')
+ user = models.OneToOneField(User, related_name='auth_token')
created = models.DateTimeField(auto_now_add=True)
def save(self, *args, **kwargs):
diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py
new file mode 100644
index 00000000..60a3740e
--- /dev/null
+++ b/rest_framework/authtoken/serializers.py
@@ -0,0 +1,24 @@
+from django.contrib.auth import authenticate
+from rest_framework import serializers
+
+
+class AuthTokenSerializer(serializers.Serializer):
+ username = serializers.CharField()
+ password = serializers.CharField()
+
+ def validate(self, attrs):
+ username = attrs.get('username')
+ password = attrs.get('password')
+
+ if username and password:
+ user = authenticate(username=username, password=password)
+
+ if user:
+ if not user.is_active:
+ raise serializers.ValidationError('User account is disabled.')
+ attrs['user'] = user
+ return attrs
+ else:
+ raise serializers.ValidationError('Unable to login with provided credentials.')
+ else:
+ raise serializers.ValidationError('Must include "username" and "password"')
diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py
index e69de29b..7c03cb76 100644
--- a/rest_framework/authtoken/views.py
+++ b/rest_framework/authtoken/views.py
@@ -0,0 +1,26 @@
+from rest_framework.views import APIView
+from rest_framework import status
+from rest_framework import parsers
+from rest_framework import renderers
+from rest_framework.response import Response
+from rest_framework.authtoken.models import Token
+from rest_framework.authtoken.serializers import AuthTokenSerializer
+
+
+class ObtainAuthToken(APIView):
+ throttle_classes = ()
+ permission_classes = ()
+ parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
+ renderer_classes = (renderers.JSONRenderer,)
+ serializer_class = AuthTokenSerializer
+ model = Token
+
+ def post(self, request):
+ serializer = self.serializer_class(data=request.DATA)
+ if serializer.is_valid():
+ token, created = Token.objects.get_or_create(user=serializer.object['user'])
+ return Response({'token': token.key})
+ return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+
+
+obtain_auth_token = ObtainAuthToken.as_view()
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 7664c400..5508f6c0 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -1,8 +1,23 @@
"""
-The :mod:`compat` module provides support for backwards compatibility with older versions of django/python.
+The `compat` module provides support for backwards compatibility with older
+versions of django/python, and compatibility wrappers around optional packages.
"""
+# flake8: noqa
import django
+# location of patterns, url, include changes in 1.4 onwards
+try:
+ from django.conf.urls import patterns, url, include
+except:
+ from django.conf.urls.defaults import patterns, url, include
+
+# django-filter is optional
+try:
+ import django_filters
+except:
+ django_filters = None
+
+
# cStringIO only if it's available, otherwise StringIO
try:
import cStringIO as StringIO
@@ -10,6 +25,16 @@ except ImportError:
import StringIO
+# Try to import PIL in either of the two ways it can end up installed.
+try:
+ from PIL import Image
+except ImportError:
+ try:
+ import Image
+ except ImportError:
+ Image = None
+
+
def get_concrete_model(model_cls):
try:
return model_cls._meta.concrete_model
@@ -18,6 +43,20 @@ def get_concrete_model(model_cls):
return model_cls
+# Django 1.5 add support for custom auth user model
+if django.VERSION >= (1, 5):
+ from django.conf import settings
+ if hasattr(settings, 'AUTH_USER_MODEL'):
+ User = settings.AUTH_USER_MODEL
+ else:
+ from django.contrib.auth.models import User
+else:
+ try:
+ from django.contrib.auth.models import User
+ except ImportError:
+ raise ImportError(u"User model is not to be found.")
+
+
# First implementation of Django class-based views did not include head method
# in base View class - https://code.djangoproject.com/ticket/15668
if django.VERSION >= (1, 4):
@@ -57,6 +96,12 @@ else:
update_wrapper(view, cls.dispatch, assigned=())
return view
+# Taken from @markotibold's attempt at supporting PATCH.
+# https://github.com/markotibold/django-rest-framework/tree/patch
+http_method_names = set(View.http_method_names)
+http_method_names.add('patch')
+View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django
+
# PUT, DELETE do not require CSRF until 1.4. They should. Make it better.
if django.VERSION >= (1, 4):
from django.middleware.csrf import CsrfViewMiddleware
@@ -331,7 +376,7 @@ try:
"""
extensions = ['headerid(level=2)']
- safe_mode = False,
+ safe_mode = False
md = markdown.Markdown(extensions=extensions, safe_mode=safe_mode)
return md.convert(text)
@@ -346,33 +391,6 @@ except ImportError:
yaml = None
-import unittest
-try:
- import unittest.skip
-except ImportError: # python < 2.7
- from unittest import TestCase
- import functools
-
- def skip(reason):
- # Pasted from py27/lib/unittest/case.py
- """
- Unconditionally skip a test.
- """
- def decorator(test_item):
- if not (isinstance(test_item, type) and issubclass(test_item, TestCase)):
- @functools.wraps(test_item)
- def skip_wrapper(*args, **kwargs):
- pass
- test_item = skip_wrapper
-
- test_item.__unittest_skip__ = True
- test_item.__unittest_skip_why__ = reason
- return test_item
- return decorator
-
- unittest.skip = skip
-
-
# xml.etree.parse only throws ParseError for python >= 2.7
try:
from xml.etree import ParseError as ETParseError
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 948973ae..1b710a03 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -10,8 +10,18 @@ def api_view(http_method_names):
def decorator(func):
- class WrappedAPIView(APIView):
- pass
+ WrappedAPIView = type(
+ 'WrappedAPIView',
+ (APIView,),
+ {'__doc__': func.__doc__}
+ )
+
+ # Note, the above allows us to set the docstring.
+ # It is the equivalent of:
+ #
+ # class WrappedAPIView(APIView):
+ # pass
+ # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
allowed_methods = set(http_method_names) | set(('options',))
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 572425b9..89479deb 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -31,14 +31,6 @@ class PermissionDenied(APIException):
self.detail = detail or self.default_detail
-class InvalidFormat(APIException):
- status_code = status.HTTP_404_NOT_FOUND
- default_detail = "Format suffix '.%s' not found."
-
- def __init__(self, format, detail=None):
- self.detail = (detail or self.default_detail) % format
-
-
class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
default_detail = "Method '%s' not allowed."
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 6ed37823..998911e1 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,16 +1,18 @@
import copy
import datetime
import inspect
+import re
import warnings
+from io import BytesIO
+
from django.core import validators
-from django.core.exceptions import ObjectDoesNotExist, ValidationError
-from django.core.urlresolvers import resolve
+from django.core.exceptions import ValidationError
from django.conf import settings
+from django import forms
from django.forms import widgets
from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _
-from rest_framework.reverse import reverse
from rest_framework.compat import parse_date, parse_datetime
from rest_framework.compat import timezone
@@ -26,9 +28,12 @@ def is_simple_callable(obj):
class Field(object):
+ read_only = True
creation_counter = 0
empty = ''
type_name = None
+ _use_files = None
+ form_field_class = forms.CharField
def __init__(self, source=None):
self.parent = None
@@ -38,18 +43,20 @@ class Field(object):
self.source = source
- def initialize(self, parent):
+ def initialize(self, parent, field_name):
"""
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
- model_field - The model field this field corrosponds to, if one exists.
+ model_field - The model field this field corresponds to, if one exists.
"""
self.parent = parent
self.root = parent.root or parent
self.context = self.root.context
+ if self.root.partial:
+ self.required = False
- def field_from_native(self, data, field_name, into):
+ def field_from_native(self, data, files, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
@@ -88,6 +95,8 @@ class Field(object):
return value
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
return [self.to_native(item) for item in value]
+ elif isinstance(value, dict):
+ return dict(map(self.to_native, (k, v)) for k, v in value.items())
return smart_unicode(value)
def attributes(self):
@@ -111,17 +120,17 @@ class WritableField(Field):
widget = widgets.TextInput
default = None
- def __init__(self, source=None, readonly=False, required=None,
+ def __init__(self, source=None, read_only=False, required=None,
validators=[], error_messages=None, widget=None,
- default=None):
+ default=None, blank=None):
super(WritableField, self).__init__(source=source)
- self.readonly = readonly
+ self.read_only = read_only
if required is None:
- self.required = not(readonly)
+ self.required = not(read_only)
else:
- assert not readonly, "Cannot set required=True and readonly=True"
+ assert not (read_only and required), "Cannot set required=True and read_only=True"
self.required = required
messages = {}
@@ -131,7 +140,8 @@ class WritableField(Field):
self.error_messages = messages
self.validators = self.default_validators + validators
- self.default = default or self.default
+ self.default = default if default is not None else self.default
+ self.blank = blank
# Widgets are ony used for HTML forms.
widget = widget or self.widget
@@ -161,18 +171,23 @@ class WritableField(Field):
if errors:
raise ValidationError(errors)
- def field_from_native(self, data, field_name, into):
+ def field_from_native(self, data, files, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
"""
- if self.readonly:
+ if self.read_only:
return
try:
- native = data[field_name]
+ if self._use_files:
+ files = files or {}
+ native = files[field_name]
+ else:
+ native = data[field_name]
except KeyError:
- if self.default is not None:
+ if self.default is not None and not self.root.partial:
+ # Note: partial updates shouldn't set defaults
native = self.default
else:
if self.required:
@@ -197,21 +212,32 @@ class WritableField(Field):
class ModelField(WritableField):
"""
- A generic field that can be used against an arbirtrary model field.
+ A generic field that can be used against an arbitrary model field.
"""
def __init__(self, *args, **kwargs):
try:
self.model_field = kwargs.pop('model_field')
except:
raise ValueError("ModelField requires 'model_field' kwarg")
+
+ self.min_length = kwargs.pop('min_length',
+ getattr(self.model_field, 'min_length', None))
+ self.max_length = kwargs.pop('max_length',
+ getattr(self.model_field, 'max_length', None))
+
super(ModelField, self).__init__(*args, **kwargs)
+ if self.min_length is not None:
+ self.validators.append(validators.MinLengthValidator(self.min_length))
+ if self.max_length is not None:
+ self.validators.append(validators.MaxLengthValidator(self.max_length))
+
def from_native(self, value):
- try:
- rel = self.model_field.rel
- except:
+ rel = getattr(self.model_field, "rel", None)
+ if rel is not None:
+ return rel.to._meta.get_field(rel.field_name).to_python(value)
+ else:
return self.model_field.to_python(value)
- return rel.to._meta.get_field(rel.field_name).to_python(value)
def field_to_native(self, obj, field_name):
value = self.model_field._get_val_from_obj(obj)
@@ -224,200 +250,12 @@ class ModelField(WritableField):
"type": self.model_field.get_internal_type()
}
-##### Relational fields #####
-
-
-class RelatedField(WritableField):
- """
- Base class for related model fields.
- """
- def __init__(self, *args, **kwargs):
- self.queryset = kwargs.pop('queryset', None)
- super(RelatedField, self).__init__(*args, **kwargs)
-
- def field_to_native(self, obj, field_name):
- value = getattr(obj, self.source or field_name)
- return self.to_native(value)
-
- def field_from_native(self, data, field_name, into):
- if self.readonly:
- return
-
- value = data.get(field_name)
- into[(self.source or field_name) + '_id'] = self.from_native(value)
-
-
-class ManyRelatedMixin(object):
- """
- Mixin to convert a related field to a many related field.
- """
- def field_to_native(self, obj, field_name):
- value = getattr(obj, self.source or field_name)
- return [self.to_native(item) for item in value.all()]
-
- def field_from_native(self, data, field_name, into):
- if self.readonly:
- return
-
- try:
- # Form data
- value = data.getlist(self.source or field_name)
- except:
- # Non-form data
- value = data.get(self.source or field_name)
- else:
- if value == ['']:
- value = []
- into[field_name] = [self.from_native(item) for item in value]
-
-
-class ManyRelatedField(ManyRelatedMixin, RelatedField):
- """
- Base class for related model managers.
- """
- pass
-
-
-### PrimaryKey relationships
-
-class PrimaryKeyRelatedField(RelatedField):
- """
- Serializes a related field or related object to a pk value.
- """
-
- def to_native(self, pk):
- return pk
-
- def field_to_native(self, obj, field_name):
- try:
- # Prefer obj.serializable_value for performance reasons
- pk = obj.serializable_value(self.source or field_name)
- except AttributeError:
- # RelatedObject (reverse relationship)
- obj = getattr(obj, self.source or field_name)
- return self.to_native(obj.pk)
- # Forward relationship
- return self.to_native(pk)
-
-
-class ManyPrimaryKeyRelatedField(ManyRelatedField):
- """
- Serializes a to-many related field or related manager to a pk value.
- """
- def to_native(self, pk):
- return pk
-
- def field_to_native(self, obj, field_name):
- try:
- # Prefer obj.serializable_value for performance reasons
- queryset = obj.serializable_value(self.source or field_name)
- except AttributeError:
- # RelatedManager (reverse relationship)
- queryset = getattr(obj, self.source or field_name)
- return [self.to_native(item.pk) for item in queryset.all()]
- # Forward relationship
- return [self.to_native(item.pk) for item in queryset.all()]
-
-
-### Hyperlinked relationships
-
-class HyperlinkedRelatedField(RelatedField):
- pk_url_kwarg = 'pk'
- slug_url_kwarg = 'slug'
- slug_field = 'slug'
-
- def __init__(self, *args, **kwargs):
- try:
- self.view_name = kwargs.pop('view_name')
- except:
- raise ValueError("Hyperlinked field requires 'view_name' kwarg")
- super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
-
- def to_native(self, obj):
- view_name = self.view_name
- request = self.context.get('request', None)
- kwargs = {self.pk_url_kwarg: obj.pk}
- try:
- return reverse(view_name, kwargs=kwargs, request=request)
- except:
- pass
-
- slug = getattr(obj, self.slug_field, None)
-
- if not slug:
- raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
-
- kwargs = {self.slug_url_kwarg: slug}
- try:
- return reverse(self.view_name, kwargs=kwargs, request=request)
- except:
- pass
-
- kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
- try:
- return reverse(self.view_name, kwargs=kwargs, request=request)
- except:
- pass
-
- raise ValidationError('Could not resolve URL for field using view name "%s"', view_name)
-
- def from_native(self, value):
- # Convert URL -> model instance pk
- # TODO: Use values_list
- try:
- match = resolve(value)
- except:
- raise ValidationError('Invalid hyperlink - No URL match')
-
- if match.url_name != self.view_name:
- raise ValidationError('Invalid hyperlink - Incorrect URL match')
-
- pk = match.kwargs.get(self.pk_url_kwarg, None)
- slug = match.kwargs.get(self.slug_url_kwarg, None)
-
- # Try explicit primary key.
- if pk is not None:
- return pk
- # Next, try looking up by slug.
- elif slug is not None:
- slug_field = self.get_slug_field()
- queryset = self.queryset.filter(**{slug_field: slug})
- # If none of those are defined, it's an error.
- else:
- raise ValidationError('Invalid hyperlink')
-
- try:
- obj = queryset.get()
- except ObjectDoesNotExist:
- raise ValidationError('Invalid hyperlink - object does not exist.')
- return obj.pk
-
-
-class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
- pass
-
-
-class HyperlinkedIdentityField(Field):
- """
- A field that represents the model's identity using a hyperlink.
- """
- def __init__(self, *args, **kwargs):
- # TODO: Make this mandatory, and have the HyperlinkedModelSerializer
- # set it on-the-fly
- self.view_name = kwargs.pop('view_name', None)
- super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
-
- def field_to_native(self, obj, field_name):
- request = self.context.get('request', None)
- view_name = self.view_name or self.parent.opts.view_name
- view_kwargs = {'pk': obj.pk}
- return reverse(view_name, kwargs=view_kwargs, request=request)
-
##### Typed Fields #####
class BooleanField(WritableField):
type_name = 'BooleanField'
+ form_field_class = forms.BooleanField
widget = widgets.CheckboxInput
default_error_messages = {
'invalid': _(u"'%s' value must be either True or False."),
@@ -430,15 +268,16 @@ class BooleanField(WritableField):
default = False
def from_native(self, value):
- if value in ('t', 'True', '1'):
+ if value in ('true', 't', 'True', '1'):
return True
- if value in ('f', 'False', '0'):
+ if value in ('false', 'f', 'False', '0'):
return False
return bool(value)
class CharField(WritableField):
type_name = 'CharField'
+ form_field_class = forms.CharField
def __init__(self, max_length=None, min_length=None, *args, **kwargs):
self.max_length, self.min_length = max_length, min_length
@@ -448,14 +287,42 @@ class CharField(WritableField):
if max_length is not None:
self.validators.append(validators.MaxLengthValidator(max_length))
+ def validate(self, value):
+ """
+ Validates that the value is supplied (if required).
+ """
+ # if empty string and allow blank
+ if self.blank and not value:
+ return
+ else:
+ super(CharField, self).validate(value)
+
def from_native(self, value):
if isinstance(value, basestring) or value is None:
return value
return smart_unicode(value)
+class URLField(CharField):
+ type_name = 'URLField'
+
+ def __init__(self, **kwargs):
+ kwargs['max_length'] = kwargs.get('max_length', 200)
+ kwargs['validators'] = [validators.URLValidator()]
+ super(URLField, self).__init__(**kwargs)
+
+
+class SlugField(CharField):
+ type_name = 'SlugField'
+
+ def __init__(self, *args, **kwargs):
+ kwargs['max_length'] = kwargs.get('max_length', 50)
+ super(SlugField, self).__init__(*args, **kwargs)
+
+
class ChoiceField(WritableField):
type_name = 'ChoiceField'
+ form_field_class = forms.ChoiceField
widget = widgets.Select
default_error_messages = {
'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'),
@@ -495,13 +362,14 @@ class ChoiceField(WritableField):
if value == smart_unicode(k2):
return True
else:
- if value == smart_unicode(k):
+ if value == smart_unicode(k) or value == k:
return True
return False
class EmailField(CharField):
type_name = 'EmailField'
+ form_field_class = forms.EmailField
default_error_messages = {
'invalid': _('Enter a valid e-mail address.'),
@@ -509,7 +377,10 @@ class EmailField(CharField):
default_validators = [validators.validate_email]
def from_native(self, value):
- return super(EmailField, self).from_native(value).strip()
+ ret = super(EmailField, self).from_native(value)
+ if ret is None:
+ return None
+ return ret.strip()
def __deepcopy__(self, memo):
result = copy.copy(self)
@@ -519,8 +390,39 @@ class EmailField(CharField):
return result
+class RegexField(CharField):
+ type_name = 'RegexField'
+ form_field_class = forms.RegexField
+
+ def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs):
+ super(RegexField, self).__init__(max_length, min_length, *args, **kwargs)
+ self.regex = regex
+
+ def _get_regex(self):
+ return self._regex
+
+ def _set_regex(self, regex):
+ if isinstance(regex, basestring):
+ regex = re.compile(regex)
+ self._regex = regex
+ if hasattr(self, '_regex_validator') and self._regex_validator in self.validators:
+ self.validators.remove(self._regex_validator)
+ self._regex_validator = validators.RegexValidator(regex=regex)
+ self.validators.append(self._regex_validator)
+
+ regex = property(_get_regex, _set_regex)
+
+ def __deepcopy__(self, memo):
+ result = copy.copy(self)
+ memo[id(self)] = result
+ result.validators = self.validators[:]
+ return result
+
+
class DateField(WritableField):
type_name = 'DateField'
+ widget = widgets.DateInput
+ form_field_class = forms.DateField
default_error_messages = {
'invalid': _(u"'%s' value has an invalid date format. It must be "
@@ -531,8 +433,9 @@ class DateField(WritableField):
empty = None
def from_native(self, value):
- if value is None:
- return value
+ if value in validators.EMPTY_VALUES:
+ return None
+
if isinstance(value, datetime.datetime):
if timezone and settings.USE_TZ and timezone.is_aware(value):
# Convert aware datetimes to the default time zone
@@ -557,6 +460,8 @@ class DateField(WritableField):
class DateTimeField(WritableField):
type_name = 'DateTimeField'
+ widget = widgets.DateTimeInput
+ form_field_class = forms.DateTimeField
default_error_messages = {
'invalid': _(u"'%s' value has an invalid format. It must be in "
@@ -570,8 +475,9 @@ class DateTimeField(WritableField):
empty = None
def from_native(self, value):
- if value is None:
- return value
+ if value in validators.EMPTY_VALUES:
+ return None
+
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
@@ -610,6 +516,7 @@ class DateTimeField(WritableField):
class IntegerField(WritableField):
type_name = 'IntegerField'
+ form_field_class = forms.IntegerField
default_error_messages = {
'invalid': _('Enter a whole number.'),
@@ -629,6 +536,7 @@ class IntegerField(WritableField):
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
+
try:
value = int(str(value))
except (ValueError, TypeError):
@@ -638,16 +546,123 @@ class IntegerField(WritableField):
class FloatField(WritableField):
type_name = 'FloatField'
+ form_field_class = forms.FloatField
default_error_messages = {
'invalid': _("'%s' value must be a float."),
}
def from_native(self, value):
- if value is None:
- return value
+ if value in validators.EMPTY_VALUES:
+ return None
+
try:
return float(value)
except (TypeError, ValueError):
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
+
+
+class FileField(WritableField):
+ _use_files = True
+ type_name = 'FileField'
+ form_field_class = forms.FileField
+ widget = widgets.FileInput
+
+ default_error_messages = {
+ 'invalid': _("No file was submitted. Check the encoding type on the form."),
+ 'missing': _("No file was submitted."),
+ 'empty': _("The submitted file is empty."),
+ 'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'),
+ 'contradiction': _('Please either submit a file or check the clear checkbox, not both.')
+ }
+
+ def __init__(self, *args, **kwargs):
+ self.max_length = kwargs.pop('max_length', None)
+ self.allow_empty_file = kwargs.pop('allow_empty_file', False)
+ super(FileField, self).__init__(*args, **kwargs)
+
+ def from_native(self, data):
+ if data in validators.EMPTY_VALUES:
+ return None
+
+ # UploadedFile objects should have name and size attributes.
+ try:
+ file_name = data.name
+ file_size = data.size
+ except AttributeError:
+ raise ValidationError(self.error_messages['invalid'])
+
+ if self.max_length is not None and len(file_name) > self.max_length:
+ error_values = {'max': self.max_length, 'length': len(file_name)}
+ raise ValidationError(self.error_messages['max_length'] % error_values)
+ if not file_name:
+ raise ValidationError(self.error_messages['invalid'])
+ if not self.allow_empty_file and not file_size:
+ raise ValidationError(self.error_messages['empty'])
+
+ return data
+
+ def to_native(self, value):
+ return value.name
+
+
+class ImageField(FileField):
+ _use_files = True
+ form_field_class = forms.ImageField
+
+ default_error_messages = {
+ 'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."),
+ }
+
+ def from_native(self, data):
+ """
+ Checks that the file-upload field data contains a valid image (GIF, JPG,
+ PNG, possibly others -- whatever the Python Imaging Library supports).
+ """
+ f = super(ImageField, self).from_native(data)
+ if f is None:
+ return None
+
+ from compat import Image
+ assert Image is not None, 'PIL must be installed for ImageField support'
+
+ # We need to get a file object for PIL. We might have a path or we might
+ # have to read the data into memory.
+ if hasattr(data, 'temporary_file_path'):
+ file = data.temporary_file_path()
+ else:
+ if hasattr(data, 'read'):
+ file = BytesIO(data.read())
+ else:
+ file = BytesIO(data['content'])
+
+ try:
+ # load() could spot a truncated JPEG, but it loads the entire
+ # image in memory, which is a DoS vector. See #3848 and #18520.
+ # verify() must be called immediately after the constructor.
+ Image.open(file).verify()
+ except ImportError:
+ # Under PyPy, it is possible to import PIL. However, the underlying
+ # _imaging C module isn't available, so an ImportError will be
+ # raised. Catch and re-raise.
+ raise
+ except Exception: # Python Imaging Library doesn't recognize it as an image
+ raise ValidationError(self.error_messages['invalid_image'])
+ if hasattr(f, 'seek') and callable(f.seek):
+ f.seek(0)
+ return f
+
+
+class SerializerMethodField(Field):
+ """
+ A field that gets its value by calling a method on the serializer it's attached to.
+ """
+
+ def __init__(self, method_name):
+ self.method_name = method_name
+ super(SerializerMethodField, self).__init__()
+
+ def field_to_native(self, obj, field_name):
+ value = getattr(self.parent, self.method_name)(obj)
+ return self.to_native(value)
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
new file mode 100644
index 00000000..bcc87660
--- /dev/null
+++ b/rest_framework/filters.py
@@ -0,0 +1,59 @@
+from rest_framework.compat import django_filters
+
+FilterSet = django_filters and django_filters.FilterSet or None
+
+
+class BaseFilterBackend(object):
+ """
+ A base class from which all filter backend classes should inherit.
+ """
+
+ def filter_queryset(self, request, queryset, view):
+ """
+ Return a filtered queryset.
+ """
+ raise NotImplementedError(".filter_queryset() must be overridden.")
+
+
+class DjangoFilterBackend(BaseFilterBackend):
+ """
+ A filter backend that uses django-filter.
+ """
+ default_filter_set = FilterSet
+
+ def __init__(self):
+ assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed'
+
+ def get_filter_class(self, view):
+ """
+ Return the django-filters `FilterSet` used to filter the queryset.
+ """
+ filter_class = getattr(view, 'filter_class', None)
+ filter_fields = getattr(view, 'filter_fields', None)
+ view_model = getattr(view, 'model', None)
+
+ if filter_class:
+ filter_model = filter_class.Meta.model
+
+ assert issubclass(filter_model, view_model), \
+ 'FilterSet model %s does not match view model %s' % \
+ (filter_model, view_model)
+
+ return filter_class
+
+ if filter_fields:
+ class AutoFilterSet(self.default_filter_set):
+ class Meta:
+ model = view_model
+ fields = filter_fields
+ return AutoFilterSet
+
+ return None
+
+ def filter_queryset(self, request, queryset, view):
+ filter_class = self.get_filter_class(view)
+
+ if filter_class:
+ return filter_class(request.GET, queryset=queryset)
+
+ return queryset
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 81014026..f575470e 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -1,5 +1,5 @@
"""
-Generic views that provide commmonly needed behaviour.
+Generic views that provide commonly needed behaviour.
"""
from rest_framework import views, mixins
@@ -14,6 +14,8 @@ class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
+
+ model = None
serializer_class = None
model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
@@ -30,8 +32,10 @@ class GenericAPIView(views.APIView):
def get_serializer_class(self):
"""
Return the class to use for the serializer.
- Use `self.serializer_class`, falling back to constructing a
- model serializer class from `self.model_serializer_class`
+
+ Defaults to using `self.serializer_class`, falls back to constructing a
+ model serializer class using `self.model_serializer_class`, with
+ `self.model` as the model.
"""
serializer_class = self.serializer_class
@@ -43,12 +47,19 @@ class GenericAPIView(views.APIView):
return serializer_class
- def get_serializer(self, data=None, files=None, instance=None):
- # TODO: add support for files
- # TODO: add support for seperate serializer/deserializer
+ def get_serializer(self, instance=None, data=None,
+ files=None, partial=False):
+ """
+ Return the serializer instance that should be used for validating and
+ deserializing input, and for serializing output.
+ """
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
- return serializer_class(data, instance=instance, context=context)
+ return serializer_class(instance, data=data, files=files,
+ partial=partial, context=context)
+
+ def pre_save(self, obj):
+ pass
class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
@@ -56,37 +67,59 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
Base class for generic views onto a queryset.
"""
- pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
paginate_by = api_settings.PAGINATE_BY
+ paginate_by_param = api_settings.PAGINATE_BY_PARAM
+ pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
+ filter_backend = api_settings.FILTER_BACKEND
- def get_pagination_serializer_class(self):
+ def filter_queryset(self, queryset):
"""
- Return the class to use for the pagination serializer.
+ Given a queryset, filter it with whichever filter backend is in use.
+ """
+ if not self.filter_backend:
+ return queryset
+ backend = self.filter_backend()
+ return backend.filter_queryset(self.request, queryset, self)
+
+ def get_pagination_serializer(self, page=None):
+ """
+ Return a serializer instance to use with paginated data.
"""
class SerializerClass(self.pagination_serializer_class):
class Meta:
object_serializer_class = self.get_serializer_class()
- return SerializerClass
-
- def get_pagination_serializer(self, page=None):
- pagination_serializer_class = self.get_pagination_serializer_class()
+ pagination_serializer_class = SerializerClass
context = self.get_serializer_context()
return pagination_serializer_class(instance=page, context=context)
+ def get_paginate_by(self, queryset):
+ """
+ Return the size of pages to use with pagination.
+ """
+ if self.paginate_by_param:
+ query_params = self.request.QUERY_PARAMS
+ try:
+ return int(query_params[self.paginate_by_param])
+ except (KeyError, ValueError):
+ pass
+ return self.paginate_by
+
class SingleObjectAPIView(SingleObjectMixin, GenericAPIView):
"""
Base class for generic views onto a model instance.
"""
+
pk_url_kwarg = 'pk' # Not provided in Django 1.3
slug_url_kwarg = 'slug' # Not provided in Django 1.3
+ slug_field = 'slug'
- def get_object(self):
+ def get_object(self, queryset=None):
"""
Override default to add support for object-level permissions.
"""
- obj = super(SingleObjectAPIView, self).get_object()
+ obj = super(SingleObjectAPIView, self).get_object(queryset)
if not self.has_permission(self.request, obj):
self.permission_denied(self.request)
return obj
@@ -125,7 +158,7 @@ class RetrieveAPIView(mixins.RetrieveModelMixin,
class DestroyAPIView(mixins.DestroyModelMixin,
- SingleObjectAPIView):
+ SingleObjectAPIView):
"""
Concrete view for deleting a model instance.
@@ -143,6 +176,10 @@ class UpdateAPIView(mixins.UpdateModelMixin,
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
+ def patch(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
class ListCreateAPIView(mixins.ListModelMixin,
mixins.CreateModelMixin,
@@ -157,6 +194,23 @@ class ListCreateAPIView(mixins.ListModelMixin,
return self.create(request, *args, **kwargs)
+class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
+ mixins.UpdateModelMixin,
+ SingleObjectAPIView):
+ """
+ Concrete view for retrieving, updating a model instance.
+ """
+ def get(self, request, *args, **kwargs):
+ return self.retrieve(request, *args, **kwargs)
+
+ def put(self, request, *args, **kwargs):
+ return self.update(request, *args, **kwargs)
+
+ def patch(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
+
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
SingleObjectAPIView):
@@ -183,5 +237,9 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
+ def patch(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 9bd566da..e0ae216e 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -3,9 +3,6 @@ Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways.
-
-Eg. Use mixins to build a Resource class, and have a Router class
- perform the binding of http methods to actions for us.
"""
from django.http import Http404
from rest_framework import status
@@ -18,30 +15,42 @@ class CreateModelMixin(object):
Should be mixed in with any `BaseView`.
"""
def create(self, request, *args, **kwargs):
- serializer = self.get_serializer(data=request.DATA)
+ serializer = self.get_serializer(data=request.DATA, files=request.FILES)
+
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
- return Response(serializer.data, status=status.HTTP_201_CREATED)
+ headers = self.get_success_headers(serializer.data)
+ return Response(serializer.data, status=status.HTTP_201_CREATED,
+ headers=headers)
+
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ def get_success_headers(self, data):
+ try:
+ return {'Location': data['url']}
+ except (TypeError, KeyError):
+ return {}
+
class ListModelMixin(object):
"""
List a queryset.
- Should be mixed in with `MultipleObjectBaseView`.
+ Should be mixed in with `MultipleObjectAPIView`.
"""
empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs):
- self.object_list = self.get_queryset()
+ queryset = self.get_queryset()
+ self.object_list = self.filter_queryset(queryset)
# Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets.
allow_empty = self.get_allow_empty()
- if not allow_empty and len(self.object_list) == 0:
- error_args = {'class_name': self.__class__.__name__}
- raise Http404(self.empty_error % error_args)
+ if not allow_empty and not self.object_list:
+ class_name = self.__class__.__name__
+ error_msg = self.empty_error % {'class_name': class_name}
+ raise Http404(error_msg)
# Pagination size is set by the `.paginate_by` attribute,
# which may be `None` to disable pagination.
@@ -51,7 +60,7 @@ class ListModelMixin(object):
paginator, page, queryset, is_paginated = packed
serializer = self.get_pagination_serializer(page)
else:
- serializer = self.get_serializer(instance=self.object_list)
+ serializer = self.get_serializer(self.object_list)
return Response(serializer.data)
@@ -63,7 +72,7 @@ class RetrieveModelMixin(object):
"""
def retrieve(self, request, *args, **kwargs):
self.object = self.get_object()
- serializer = self.get_serializer(instance=self.object)
+ serializer = self.get_serializer(self.object)
return Response(serializer.data)
@@ -73,17 +82,21 @@ class UpdateModelMixin(object):
Should be mixed in with `SingleObjectBaseView`.
"""
def update(self, request, *args, **kwargs):
+ partial = kwargs.pop('partial', False)
try:
self.object = self.get_object()
+ success_status_code = status.HTTP_200_OK
except Http404:
self.object = None
+ success_status_code = status.HTTP_201_CREATED
- serializer = self.get_serializer(data=request.DATA, instance=self.object)
+ serializer = self.get_serializer(self.object, data=request.DATA,
+ files=request.FILES, partial=partial)
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
- return Response(serializer.data)
+ return Response(serializer.data, status=success_status_code)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@@ -101,6 +114,11 @@ class UpdateModelMixin(object):
slug_field = self.get_slug_field()
setattr(obj, slug_field, slug)
+ # Ensure we clean the attributes so that we don't eg return integer
+ # pk using a string representation, as provided by the url conf kwarg.
+ if hasattr(obj, 'full_clean'):
+ obj.full_clean()
+
class DestroyModelMixin(object):
"""
@@ -108,6 +126,6 @@ class DestroyModelMixin(object):
Should be mixed in with `SingleObjectBaseView`.
"""
def destroy(self, request, *args, **kwargs):
- self.object = self.get_object()
- self.object.delete()
+ obj = self.get_object()
+ obj.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py
index 444f8056..ee2800a6 100644
--- a/rest_framework/negotiation.py
+++ b/rest_framework/negotiation.py
@@ -1,6 +1,8 @@
+from django.http import Http404
from rest_framework import exceptions
from rest_framework.settings import api_settings
from rest_framework.utils.mediatypes import order_by_precedence, media_type_matches
+from rest_framework.utils.mediatypes import _MediaType
class BaseContentNegotiation(object):
@@ -47,7 +49,8 @@ class DefaultContentNegotiation(BaseContentNegotiation):
for media_type in media_type_set:
if media_type_matches(renderer.media_type, media_type):
# Return the most specific media type as accepted.
- if len(renderer.media_type) > len(media_type):
+ if (_MediaType(renderer.media_type).precedence >
+ _MediaType(media_type).precedence):
# Eg client requests '*/*'
# Accepted media type is 'application/json'
return renderer, renderer.media_type
@@ -66,7 +69,7 @@ class DefaultContentNegotiation(BaseContentNegotiation):
renderers = [renderer for renderer in renderers
if renderer.format == format]
if not renderers:
- raise exceptions.InvalidFormat(format)
+ raise Http404
return renderers
def get_accept_list(self, request):
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 131718fd..d241ade7 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -1,4 +1,5 @@
from rest_framework import serializers
+from rest_framework.templatetags.rest_framework import replace_query_param
# TODO: Support URLconf kwarg-style paging
@@ -7,30 +8,30 @@ class NextPageField(serializers.Field):
"""
Field that returns a link to the next page in paginated results.
"""
+ page_field = 'page'
+
def to_native(self, value):
if not value.has_next():
return None
page = value.next_page_number()
request = self.context.get('request')
- relative_url = '?page=%d' % page
- if request:
- return request.build_absolute_uri(relative_url)
- return relative_url
+ url = request and request.build_absolute_uri() or ''
+ return replace_query_param(url, self.page_field, page)
class PreviousPageField(serializers.Field):
"""
Field that returns a link to the previous page in paginated results.
"""
+ page_field = 'page'
+
def to_native(self, value):
if not value.has_previous():
return None
page = value.previous_page_number()
request = self.context.get('request')
- relative_url = '?page=%d' % page
- if request:
- return request.build_absolute_uri('?page=%d' % page)
- return relative_url
+ url = request and request.build_absolute_uri() or ''
+ return replace_query_param(url, self.page_field, page)
class PaginationSerializerOptions(serializers.SerializerOptions):
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 4841676c..149d6431 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -8,11 +8,11 @@ on the request, such as form content or json encoded data.
from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError
-from django.utils import simplejson as json
from rest_framework.compat import yaml, ETParseError
from rest_framework.exceptions import ParseError
from xml.etree import ElementTree as ET
from xml.parsers.expat import ExpatError
+import json
import datetime
import decimal
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index 6f848cee..655b78a3 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -18,6 +18,17 @@ class BasePermission(object):
raise NotImplementedError(".has_permission() must be overridden.")
+class AllowAny(BasePermission):
+ """
+ Allow any access.
+ This isn't strictly required, since you could use an empty
+ permission_classes list, but it's useful because it makes the intention
+ more explicit.
+ """
+ def has_permission(self, request, view, obj=None):
+ return True
+
+
class IsAuthenticated(BasePermission):
"""
Allows access only to authenticated users.
@@ -85,7 +96,7 @@ class DjangoModelPermissions(BasePermission):
"""
kwargs = {
'app_label': model_cls._meta.app_label,
- 'model_name': model_cls._meta.module_name
+ 'model_name': model_cls._meta.module_name
}
return [perm % kwargs for perm in self.perms_map[method]]
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
new file mode 100644
index 00000000..5e4552b7
--- /dev/null
+++ b/rest_framework/relations.py
@@ -0,0 +1,509 @@
+from django.core.exceptions import ObjectDoesNotExist, ValidationError
+from django.core.urlresolvers import resolve, get_script_prefix
+from django import forms
+from django.forms import widgets
+from django.forms.models import ModelChoiceIterator
+from django.utils.encoding import smart_unicode
+from django.utils.translation import ugettext_lazy as _
+from rest_framework.fields import Field, WritableField
+from rest_framework.reverse import reverse
+from urlparse import urlparse
+
+##### Relational fields #####
+
+
+# Not actually Writable, but subclasses may need to be.
+class RelatedField(WritableField):
+ """
+ Base class for related model fields.
+
+ If not overridden, this represents a to-one relationship, using the unicode
+ representation of the target.
+ """
+ widget = widgets.Select
+ cache_choices = False
+ empty_label = None
+ default_read_only = True # TODO: Remove this
+
+ def __init__(self, *args, **kwargs):
+ self.queryset = kwargs.pop('queryset', None)
+ self.null = kwargs.pop('null', False)
+ super(RelatedField, self).__init__(*args, **kwargs)
+ self.read_only = kwargs.pop('read_only', self.default_read_only)
+
+ def initialize(self, parent, field_name):
+ super(RelatedField, self).initialize(parent, field_name)
+ if self.queryset is None and not self.read_only:
+ try:
+ manager = getattr(self.parent.opts.model, self.source or field_name)
+ if hasattr(manager, 'related'): # Forward
+ self.queryset = manager.related.model._default_manager.all()
+ else: # Reverse
+ self.queryset = manager.field.rel.to._default_manager.all()
+ except:
+ raise
+ msg = ('Serializer related fields must include a `queryset`' +
+ ' argument or set `read_only=True')
+ raise Exception(msg)
+
+ ### We need this stuff to make form choices work...
+
+ # def __deepcopy__(self, memo):
+ # result = super(RelatedField, self).__deepcopy__(memo)
+ # result.queryset = result.queryset
+ # return result
+
+ def prepare_value(self, obj):
+ return self.to_native(obj)
+
+ def label_from_instance(self, obj):
+ """
+ Return a readable representation for use with eg. select widgets.
+ """
+ desc = smart_unicode(obj)
+ ident = smart_unicode(self.to_native(obj))
+ if desc == ident:
+ return desc
+ return "%s - %s" % (desc, ident)
+
+ def _get_queryset(self):
+ return self._queryset
+
+ def _set_queryset(self, queryset):
+ self._queryset = queryset
+ self.widget.choices = self.choices
+
+ queryset = property(_get_queryset, _set_queryset)
+
+ def _get_choices(self):
+ # If self._choices is set, then somebody must have manually set
+ # the property self.choices. In this case, just return self._choices.
+ if hasattr(self, '_choices'):
+ return self._choices
+
+ # Otherwise, execute the QuerySet in self.queryset to determine the
+ # choices dynamically. Return a fresh ModelChoiceIterator that has not been
+ # consumed. Note that we're instantiating a new ModelChoiceIterator *each*
+ # time _get_choices() is called (and, thus, each time self.choices is
+ # accessed) so that we can ensure the QuerySet has not been consumed. This
+ # construct might look complicated but it allows for lazy evaluation of
+ # the queryset.
+ return ModelChoiceIterator(self)
+
+ def _set_choices(self, value):
+ # Setting choices also sets the choices on the widget.
+ # choices can be any iterable, but we call list() on it because
+ # it will be consumed more than once.
+ self._choices = self.widget.choices = list(value)
+
+ choices = property(_get_choices, _set_choices)
+
+ ### Regular serializer stuff...
+
+ def field_to_native(self, obj, field_name):
+ try:
+ value = getattr(obj, self.source or field_name)
+ except ObjectDoesNotExist:
+ return None
+ return self.to_native(value)
+
+ def field_from_native(self, data, files, field_name, into):
+ if self.read_only:
+ return
+
+ try:
+ value = data[field_name]
+ except KeyError:
+ if self.required:
+ raise ValidationError(self.error_messages['required'])
+ return
+
+ if value in (None, '') and not self.null:
+ raise ValidationError('Value may not be null')
+ elif value in (None, '') and self.null:
+ into[(self.source or field_name)] = None
+ else:
+ into[(self.source or field_name)] = self.from_native(value)
+
+
+class ManyRelatedMixin(object):
+ """
+ Mixin to convert a related field to a many related field.
+ """
+ widget = widgets.SelectMultiple
+
+ def field_to_native(self, obj, field_name):
+ value = getattr(obj, self.source or field_name)
+ return [self.to_native(item) for item in value.all()]
+
+ def field_from_native(self, data, files, field_name, into):
+ if self.read_only:
+ return
+
+ try:
+ # Form data
+ value = data.getlist(self.source or field_name)
+ except:
+ # Non-form data
+ value = data.get(self.source or field_name)
+ else:
+ if value == ['']:
+ value = []
+
+ into[field_name] = [self.from_native(item) for item in value]
+
+
+class ManyRelatedField(ManyRelatedMixin, RelatedField):
+ """
+ Base class for related model managers.
+
+ If not overridden, this represents a to-many relationship, using the unicode
+ representations of the target, and is read-only.
+ """
+ pass
+
+
+### PrimaryKey relationships
+
+class PrimaryKeyRelatedField(RelatedField):
+ """
+ Represents a to-one relationship as a pk value.
+ """
+ default_read_only = False
+ form_field_class = forms.ChoiceField
+
+ default_error_messages = {
+ 'does_not_exist': _("Invalid pk '%s' - object does not exist."),
+ 'invalid': _('Invalid value.'),
+ }
+
+ # TODO: Remove these field hacks...
+ def prepare_value(self, obj):
+ return self.to_native(obj.pk)
+
+ def label_from_instance(self, obj):
+ """
+ Return a readable representation for use with eg. select widgets.
+ """
+ desc = smart_unicode(obj)
+ ident = smart_unicode(self.to_native(obj.pk))
+ if desc == ident:
+ return desc
+ return "%s - %s" % (desc, ident)
+
+ # TODO: Possibly change this to just take `obj`, through prob less performant
+ def to_native(self, pk):
+ return pk
+
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(pk=data)
+ except ObjectDoesNotExist:
+ msg = self.error_messages['does_not_exist'] % smart_unicode(data)
+ raise ValidationError(msg)
+ except (TypeError, ValueError):
+ msg = self.error_messages['invalid']
+ raise ValidationError(msg)
+
+ def field_to_native(self, obj, field_name):
+ try:
+ # Prefer obj.serializable_value for performance reasons
+ pk = obj.serializable_value(self.source or field_name)
+ except AttributeError:
+ # RelatedObject (reverse relationship)
+ try:
+ obj = getattr(obj, self.source or field_name)
+ except ObjectDoesNotExist:
+ return None
+ return self.to_native(obj.pk)
+ # Forward relationship
+ return self.to_native(pk)
+
+
+class ManyPrimaryKeyRelatedField(ManyRelatedField):
+ """
+ Represents a to-many relationship as a pk value.
+ """
+ default_read_only = False
+ form_field_class = forms.MultipleChoiceField
+
+ default_error_messages = {
+ 'does_not_exist': _("Invalid pk '%s' - object does not exist."),
+ 'invalid': _('Invalid value.'),
+ }
+
+ def prepare_value(self, obj):
+ return self.to_native(obj.pk)
+
+ def label_from_instance(self, obj):
+ """
+ Return a readable representation for use with eg. select widgets.
+ """
+ desc = smart_unicode(obj)
+ ident = smart_unicode(self.to_native(obj.pk))
+ if desc == ident:
+ return desc
+ return "%s - %s" % (desc, ident)
+
+ def to_native(self, pk):
+ return pk
+
+ def field_to_native(self, obj, field_name):
+ try:
+ # Prefer obj.serializable_value for performance reasons
+ queryset = obj.serializable_value(self.source or field_name)
+ except AttributeError:
+ # RelatedManager (reverse relationship)
+ queryset = getattr(obj, self.source or field_name)
+ return [self.to_native(item.pk) for item in queryset.all()]
+ # Forward relationship
+ return [self.to_native(item.pk) for item in queryset.all()]
+
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(pk=data)
+ except ObjectDoesNotExist:
+ msg = self.error_messages['does_not_exist'] % smart_unicode(data)
+ raise ValidationError(msg)
+ except (TypeError, ValueError):
+ msg = self.error_messages['invalid']
+ raise ValidationError(msg)
+
+### Slug relationships
+
+
+class SlugRelatedField(RelatedField):
+ default_read_only = False
+ form_field_class = forms.ChoiceField
+
+ default_error_messages = {
+ 'does_not_exist': _("Object with %s=%s does not exist."),
+ 'invalid': _('Invalid value.'),
+ }
+
+ def __init__(self, *args, **kwargs):
+ self.slug_field = kwargs.pop('slug_field', None)
+ assert self.slug_field, 'slug_field is required'
+ super(SlugRelatedField, self).__init__(*args, **kwargs)
+
+ def to_native(self, obj):
+ return getattr(obj, self.slug_field)
+
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(**{self.slug_field: data})
+ except ObjectDoesNotExist:
+ raise ValidationError(self.error_messages['does_not_exist'] %
+ (self.slug_field, unicode(data)))
+ except (TypeError, ValueError):
+ msg = self.error_messages['invalid']
+ raise ValidationError(msg)
+
+
+class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField):
+ form_field_class = forms.MultipleChoiceField
+
+
+### Hyperlinked relationships
+
+class HyperlinkedRelatedField(RelatedField):
+ """
+ Represents a to-one relationship, using hyperlinking.
+ """
+ pk_url_kwarg = 'pk'
+ slug_field = 'slug'
+ slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
+ default_read_only = False
+ form_field_class = forms.ChoiceField
+
+ default_error_messages = {
+ 'no_match': _('Invalid hyperlink - No URL match'),
+ 'incorrect_match': _('Invalid hyperlink - Incorrect URL match'),
+ 'configuration_error': _('Invalid hyperlink due to configuration error'),
+ 'does_not_exist': _("Invalid hyperlink - object does not exist."),
+ 'invalid': _('Invalid value.'),
+ }
+
+ def __init__(self, *args, **kwargs):
+ try:
+ self.view_name = kwargs.pop('view_name')
+ except:
+ raise ValueError("Hyperlinked field requires 'view_name' kwarg")
+
+ self.slug_field = kwargs.pop('slug_field', self.slug_field)
+ default_slug_kwarg = self.slug_url_kwarg or self.slug_field
+ self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
+ self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
+
+ self.format = kwargs.pop('format', None)
+ super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
+
+ def get_slug_field(self):
+ """
+ Get the name of a slug field to be used to look up by slug.
+ """
+ return self.slug_field
+
+ def to_native(self, obj):
+ view_name = self.view_name
+ request = self.context.get('request', None)
+ format = self.format or self.context.get('format', None)
+ pk = getattr(obj, 'pk', None)
+ if pk is None:
+ return
+ kwargs = {self.pk_url_kwarg: pk}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ slug = getattr(obj, self.slug_field, None)
+
+ if not slug:
+ raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+
+ def from_native(self, value):
+ # Convert URL -> model instance pk
+ # TODO: Use values_list
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ http_prefix = value.startswith('http:') or value.startswith('https:')
+ except AttributeError:
+ msg = self.error_messages['invalid']
+ raise ValidationError(msg)
+
+ if http_prefix:
+ # If needed convert absolute URLs to relative path
+ value = urlparse(value).path
+ prefix = get_script_prefix()
+ if value.startswith(prefix):
+ value = '/' + value[len(prefix):]
+
+ try:
+ match = resolve(value)
+ except:
+ raise ValidationError(self.error_messages['no_match'])
+
+ if match.view_name != self.view_name:
+ raise ValidationError(self.error_messages['incorrect_match'])
+
+ pk = match.kwargs.get(self.pk_url_kwarg, None)
+ slug = match.kwargs.get(self.slug_url_kwarg, None)
+
+ # Try explicit primary key.
+ if pk is not None:
+ queryset = self.queryset.filter(pk=pk)
+ # Next, try looking up by slug.
+ elif slug is not None:
+ slug_field = self.get_slug_field()
+ queryset = self.queryset.filter(**{slug_field: slug})
+ # If none of those are defined, it's probably a configuation error.
+ else:
+ raise ValidationError(self.error_messages['configuration_error'])
+
+ try:
+ obj = queryset.get()
+ except ObjectDoesNotExist:
+ raise ValidationError(self.error_messages['does_not_exist'])
+ except (TypeError, ValueError):
+ msg = self.error_messages['invalid']
+ raise ValidationError(msg)
+
+ return obj
+
+
+class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
+ """
+ Represents a to-many relationship, using hyperlinking.
+ """
+ form_field_class = forms.MultipleChoiceField
+
+
+class HyperlinkedIdentityField(Field):
+ """
+ Represents the instance, or a property on the instance, using hyperlinking.
+ """
+ pk_url_kwarg = 'pk'
+ slug_field = 'slug'
+ slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
+
+ def __init__(self, *args, **kwargs):
+ # TODO: Make view_name mandatory, and have the
+ # HyperlinkedModelSerializer set it on-the-fly
+ self.view_name = kwargs.pop('view_name', None)
+ # Optionally the format of the target hyperlink may be specified
+ self.format = kwargs.pop('format', None)
+
+ self.slug_field = kwargs.pop('slug_field', self.slug_field)
+ default_slug_kwarg = self.slug_url_kwarg or self.slug_field
+ self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
+ self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
+
+ super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
+
+ def field_to_native(self, obj, field_name):
+ request = self.context.get('request', None)
+ format = self.context.get('format', None)
+ view_name = self.view_name or self.parent.opts.view_name
+ kwargs = {self.pk_url_kwarg: obj.pk}
+
+ # By default use whatever format is given for the current context
+ # unless the target is a different type to the source.
+ #
+ # Eg. Consider a HyperlinkedIdentityField pointing from a json
+ # representation to an html property of that representation...
+ #
+ # '/snippets/1/' should link to '/snippets/1/highlight/'
+ # ...but...
+ # '/snippets/1/.json' should link to '/snippets/1/highlight/.html'
+ if format and self.format and self.format != format:
+ format = self.format
+
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ slug = getattr(obj, self.slug_field, None)
+
+ if not slug:
+ raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index c64fb517..0a34abaa 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -4,14 +4,14 @@ Renderers are used to serialize a response into specific media types.
They give us a generic way of being able to handle various media types
on the response, such as JSON encoded data or HTML output.
-REST framework also provides an HTML renderer the renders the browseable API.
+REST framework also provides an HTML renderer the renders the browsable API.
"""
import copy
import string
+import json
from django import forms
from django.http.multipartparser import parse_header
-from django.template import RequestContext, loader
-from django.utils import simplejson as json
+from django.template import RequestContext, loader, Template
from rest_framework.compat import yaml
from rest_framework.exceptions import ConfigurationError
from rest_framework.settings import api_settings
@@ -19,8 +19,8 @@ from rest_framework.request import clone_request
from rest_framework.utils import dict2xml
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
-from rest_framework import VERSION
-from rest_framework import serializers, parsers
+from rest_framework import VERSION, status
+from rest_framework import parsers
class BaseRenderer(object):
@@ -100,7 +100,7 @@ class JSONPRenderer(JSONRenderer):
callback = self.get_callback(renderer_context)
json = super(JSONPRenderer, self).render(data, accepted_media_type,
renderer_context)
- return "%s(%s);" % (callback, json)
+ return u"%s(%s);" % (callback, json)
class XMLRenderer(BaseRenderer):
@@ -139,18 +139,33 @@ class YAMLRenderer(BaseRenderer):
return yaml.dump(data, stream=None, Dumper=self.encoder)
-class HTMLRenderer(BaseRenderer):
+class TemplateHTMLRenderer(BaseRenderer):
"""
- A Base class provided for convenience.
+ An HTML renderer for use with templates.
- Render the object simply by using the given template.
- To create a template renderer, subclass this class, and set
- the :attr:`media_type` and :attr:`template` attributes.
+ The data supplied to the Response object should be a dictionary that will
+ be used as context for the template.
+
+ The template name is determined by (in order of preference):
+
+ 1. An explicit `.template_name` attribute set on the response.
+ 2. An explicit `.template_name` attribute set on this class.
+ 3. The return result of calling `view.get_template_names()`.
+
+ For example:
+ data = {'users': User.objects.all()}
+ return Response(data, template_name='users.html')
+
+ For pre-rendered HTML, see StaticHTMLRenderer.
"""
media_type = 'text/html'
format = 'html'
template_name = None
+ exception_template_names = [
+ '%(status_code)s.html',
+ 'api_exception.html'
+ ]
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -167,15 +182,21 @@ class HTMLRenderer(BaseRenderer):
request = renderer_context['request']
response = renderer_context['response']
- template_names = self.get_template_names(response, view)
- template = self.resolve_template(template_names)
- context = self.resolve_context(data, request)
+ if response.exception:
+ template = self.get_exception_template(response)
+ else:
+ template_names = self.get_template_names(response, view)
+ template = self.resolve_template(template_names)
+
+ context = self.resolve_context(data, request, response)
return template.render(context)
def resolve_template(self, template_names):
return loader.select_template(template_names)
- def resolve_context(self, data, request):
+ def resolve_context(self, data, request, response):
+ if response.exception:
+ data['status_code'] = response.status_code
return RequestContext(request, data)
def get_template_names(self, response, view):
@@ -187,6 +208,48 @@ class HTMLRenderer(BaseRenderer):
return view.get_template_names()
raise ConfigurationError('Returned a template response with no template_name')
+ def get_exception_template(self, response):
+ template_names = [name % {'status_code': response.status_code}
+ for name in self.exception_template_names]
+
+ try:
+ # Try to find an appropriate error template
+ return self.resolve_template(template_names)
+ except:
+ # Fall back to using eg '404 Not Found'
+ return Template('%d %s' % (response.status_code,
+ response.status_text.title()))
+
+
+# Note, subclass TemplateHTMLRenderer simply for the exception behavior
+class StaticHTMLRenderer(TemplateHTMLRenderer):
+ """
+ An HTML renderer class that simply returns pre-rendered HTML.
+
+ The data supplied to the Response object should be a string representing
+ the pre-rendered HTML content.
+
+ For example:
+ data = '<html><body>example</body></html>'
+ return Response(data)
+
+ For template rendered HTML, see TemplateHTMLRenderer.
+ """
+ media_type = 'text/html'
+ format = 'html'
+
+ def render(self, data, accepted_media_type=None, renderer_context=None):
+ renderer_context = renderer_context or {}
+ response = renderer_context['response']
+
+ if response and response.exception:
+ request = renderer_context['request']
+ template = self.get_exception_template(response)
+ context = self.resolve_context(data, request, response)
+ return template.render(context)
+
+ return data
+
class BrowsableAPIRenderer(BaseRenderer):
"""
@@ -224,7 +287,7 @@ class BrowsableAPIRenderer(BaseRenderer):
return content
- def show_form_for_method(self, view, method, request):
+ def show_form_for_method(self, view, method, request, obj):
"""
Returns True if a form should be shown for this method.
"""
@@ -236,46 +299,32 @@ class BrowsableAPIRenderer(BaseRenderer):
request = clone_request(request, method)
try:
- if not view.has_permission(request):
+ if not view.has_permission(request, obj):
return # Don't have permission
except:
return # Don't have permission and exception explicitly raise
return True
def serializer_to_form_fields(self, serializer):
- field_mapping = {
- serializers.FloatField: forms.FloatField,
- serializers.IntegerField: forms.IntegerField,
- serializers.DateTimeField: forms.DateTimeField,
- serializers.DateField: forms.DateField,
- serializers.EmailField: forms.EmailField,
- serializers.CharField: forms.CharField,
- serializers.BooleanField: forms.BooleanField,
- serializers.PrimaryKeyRelatedField: forms.ModelChoiceField,
- serializers.ManyPrimaryKeyRelatedField: forms.ModelMultipleChoiceField
- }
-
fields = {}
- for k, v in serializer.get_fields(True).items():
- if getattr(v, 'readonly', True):
+ for k, v in serializer.get_fields().items():
+ if getattr(v, 'read_only', True):
continue
kwargs = {}
kwargs['required'] = v.required
- if getattr(v, 'queryset', None):
- kwargs['queryset'] = v.queryset
+ #if getattr(v, 'queryset', None):
+ # kwargs['queryset'] = v.queryset
+
+ if getattr(v, 'choices', None) is not None:
+ kwargs['choices'] = v.choices
+
+ if getattr(v, 'regex', None) is not None:
+ kwargs['regex'] = v.regex
if getattr(v, 'widget', None):
widget = copy.deepcopy(v.widget)
- # If choices have friendly readable names,
- # then add in the identities too
- if getattr(widget, 'choices', None):
- choices = widget.choices
- if any([ident != desc for (ident, desc) in choices]):
- choices = [(ident, "%s (%s)" % (desc, ident))
- for (ident, desc) in choices]
- widget.choices = choices
kwargs['widget'] = widget
if getattr(v, 'default', None) is not None:
@@ -283,10 +332,7 @@ class BrowsableAPIRenderer(BaseRenderer):
kwargs['label'] = k
- try:
- fields[k] = field_mapping[v.__class__](**kwargs)
- except KeyError:
- fields[k] = forms.CharField(**kwargs)
+ fields[k] = v.form_field_class(**kwargs)
return fields
def get_form(self, view, method, request):
@@ -295,7 +341,8 @@ class BrowsableAPIRenderer(BaseRenderer):
In the absence on of the Resource having an associated form then
provide a form that can be used to submit arbitrary content.
"""
- if not self.show_form_for_method(view, method, request):
+ obj = getattr(view, 'object', None)
+ if not self.show_form_for_method(view, method, request, obj):
return
if method == 'DELETE' or method == 'OPTIONS':
@@ -305,17 +352,13 @@ class BrowsableAPIRenderer(BaseRenderer):
media_types = [parser.media_type for parser in view.parser_classes]
return self.get_generic_content_form(media_types)
- # Creating an on the fly form see: http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
- obj, data = None, None
- if getattr(view, 'object', None):
- obj = view.object
-
serializer = view.get_serializer(instance=obj)
fields = self.serializer_to_form_fields(serializer)
+ # Creating an on the fly form see:
+ # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields)
- if obj:
- data = serializer.data
+ data = (obj is not None) and serializer.data or None
form_instance = OnTheFlyForm(data)
return form_instance
@@ -416,7 +459,7 @@ class BrowsableAPIRenderer(BaseRenderer):
# Munge DELETE Response code to allow us to return content
# (Do this *after* we've rendered the template so that we include
# the normal deletion response code in the output)
- if response.status_code == 204:
- response.status_code = 200
+ if response.status_code == status.HTTP_204_NO_CONTENT:
+ response.status_code = status.HTTP_200_OK
return ret
diff --git a/rest_framework/request.py b/rest_framework/request.py
index 5870be82..b7133608 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -21,8 +21,8 @@ def is_form_media_type(media_type):
Return True if the media type is a valid form media type.
"""
base_media_type, params = parse_header(media_type)
- return base_media_type == 'application/x-www-form-urlencoded' or \
- base_media_type == 'multipart/form-data'
+ return (base_media_type == 'application/x-www-form-urlencoded' or
+ base_media_type == 'multipart/form-data')
class Empty(object):
@@ -169,6 +169,15 @@ class Request(object):
self._user, self._auth = self._authenticate()
return self._user
+ @user.setter
+ def user(self, value):
+ """
+ Sets the user on the current request. This is necessary to maintain
+ compatilbility with django.contrib.auth where the user proprety is
+ set in the login and logout functions.
+ """
+ self._user = value
+
@property
def auth(self):
"""
@@ -179,6 +188,14 @@ class Request(object):
self._user, self._auth = self._authenticate()
return self._auth
+ @auth.setter
+ def auth(self, value):
+ """
+ Sets any non-user authentication information associated with the
+ request, such as an authentication token.
+ """
+ self._auth = value
+
def _load_data_and_files(self):
"""
Parses the request content into self.DATA and self.FILES.
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 7a459c8f..be78c43a 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -9,18 +9,23 @@ class Response(SimpleTemplateResponse):
"""
def __init__(self, data=None, status=200,
- template_name=None, headers=None):
+ template_name=None, headers=None,
+ exception=False):
"""
Alters the init arguments slightly.
For example, drop 'template_name', and instead use 'data'.
- Setting 'renderer' and 'media_type' will typically be defered,
+ Setting 'renderer' and 'media_type' will typically be deferred,
For example being set automatically by the `APIView`.
"""
super(Response, self).__init__(None, status=status)
self.data = data
- self.headers = headers and headers[:] or []
self.template_name = template_name
+ self.exception = exception
+
+ if headers:
+ for name,value in headers.iteritems():
+ self[name] = value
@property
def rendered_content(self):
@@ -45,3 +50,13 @@ class Response(SimpleTemplateResponse):
# TODO: Deprecate and use a template tag instead
# TODO: Status code text for RFC 6585 status codes
return STATUS_CODE_TEXT.get(self.status_code, '')
+
+ def __getstate__(self):
+ """
+ Remove attributes from the response that shouldn't be cached
+ """
+ state = super(Response, self).__getstate__()
+ for key in ('accepted_renderer', 'renderer_context', 'data'):
+ if key in state:
+ del state[key]
+ return state
diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py
index ba663f98..c9db02f0 100644
--- a/rest_framework/reverse.py
+++ b/rest_framework/reverse.py
@@ -5,13 +5,15 @@ from django.core.urlresolvers import reverse as django_reverse
from django.utils.functional import lazy
-def reverse(viewname, *args, **kwargs):
+def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
"""
Same as `django.core.urlresolvers.reverse`, but optionally takes a request
and returns a fully qualified URL, using the request to get the base URL.
"""
- request = kwargs.pop('request', None)
- url = django_reverse(viewname, *args, **kwargs)
+ if format is not None:
+ kwargs = kwargs or {}
+ kwargs['format'] = format
+ url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
if request:
return request.build_absolute_uri(url)
return url
diff --git a/rest_framework/runtests/runcoverage.py b/rest_framework/runtests/runcoverage.py
index ea2e3d45..bcab1d14 100755
--- a/rest_framework/runtests/runcoverage.py
+++ b/rest_framework/runtests/runcoverage.py
@@ -8,6 +8,9 @@ Useful tool to run the test suite for rest_framework and generate a coverage rep
# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
import os
import sys
+
+# fix sys path so we don't need to setup PYTHONPATH
+sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'
from coverage import coverage
@@ -32,10 +35,10 @@ def main():
'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
DeprecationWarning
)
- failures = TestRunner(['rest_framework'])
+ failures = TestRunner(['tests'])
else:
test_runner = TestRunner()
- failures = test_runner.run_tests(['rest_framework'])
+ failures = test_runner.run_tests(['tests'])
cov.stop()
# Discover the list of all modules that we should test coverage for
@@ -55,6 +58,12 @@ def main():
if 'compat.py' in files:
files.remove('compat.py')
+ # Same applies to template tags module.
+ # This module has to include branching on Django versions,
+ # so it's never possible for it to have full coverage.
+ if 'rest_framework.py' in files:
+ files.remove('rest_framework.py')
+
cov_files.extend([os.path.join(path, file) for file in files if file.endswith('.py')])
cov.report(cov_files)
diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py
index b2438c9b..505994e2 100755
--- a/rest_framework/runtests/runtests.py
+++ b/rest_framework/runtests/runtests.py
@@ -5,6 +5,9 @@
# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
import os
import sys
+
+# fix sys path so we don't need to setup PYTHONPATH
+sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'
from django.conf import settings
@@ -32,7 +35,7 @@ def main():
else:
print usage()
sys.exit(1)
- failures = test_runner.run_tests(['rest_framework' + test_case])
+ failures = test_runner.run_tests(['tests' + test_case])
sys.exit(failures)
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index 67de82c8..dd5d9dc3 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -21,6 +21,12 @@ DATABASES = {
}
}
+CACHES = {
+ 'default': {
+ 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
+ }
+}
+
# Local time zone for this installation. Choices can be found here:
# http://en.wikipedia.org/wiki/List_of_tz_zones_by_name
# although not all choices may be available on all operating systems.
@@ -91,6 +97,7 @@ INSTALLED_APPS = (
# 'django.contrib.admindocs',
'rest_framework',
'rest_framework.authtoken',
+ 'rest_framework.tests'
)
STATIC_URL = '/static/'
@@ -100,13 +107,6 @@ import django
if django.VERSION < (1, 3):
INSTALLED_APPS += ('staticfiles',)
-# OAuth support is optional, so we only test oauth if it's installed.
-try:
- import oauth_provider
-except ImportError:
- pass
-else:
- INSTALLED_APPS += ('oauth_provider',)
# If we're running on the Jenkins server we want to archive the coverage reports as XML.
import os
diff --git a/rest_framework/runtests/urls.py b/rest_framework/runtests/urls.py
index 4b7da787..ed5baeae 100644
--- a/rest_framework/runtests/urls.py
+++ b/rest_framework/runtests/urls.py
@@ -1,7 +1,7 @@
"""
Blank URLConf just to keep runtests.py happy.
"""
-from django.conf.urls.defaults import *
+from rest_framework.compat import patterns
urlpatterns = patterns('',
)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 8ee9a0ec..27458f96 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -3,8 +3,18 @@ import datetime
import types
from decimal import Decimal
from django.db import models
+from django.forms import widgets
from django.utils.datastructures import SortedDict
from rest_framework.compat import get_concrete_model
+
+# Note: We do the following so that users of the framework can use this style:
+#
+# example_field = serializers.CharField(...)
+#
+# This helps keep the seperation between model fields, form fields, and
+# serializer fields more explicit.
+
+from rest_framework.relations import *
from rest_framework.fields import *
@@ -12,7 +22,16 @@ class DictWithMetadata(dict):
"""
A dict-like object, that can have additional properties attached.
"""
- pass
+ def __getstate__(self):
+ """
+ Used by pickle (e.g., caching).
+ Overriden to remove metadata from the dict, since it shouldn't be pickled
+ and may in some instances be unpickleable.
+ """
+ # return an instance of the first dict in MRO that isn't a DictWithMetadata
+ for base in self.__class__.__mro__:
+ if not isinstance(base, DictWithMetadata) and isinstance(base, dict):
+ return base(self)
class SortedDictWithMetadata(SortedDict, DictWithMetadata):
@@ -22,10 +41,6 @@ class SortedDictWithMetadata(SortedDict, DictWithMetadata):
pass
-class RecursionOccured(BaseException):
- pass
-
-
def _is_protected_type(obj):
"""
True if the object is a native datatype that does not need to
@@ -33,10 +48,10 @@ def _is_protected_type(obj):
"""
return isinstance(obj, (
types.NoneType,
- int, long,
- datetime.datetime, datetime.date, datetime.time,
- float, Decimal,
- basestring)
+ int, long,
+ datetime.datetime, datetime.date, datetime.time,
+ float, Decimal,
+ basestring)
)
@@ -54,7 +69,7 @@ def _get_declared_fields(bases, attrs):
# If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
- # in order to the correct order of fields.
+ # in order to maintain the correct order of fields.
for base in bases[::-1]:
if hasattr(base, 'base_fields'):
fields = base.base_fields.items() + fields
@@ -73,7 +88,7 @@ class SerializerOptions(object):
Meta class options for Serializer
"""
def __init__(self, meta):
- self.nested = getattr(meta, 'nested', False)
+ self.depth = getattr(meta, 'depth', 0)
self.fields = getattr(meta, 'fields', ())
self.exclude = getattr(meta, 'exclude', ())
@@ -83,51 +98,53 @@ class BaseSerializer(Field):
pass
_options_class = SerializerOptions
- _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations.
+ _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations.
- def __init__(self, data=None, instance=None, context=None, **kwargs):
+ def __init__(self, instance=None, data=None, files=None,
+ context=None, partial=False, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
- self.fields = copy.deepcopy(self.base_fields)
self.opts = self._options_class(self.Meta)
self.parent = None
self.root = None
+ self.partial = partial
- self.stack = []
self.context = context or {}
self.init_data = data
+ self.init_files = files
self.object = instance
+ self.fields = self.get_fields()
self._data = None
+ self._files = None
self._errors = None
#####
# Methods to determine which fields to use when (de)serializing objects.
- def default_fields(self, serialize, obj=None, data=None, nested=False):
+ def get_default_fields(self):
"""
Return the complete set of default fields for the object, as a dict.
"""
return {}
- def get_fields(self, serialize, obj=None, data=None, nested=False):
+ def get_fields(self):
"""
Returns the complete set of fields for the object as a dict.
This will be the set of any explicitly declared fields,
- plus the set of fields returned by default_fields().
+ plus the set of fields returned by get_default_fields().
"""
ret = SortedDict()
# Get the explicitly declared fields
- for key, field in self.fields.items():
+ base_fields = copy.deepcopy(self.base_fields)
+ for key, field in base_fields.items():
ret[key] = field
- # Set up the field
- field.initialize(parent=self)
# Add in the default fields
- fields = self.default_fields(serialize, obj, data, nested)
- for key, val in fields.items():
+ default_fields = self.get_default_fields()
+ for key, val in default_fields.items():
if key not in ret:
ret[key] = val
@@ -143,25 +160,25 @@ class BaseSerializer(Field):
for key in self.opts.exclude:
ret.pop(key, None)
+ for key, field in ret.items():
+ field.initialize(parent=self, field_name=key)
+
return ret
#####
# Field methods - used when the serializer class is itself used as a field.
- def initialize(self, parent):
+ def initialize(self, parent, field_name):
"""
Same behaviour as usual Field, except that we need to keep track
- of state so that we can deal with handling maximum depth and recursion.
+ of state so that we can deal with handling maximum depth.
"""
- super(BaseSerializer, self).initialize(parent)
- self.stack = parent.stack[:]
- if parent.opts.nested and not isinstance(parent.opts.nested, bool):
- self.opts.nested = parent.opts.nested - 1
- else:
- self.opts.nested = parent.opts.nested
+ super(BaseSerializer, self).initialize(parent, field_name)
+ if parent.opts.depth:
+ self.opts.depth = parent.opts.depth - 1
#####
- # Methods to convert or revert from objects <--> primative representations.
+ # Methods to convert or revert from objects <--> primitive representations.
def get_field_key(self, field_name):
"""
@@ -174,35 +191,32 @@ class BaseSerializer(Field):
Core of serialization.
Convert an object into a dictionary of serialized field values.
"""
- if obj in self.stack and not self.source == '*':
- raise RecursionOccured()
- self.stack.append(obj)
-
ret = self._dict_class()
ret.fields = {}
- fields = self.get_fields(serialize=True, obj=obj, nested=self.opts.nested)
- for field_name, field in fields.items():
+ for field_name, field in self.fields.items():
+ field.initialize(parent=self, field_name=field_name)
key = self.get_field_key(field_name)
- try:
- value = field.field_to_native(obj, field_name)
- except RecursionOccured:
- field = self.get_fields(serialize=True, obj=obj, nested=False)[field_name]
- value = field.field_to_native(obj, field_name)
+ value = field.field_to_native(obj, field_name)
ret[key] = value
ret.fields[key] = field
return ret
- def restore_fields(self, data):
+ def restore_fields(self, data, files):
"""
Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields.
"""
- fields = self.get_fields(serialize=False, data=data, nested=self.opts.nested)
reverted_data = {}
- for field_name, field in fields.items():
+
+ if data is not None and not isinstance(data, dict):
+ self._errors['non_field_errors'] = [u'Invalid data']
+ return None
+
+ for field_name, field in self.fields.items():
+ field.initialize(parent=self, field_name=field_name)
try:
- field.field_from_native(data, field_name, reverted_data)
+ field.field_from_native(data, files, field_name, reverted_data)
except ValidationError as err:
self._errors[field_name] = list(err.messages)
@@ -212,9 +226,7 @@ class BaseSerializer(Field):
"""
Run `validate_<fieldname>()` and `validate()` methods on the serializer
"""
- fields = self.get_fields(serialize=False, data=attrs, nested=self.opts.nested)
-
- for field_name, field in fields.items():
+ for field_name, field in self.fields.items():
try:
validate_method = getattr(self, 'validate_%s' % field_name, None)
if validate_method:
@@ -223,10 +235,18 @@ class BaseSerializer(Field):
except ValidationError as err:
self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
- try:
- attrs = self.validate(attrs)
- except ValidationError as err:
- self._errors['non_field_errors'] = err.messages
+ # If there are already errors, we don't run .validate() because
+ # field-validation failed and thus `attrs` may not be complete.
+ # which in turn can cause inconsistent validation errors.
+ if not self._errors:
+ try:
+ attrs = self.validate(attrs)
+ except ValidationError as err:
+ if hasattr(err, 'message_dict'):
+ for field_name, error_messages in err.message_dict.items():
+ self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages)
+ elif hasattr(err, 'messages'):
+ self._errors['non_field_errors'] = err.messages
return attrs
@@ -249,26 +269,23 @@ class BaseSerializer(Field):
def to_native(self, obj):
"""
- Serialize objects -> primatives.
+ Serialize objects -> primitives.
"""
- if isinstance(obj, dict):
- return dict([(key, self.to_native(val))
- for (key, val) in obj.items()])
- elif hasattr(obj, '__iter__'):
- return [self.to_native(item) for item in obj]
+ if hasattr(obj, '__iter__'):
+ return [self.convert_object(item) for item in obj]
return self.convert_object(obj)
- def from_native(self, data):
+ def from_native(self, data, files):
"""
- Deserialize primatives -> objects.
+ Deserialize primitives -> objects.
"""
if hasattr(data, '__iter__') and not isinstance(data, dict):
# TODO: error data when deserializing lists
- return (self.from_native(item) for item in data)
+ return [self.from_native(item, None) for item in data]
self._errors = {}
- if data is not None:
- attrs = self.restore_fields(data)
+ if data is not None or files is not None:
+ attrs = self.restore_fields(data, files)
attrs = self.perform_validation(attrs)
else:
self._errors['non_field_errors'] = ['No input provided']
@@ -281,22 +298,36 @@ class BaseSerializer(Field):
Override default so that we can apply ModelSerializer as a nested
field to relationships.
"""
- obj = getattr(obj, self.source or field_name)
+ try:
+ if self.source:
+ for component in self.source.split('.'):
+ obj = getattr(obj, component)
+ if is_simple_callable(obj):
+ obj = obj()
+ else:
+ obj = getattr(obj, field_name)
+ if is_simple_callable(obj):
+ obj = obj()
+ except ObjectDoesNotExist:
+ return None
# If the object has an "all" method, assume it's a relationship
if is_simple_callable(getattr(obj, 'all', None)):
return [self.to_native(item) for item in obj.all()]
+ if obj is None:
+ return None
+
return self.to_native(obj)
@property
def errors(self):
"""
Run deserialization and return error data,
- setting self.object if no errors occured.
+ setting self.object if no errors occurred.
"""
if self._errors is None:
- obj = self.from_native(self.init_data)
+ obj = self.from_native(self.init_data, self.init_files)
if not self._errors:
self.object = obj
return self._errors
@@ -329,6 +360,7 @@ class ModelSerializerOptions(SerializerOptions):
def __init__(self, meta):
super(ModelSerializerOptions, self).__init__(meta)
self.model = getattr(meta, 'model', None)
+ self.read_only_fields = getattr(meta, 'read_only_fields', ())
class ModelSerializer(Serializer):
@@ -337,16 +369,10 @@ class ModelSerializer(Serializer):
"""
_options_class = ModelSerializerOptions
- def default_fields(self, serialize, obj=None, data=None, nested=False):
+ def get_default_fields(self):
"""
Return all the fields that should be serialized for the model.
"""
- # TODO: Modfiy this so that it's called on init, and drop
- # serialize/obj/data arguments.
- #
- # We *could* provide a hook for dynamic fields, but
- # it'd be nice if the default was to generate fields statically
- # at the point of __init__
cls = self.opts.model
opts = get_concrete_model(cls)._meta
@@ -358,6 +384,7 @@ class ModelSerializer(Serializer):
fields += [field for field in opts.many_to_many if field.serialize]
ret = SortedDict()
+ nested = bool(self.opts.depth)
is_pk = True # First field in the list is the pk
for model_field in fields:
@@ -374,22 +401,30 @@ class ModelSerializer(Serializer):
field = self.get_field(model_field)
if field:
- field.initialize(parent=self)
ret[model_field.name] = field
+ for field_name in self.opts.read_only_fields:
+ assert field_name in ret, \
+ "read_only_fields on '%s' included invalid item '%s'" % \
+ (self.__class__.__name__, field_name)
+ ret[field_name].read_only = True
+
return ret
def get_pk_field(self, model_field):
"""
Returns a default instance of the pk field.
"""
- return Field()
+ return self.get_field(model_field)
def get_nested_field(self, model_field):
"""
Creates a default instance of a nested relational field.
"""
- return ModelSerializer()
+ class NestedModelSerializer(ModelSerializer):
+ class Meta:
+ model = model_field.rel.to
+ return NestedModelSerializer()
def get_related_field(self, model_field, to_many=False):
"""
@@ -397,20 +432,43 @@ class ModelSerializer(Serializer):
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
- queryset = model_field.rel.to._default_manager
+ kwargs = {
+ 'null': model_field.null or model_field.blank,
+ 'queryset': model_field.rel.to._default_manager
+ }
+
if to_many:
- return ManyPrimaryKeyRelatedField(queryset=queryset)
- return PrimaryKeyRelatedField(queryset=queryset)
+ return ManyPrimaryKeyRelatedField(**kwargs)
+ return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
+
+ kwargs['blank'] = model_field.blank
+
+ if model_field.null or model_field.blank:
+ kwargs['required'] = False
+
+ if isinstance(model_field, models.AutoField) or not model_field.editable:
+ kwargs['read_only'] = True
+
if model_field.has_default():
kwargs['required'] = False
+ kwargs['default'] = model_field.get_default()
+
+ if model_field.__class__ == models.TextField:
+ kwargs['widget'] = widgets.Textarea
+
+ # TODO: TypedChoiceField?
+ if model_field.flatchoices: # This ModelField contains choices
+ kwargs['choices'] = model_field.flatchoices
+ return ChoiceField(**kwargs)
field_mapping = {
+ models.AutoField: IntegerField,
models.FloatField: FloatField,
models.IntegerField: IntegerField,
models.PositiveIntegerField: IntegerField,
@@ -420,42 +478,86 @@ class ModelSerializer(Serializer):
models.DateField: DateField,
models.EmailField: EmailField,
models.CharField: CharField,
+ models.URLField: URLField,
+ models.SlugField: SlugField,
models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField,
+ models.FileField: FileField,
+ models.ImageField: ImageField,
}
try:
return field_mapping[model_field.__class__](**kwargs)
except KeyError:
return ModelField(model_field=model_field, **kwargs)
+ def get_validation_exclusions(self):
+ """
+ Return a list of field names to exclude from model validation.
+ """
+ cls = self.opts.model
+ opts = get_concrete_model(cls)._meta
+ exclusions = [field.name for field in opts.fields + opts.many_to_many]
+ for field_name, field in self.fields.items():
+ if field_name in exclusions and not field.read_only:
+ exclusions.remove(field_name)
+ return exclusions
+
def restore_object(self, attrs, instance=None):
"""
Restore the model instance.
"""
self.m2m_data = {}
+ self.related_data = {}
- if instance:
- for key, val in attrs.items():
- setattr(instance, key, val)
- return instance
+ # Reverse fk relations
+ for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model():
+ field_name = obj.field.related_query_name()
+ if field_name in attrs:
+ self.related_data[field_name] = attrs.pop(field_name)
+ # Reverse m2m relations
+ for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model():
+ field_name = obj.field.related_query_name()
+ if field_name in attrs:
+ self.m2m_data[field_name] = attrs.pop(field_name)
+
+ # Forward m2m relations
for field in self.opts.model._meta.many_to_many:
if field.name in attrs:
self.m2m_data[field.name] = attrs.pop(field.name)
- return self.opts.model(**attrs)
- def save(self, save_m2m=True):
+ if instance is not None:
+ for key, val in attrs.items():
+ setattr(instance, key, val)
+
+ else:
+ instance = self.opts.model(**attrs)
+
+ try:
+ instance.full_clean(exclude=self.get_validation_exclusions())
+ except ValidationError, err:
+ self._errors = err.message_dict
+ return None
+
+ return instance
+
+ def save(self):
"""
Save the deserialized object and return it.
"""
self.object.save()
- if self.m2m_data and save_m2m:
+ if getattr(self, 'm2m_data', None):
for accessor_name, object_list in self.m2m_data.items():
setattr(self.object, accessor_name, object_list)
self.m2m_data = {}
+ if getattr(self, 'related_data', None):
+ for accessor_name, object_list in self.related_data.items():
+ setattr(self.object, accessor_name, object_list)
+ self.related_data = {}
+
return self.object
@@ -502,9 +604,9 @@ class HyperlinkedModelSerializer(ModelSerializer):
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
rel = model_field.rel.to
- queryset = rel._default_manager
kwargs = {
- 'queryset': queryset,
+ 'null': model_field.null,
+ 'queryset': rel._default_manager,
'view_name': self._get_default_view_name(rel)
}
if to_many:
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 3c508294..5c77c55c 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -37,11 +37,14 @@ DEFAULTS = {
'rest_framework.authentication.SessionAuthentication',
'rest_framework.authentication.BasicAuthentication'
),
- 'DEFAULT_PERMISSION_CLASSES': (),
- 'DEFAULT_THROTTLE_CLASSES': (),
+ 'DEFAULT_PERMISSION_CLASSES': (
+ 'rest_framework.permissions.AllowAny',
+ ),
+ 'DEFAULT_THROTTLE_CLASSES': (
+ ),
+
'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
-
'DEFAULT_MODEL_SERIALIZER_CLASS':
'rest_framework.serializers.ModelSerializer',
'DEFAULT_PAGINATION_SERIALIZER_CLASS':
@@ -51,18 +54,26 @@ DEFAULTS = {
'user': None,
'anon': None,
},
+
+ # Pagination
'PAGINATE_BY': None,
+ 'PAGINATE_BY_PARAM': None,
+
+ # Filtering
+ 'FILTER_BACKEND': None,
+ # Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
+ # Browser enhancements
'FORM_METHOD_OVERRIDE': '_method',
'FORM_CONTENT_OVERRIDE': '_content',
'FORM_CONTENTTYPE_OVERRIDE': '_content_type',
'URL_ACCEPT_OVERRIDE': 'accept',
'URL_FORMAT_OVERRIDE': 'format',
- 'FORMAT_SUFFIX_KWARG': 'format'
+ 'FORMAT_SUFFIX_KWARG': 'format',
}
@@ -76,6 +87,7 @@ IMPORT_STRINGS = (
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
'DEFAULT_MODEL_SERIALIZER_CLASS',
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
+ 'FILTER_BACKEND',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
)
@@ -103,8 +115,8 @@ def import_from_string(val, setting_name):
module_path, class_name = '.'.join(parts[:-1]), parts[-1]
module = importlib.import_module(module_path)
return getattr(module, class_name)
- except:
- msg = "Could not import '%s' for API setting '%s'" % (val, setting_name)
+ except ImportError as e:
+ msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)
raise ImportError(msg)
@@ -139,8 +151,15 @@ class APISettings(object):
if val and attr in self.import_strings:
val = perform_import(val, attr)
+ self.validate_setting(attr, val)
+
# Cache the result
setattr(self, attr, val)
return val
+ def validate_setting(self, attr, val):
+ if attr == 'FILTER_BACKEND' and val is not None:
+ # Make sure we can initialize the class
+ val()
+
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)
diff --git a/rest_framework/static/rest_framework/css/default.css b/rest_framework/static/rest_framework/css/default.css
index 739b9300..b2e41b99 100644
--- a/rest_framework/static/rest_framework/css/default.css
+++ b/rest_framework/static/rest_framework/css/default.css
@@ -32,6 +32,17 @@ h2, h3 {
margin-right: 1em;
}
+ul.breadcrumb {
+ margin: 58px 0 0 0;
+}
+
+form select, form input, form textarea {
+ width: 90%;
+}
+
+form select[multiple] {
+ height: 150px;
+}
/* To allow tooltips to work on disabled elements */
.disabled-tooltip-shield {
position: absolute;
@@ -55,6 +66,7 @@ pre {
.page-header {
border-bottom: none;
padding-bottom: 0px;
+ margin-bottom: 20px;
}
@@ -65,7 +77,7 @@ html{
background: none;
}
-body, .navbar .navbar-inner .container-fluid{
+body, .navbar .navbar-inner .container-fluid {
max-width: 1150px;
margin: 0 auto;
}
@@ -76,13 +88,14 @@ body{
}
#content{
- margin: 40px 0 0 0;
+ margin: 0;
}
/* custom navigation styles */
.wrapper .navbar{
- width:100%;
+ width: 100%;
position: absolute;
- left:0;
+ left: 0;
+ top: 0;
}
.navbar .navbar-inner{
diff --git a/rest_framework/status.py b/rest_framework/status.py
index f3a5e481..a1eb48da 100644
--- a/rest_framework/status.py
+++ b/rest_framework/status.py
@@ -49,4 +49,4 @@ HTTP_502_BAD_GATEWAY = 502
HTTP_503_SERVICE_UNAVAILABLE = 503
HTTP_504_GATEWAY_TIMEOUT = 504
HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505
-HTTP_511_NETWORD_AUTHENTICATION_REQUIRED = 511
+HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 5ac6ef67..092bf2e4 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -1,6 +1,5 @@
{% load url from future %}
{% load rest_framework %}
-{% load static %}
<!DOCTYPE html>
<html>
<head>
@@ -14,10 +13,10 @@
<title>{% block title %}Django REST framework{% endblock %}</title>
{% block style %}
- <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap.min.css"/>
- <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap-tweaks.css"/>
- <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/prettify.css'/>
- <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/default.css'/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/prettify.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>
{% endblock %}
{% endblock %}
@@ -109,11 +108,11 @@
<div class="content-main">
<div class="page-header"><h1>{{ name }}</h1></div>
- <p class="resource-description">{{ description }}</p>
+ {{ description }}
<div class="request-info">
<pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre>
- <div>
+ </div>
<div class="response-info">
<pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %}
{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|urlize_quoted_links }}</span>
@@ -131,12 +130,12 @@
{% csrf_token %}
{{ post_form.non_field_errors }}
{% for field in post_form %}
- <div class="control-group {% if field.errors %}error{% endif %}">
+ <div class="control-group"> <!--{% if field.errors %}error{% endif %}-->
{{ field.label_tag|add_class:"control-label" }}
<div class="controls">
- {{ field|add_class:"input-xlarge" }}
+ {{ field }}
<span class="help-inline">{{ field.help_text }}</span>
- {{ field.errors|add_class:"help-block" }}
+ <!--{{ field.errors|add_class:"help-block" }}-->
</div>
</div>
{% endfor %}
@@ -156,12 +155,12 @@
{% csrf_token %}
{{ put_form.non_field_errors }}
{% for field in put_form %}
- <div class="control-group {% if field.errors %}error{% endif %}">
+ <div class="control-group"> <!--{% if field.errors %}error{% endif %}-->
{{ field.label_tag|add_class:"control-label" }}
<div class="controls">
- {{ field|add_class:"input-xlarge" }}
+ {{ field }}
<span class='help-inline'>{{ field.help_text }}</span>
- {{ field.errors|add_class:"help-block" }}
+ <!--{{ field.errors|add_class:"help-block" }}-->
</div>
</div>
{% endfor %}
@@ -195,10 +194,10 @@
{% endblock %}
{% block script %}
- <script src="{% get_static_prefix %}rest_framework/js/jquery-1.8.1-min.js"></script>
- <script src="{% get_static_prefix %}rest_framework/js/bootstrap.min.js"></script>
- <script src="{% get_static_prefix %}rest_framework/js/prettify-min.js"></script>
- <script src="{% get_static_prefix %}rest_framework/js/default.js"></script>
+ <script src="{% static "rest_framework/js/jquery-1.8.1-min.js" %}"></script>
+ <script src="{% static "rest_framework/js/bootstrap.min.js" %}"></script>
+ <script src="{% static "rest_framework/js/prettify-min.js" %}"></script>
+ <script src="{% static "rest_framework/js/default.js" %}"></script>
{% endblock %}
</body>
</html>
diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html
index 65af512e..6e2bd8d4 100644
--- a/rest_framework/templates/rest_framework/login.html
+++ b/rest_framework/templates/rest_framework/login.html
@@ -1,44 +1,52 @@
{% load url from future %}
-{% load static %}
+{% load rest_framework %}
<html>
<head>
- <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/style.css'/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>
</head>
- <body class="login">
+ <body class="container">
- <div id="container">
-
- <div id="header">
- <div id="branding">
- <h1 id="site-name">Django REST framework</h1>
+<div class="container-fluid" style="margin-top: 30px">
+ <div class="row-fluid">
+
+ <div class="well" style="width: 320px; margin-left: auto; margin-right: auto">
+ <div class="row-fluid">
+ <div>
+ <h3 style="margin: 0 0 20px;">Django REST framework</h3>
</div>
- </div>
+ </div><!-- /row fluid -->
- <div id="content" class="colM">
- <div id="content-main">
- <form method="post" action="{% url 'rest_framework:login' %}" id="login-form">
+ <div class="row-fluid">
+ <div>
+ <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">
{% csrf_token %}
- <div class="form-row">
- <label for="id_username">Username:</label> {{ form.username }}
+ <div id="div_id_username" class="clearfix control-group">
+ <div class="controls" style="height: 30px">
+ <Label class="span4" style="margin-top: 3px">Username:</label>
+ <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">
+ </div>
</div>
- <div class="form-row">
- <label for="id_password">Password:</label> {{ form.password }}
- <input type="hidden" name="next" value="{{ next }}" />
+ <div id="div_id_password" class="clearfix control-group">
+ <div class="controls" style="height: 30px">
+ <Label class="span4" style="margin-top: 3px">Password:</label>
+ <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">
+ </div>
</div>
- <div class="form-row">
- <label>&nbsp;</label><input type="submit" value="Log in">
+ <input type="hidden" name="next" value="{{ next }}" />
+ <div class="form-actions-no-box">
+ <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit">
</div>
</form>
- <script type="text/javascript">
- document.getElementById('id_username').focus()
- </script>
</div>
- <br class="clear">
- </div>
+ </div><!-- /row fluid -->
+ </div><!--/span-->
- <div id="footer"></div>
+ </div><!-- /.row-fluid -->
+ </div>
</div>
</body>
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index c9b6eb10..82fcdfe7 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -11,6 +11,101 @@ import string
register = template.Library()
+# Note we don't use 'load staticfiles', because we need a 1.3 compatible
+# version, so instead we include the `static` template tag ourselves.
+
+# When 1.3 becomes unsupported by REST framework, we can instead start to
+# use the {% load staticfiles %} tag, remove the following code,
+# and add a dependancy that `django.contrib.staticfiles` must be installed.
+
+# Note: We can't put this into the `compat` module because the compat import
+# from rest_framework.compat import ...
+# conflicts with this rest_framework template tag module.
+
+try: # Django 1.5+
+ from django.contrib.staticfiles.templatetags.staticfiles import StaticFilesNode
+
+ @register.tag('static')
+ def do_static(parser, token):
+ return StaticFilesNode.handle_token(parser, token)
+
+except:
+ try: # Django 1.4
+ from django.contrib.staticfiles.storage import staticfiles_storage
+
+ @register.simple_tag
+ def static(path):
+ """
+ A template tag that returns the URL to a file
+ using staticfiles' storage backend
+ """
+ return staticfiles_storage.url(path)
+
+ except: # Django 1.3
+ from urlparse import urljoin
+ from django import template
+ from django.templatetags.static import PrefixNode
+
+ class StaticNode(template.Node):
+ def __init__(self, varname=None, path=None):
+ if path is None:
+ raise template.TemplateSyntaxError(
+ "Static template nodes must be given a path to return.")
+ self.path = path
+ self.varname = varname
+
+ def url(self, context):
+ path = self.path.resolve(context)
+ return self.handle_simple(path)
+
+ def render(self, context):
+ url = self.url(context)
+ if self.varname is None:
+ return url
+ context[self.varname] = url
+ return ''
+
+ @classmethod
+ def handle_simple(cls, path):
+ return urljoin(PrefixNode.handle_simple("STATIC_URL"), path)
+
+ @classmethod
+ def handle_token(cls, parser, token):
+ """
+ Class method to parse prefix node and return a Node.
+ """
+ bits = token.split_contents()
+
+ if len(bits) < 2:
+ raise template.TemplateSyntaxError(
+ "'%s' takes at least one argument (path to file)" % bits[0])
+
+ path = parser.compile_filter(bits[1])
+
+ if len(bits) >= 2 and bits[-2] == 'as':
+ varname = bits[3]
+ else:
+ varname = None
+
+ return cls(varname, path)
+
+ @register.tag('static')
+ def do_static_13(parser, token):
+ return StaticNode.handle_token(parser, token)
+
+
+def replace_query_param(url, key, val):
+ """
+ Given a URL and a key/val pair, set or replace an item in the query
+ parameters of the URL, and return the new URL.
+ """
+ (scheme, netloc, path, query, fragment) = urlsplit(url)
+ query_dict = QueryDict(query).copy()
+ query_dict[key] = val
+ query = query_dict.urlencode()
+ return urlunsplit((scheme, netloc, path, query, fragment))
+
+
# Regex for adding classes to html snippets
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
@@ -31,19 +126,6 @@ hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|
trailing_empty_content_re = re.compile(r'(?:<p>(?:&nbsp;|\s|<br \/>)*?</p>\s*)+\Z')
-# Helper function for 'add_query_param'
-def replace_query_param(url, key, val):
- """
- Given a URL and a key/val pair, set or replace an item in the query
- parameters of the URL, and return the new URL.
- """
- (scheme, netloc, path, query, fragment) = urlsplit(url)
- query_dict = QueryDict(query).copy()
- query_dict[key] = val
- query = query_dict.urlencode()
- return urlunsplit((scheme, netloc, path, query, fragment))
-
-
# And the template tags themselves...
@register.simple_tag
diff --git a/rest_framework/tests/__init__.py b/rest_framework/tests/__init__.py
index adeaf6da..e69de29b 100644
--- a/rest_framework/tests/__init__.py
+++ b/rest_framework/tests/__init__.py
@@ -1,13 +0,0 @@
-"""
-Force import of all modules in this package in order to get the standard test
-runner to pick up the tests. Yowzers.
-"""
-import os
-
-modules = [filename.rsplit('.', 1)[0]
- for filename in os.listdir(os.path.dirname(__file__))
- if filename.endswith('.py') and not filename.startswith('_')]
-__test__ = dict()
-
-for module in modules:
- exec("from rest_framework.tests.%s import *" % module)
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index 8ab4c4e4..e86041bc 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -1,16 +1,14 @@
-from django.conf.urls.defaults import patterns
from django.contrib.auth.models import User
-from django.test import Client, TestCase
-
-from django.utils import simplejson as json
from django.http import HttpResponse
+from django.test import Client, TestCase
-from rest_framework.views import APIView
from rest_framework import permissions
-
from rest_framework.authtoken.models import Token
from rest_framework.authentication import TokenAuthentication
+from rest_framework.compat import patterns
+from rest_framework.views import APIView
+import json
import base64
@@ -27,6 +25,7 @@ MockView.authentication_classes += (TokenAuthentication,)
urlpatterns = patterns('',
(r'^$', MockView.as_view()),
+ (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
)
@@ -152,3 +151,33 @@ class TokenAuthTests(TestCase):
self.token.delete()
token = Token.objects.create(user=self.user)
self.assertTrue(bool(token.key))
+
+ def test_token_login_json(self):
+ """Ensure token login view using JSON POST works."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ json.dumps({'username': self.username, 'password': self.password}), 'application/json')
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(json.loads(response.content)['token'], self.key)
+
+ def test_token_login_json_bad_creds(self):
+ """Ensure token login view using JSON POST fails if bad credentials are used."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ json.dumps({'username': self.username, 'password': "badpass"}), 'application/json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_json_missing_fields(self):
+ """Ensure token login view using JSON POST fails if missing fields."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ json.dumps({'username': self.username}), 'application/json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_form(self):
+ """Ensure token login view using form POST works."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username, 'password': self.password})
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(json.loads(response.content)['token'], self.key)
diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/breadcrumbs.py
index 647ab96d..df891683 100644
--- a/rest_framework/tests/breadcrumbs.py
+++ b/rest_framework/tests/breadcrumbs.py
@@ -1,5 +1,5 @@
-from django.conf.urls.defaults import patterns, url
from django.test import TestCase
+from rest_framework.compat import patterns, url
from rest_framework.utils.breadcrumbs import get_breadcrumbs
from rest_framework.views import APIView
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py
index 41864d71..5e6bce4e 100644
--- a/rest_framework/tests/decorators.py
+++ b/rest_framework/tests/decorators.py
@@ -1,7 +1,6 @@
from django.test import TestCase
from rest_framework import status
from rest_framework.response import Response
-from django.test.client import RequestFactory
from rest_framework.renderers import JSONRenderer
from rest_framework.parsers import JSONParser
from rest_framework.authentication import BasicAuthentication
@@ -17,6 +16,8 @@ from rest_framework.decorators import (
permission_classes,
)
+from rest_framework.tests.utils import RequestFactory
+
class DecoratorTestCase(TestCase):
@@ -63,6 +64,20 @@ class DecoratorTestCase(TestCase):
response = view(request)
self.assertEqual(response.status_code, 405)
+ def test_calling_patch_method(self):
+
+ @api_view(['GET', 'PATCH'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.patch('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 405)
+
def test_renderer_classes(self):
@api_view(['GET'])
diff --git a/rest_framework/tests/extras/__init__.py b/rest_framework/tests/extras/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/rest_framework/tests/extras/__init__.py
diff --git a/rest_framework/tests/extras/bad_import.py b/rest_framework/tests/extras/bad_import.py
new file mode 100644
index 00000000..68263d94
--- /dev/null
+++ b/rest_framework/tests/extras/bad_import.py
@@ -0,0 +1 @@
+raise ValueError
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
new file mode 100644
index 00000000..8068272d
--- /dev/null
+++ b/rest_framework/tests/fields.py
@@ -0,0 +1,49 @@
+"""
+General serializer field tests.
+"""
+
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class TimestampedModel(models.Model):
+ added = models.DateTimeField(auto_now_add=True)
+ updated = models.DateTimeField(auto_now=True)
+
+
+class CharPrimaryKeyModel(models.Model):
+ id = models.CharField(max_length=20, primary_key=True)
+
+
+class TimestampedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TimestampedModel
+
+
+class CharPrimaryKeyModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = CharPrimaryKeyModel
+
+
+class ReadOnlyFieldTests(TestCase):
+ def test_auto_now_fields_read_only(self):
+ """
+ auto_now and auto_now_add fields should be read_only by default.
+ """
+ serializer = TimestampedModelSerializer()
+ self.assertEquals(serializer.fields['added'].read_only, True)
+
+ def test_auto_pk_fields_read_only(self):
+ """
+ AutoField fields should be read_only by default.
+ """
+ serializer = TimestampedModelSerializer()
+ self.assertEquals(serializer.fields['id'].read_only, True)
+
+ def test_non_auto_pk_fields_not_read_only(self):
+ """
+ PK fields other than AutoField fields should not be read_only by default.
+ """
+ serializer = CharPrimaryKeyModelSerializer()
+ self.assertEquals(serializer.fields['id'].read_only, False)
diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py
index 61d7f7b1..446e23c0 100644
--- a/rest_framework/tests/files.py
+++ b/rest_framework/tests/files.py
@@ -1,34 +1,51 @@
-# from django.test import TestCase
-# from django import forms
+import StringIO
+import datetime
-# from django.test.client import RequestFactory
-# from rest_framework.views import View
-# from rest_framework.response import Response
+from django.test import TestCase
-# import StringIO
+from rest_framework import serializers
-# class UploadFilesTests(TestCase):
-# """Check uploading of files"""
-# def setUp(self):
-# self.factory = RequestFactory()
+class UploadedFile(object):
+ def __init__(self, file, created=None):
+ self.file = file
+ self.created = created or datetime.datetime.now()
-# def test_upload_file(self):
-# class FileForm(forms.Form):
-# file = forms.FileField()
+class UploadedFileSerializer(serializers.Serializer):
+ file = serializers.FileField()
+ created = serializers.DateTimeField()
-# class MockView(View):
-# permissions = ()
-# form = FileForm
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.file = attrs['file']
+ instance.created = attrs['created']
+ return instance
+ return UploadedFile(**attrs)
-# def post(self, request, *args, **kwargs):
-# return Response({'FILE_NAME': self.CONTENT['file'].name,
-# 'FILE_CONTENT': self.CONTENT['file'].read()})
-# file = StringIO.StringIO('stuff')
-# file.name = 'stuff.txt'
-# request = self.factory.post('/', {'file': file})
-# view = MockView.as_view()
-# response = view(request)
-# self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"})
+class FileSerializerTests(TestCase):
+ def test_create(self):
+ now = datetime.datetime.now()
+ file = StringIO.StringIO('stuff')
+ file.name = 'stuff.txt'
+ file.size = file.len
+ serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
+ uploaded_file = UploadedFile(file=file, created=now)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.object.created, uploaded_file.created)
+ self.assertEquals(serializer.object.file, uploaded_file.file)
+ self.assertFalse(serializer.object is uploaded_file)
+
+ def test_creation_failure(self):
+ """
+ Passing files=None should result in an ValidationError
+
+ Regression test for:
+ https://github.com/tomchristie/django-rest-framework/issues/542
+ """
+ now = datetime.datetime.now()
+
+ serializer = UploadedFileSerializer(data={'created': now})
+ self.assertFalse(serializer.is_valid())
+ self.assertIn('file', serializer.errors)
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
new file mode 100644
index 00000000..af2e6c2e
--- /dev/null
+++ b/rest_framework/tests/filterset.py
@@ -0,0 +1,168 @@
+import datetime
+from decimal import Decimal
+from django.test import TestCase
+from django.test.client import RequestFactory
+from django.utils import unittest
+from rest_framework import generics, status, filters
+from rest_framework.compat import django_filters
+from rest_framework.tests.models import FilterableItem, BasicModel
+
+factory = RequestFactory()
+
+
+if django_filters:
+ # Basic filter on a list view.
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_fields = ['decimal', 'date']
+ filter_backend = filters.DjangoFilterBackend
+
+ # These class are used to test a filter class.
+ class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterClassRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backend = filters.DjangoFilterBackend
+
+ # These classes are used to test a misconfigured filter class.
+ class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+
+ class Meta:
+ model = BasicModel
+ fields = ['text']
+
+ class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = MisconfiguredFilter
+ filter_backend = filters.DjangoFilterBackend
+
+
+class IntegrationTestFiltering(TestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
+ def setUp(self):
+ """
+ Create 10 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(10):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
+ ]
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_fields_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the date filter works.
+ search_date = datetime.date(2012, 9, 22)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] == search_date]
+ self.assertEquals(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_class_root_view(self):
+ """
+ GET requests to filtered ListCreateAPIView that have a filter_class set
+ should return filtered results.
+ """
+ view = FilterClassRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ # Tests that the decimal filter set with 'lt' in the filter class works.
+ search_decimal = Decimal('4.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] < search_decimal]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the date filter set with 'gt' in the filter class works.
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the text filter set with 'icontains' in the filter class works.
+ search_text = 'ff'
+ request = factory.get('/?text=%s' % search_text)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if search_text in f['text'].lower()]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date and
+ f['decimal'] < search_decimal]
+ self.assertEquals(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_incorrectly_configured_filter(self):
+ """
+ An error should be displayed when the filter class is misconfigured.
+ """
+ view = IncorrectlyConfiguredRootView.as_view()
+
+ request = factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_unknown_filter(self):
+ """
+ GET requests with filters that aren't configured should return 200.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ search_integer = 10
+ request = factory.get('/?integer=%s' % search_integer)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py
index 1d7e33bc..bc7378e1 100644
--- a/rest_framework/tests/genericrelations.py
+++ b/rest_framework/tests/genericrelations.py
@@ -25,7 +25,7 @@ class TestGenericRelations(TestCase):
model = Bookmark
exclude = ('id',)
- serializer = BookmarkSerializer(instance=self.bookmark)
+ serializer = BookmarkSerializer(self.bookmark)
expected = {
'tags': [u'django', u'python'],
'url': u'https://www.djangoproject.com/'
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index f4263478..4799a04b 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -1,8 +1,9 @@
+import json
+from django.db import models
from django.test import TestCase
-from django.test.client import RequestFactory
-from django.utils import simplejson as json
from rest_framework import generics, serializers, status
-from rest_framework.tests.models import BasicModel, Comment
+from rest_framework.tests.utils import RequestFactory
+from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
factory = RequestFactory()
@@ -22,6 +23,22 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView):
model = BasicModel
+class SlugSerializer(serializers.ModelSerializer):
+ slug = serializers.Field() # read only
+
+ class Meta:
+ model = SlugBasedModel
+ exclude = ('id',)
+
+
+class SlugBasedInstanceView(InstanceView):
+ """
+ A model with a slug-field.
+ """
+ model = SlugBasedModel
+ serializer_class = SlugSerializer
+
+
class TestRootView(TestCase):
def setUp(self):
"""
@@ -129,6 +146,7 @@ class TestInstanceView(TestCase):
for obj in self.objects.all()
]
self.view = InstanceView.as_view()
+ self.slug_based_view = SlugBasedInstanceView.as_view()
def test_get_instance_view(self):
"""
@@ -157,6 +175,20 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
+ response = self.view(request, pk='1').render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEquals(updated.text, 'foobar')
+
+ def test_patch_instance_view(self):
+ """
+ PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
+ """
+ content = {'text': 'foobar'}
+ request = factory.patch('/1', json.dumps(content),
+ content_type='application/json')
+
response = self.view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
@@ -198,7 +230,7 @@ class TestInstanceView(TestCase):
def test_put_cannot_set_id(self):
"""
- POST requests to create a new object should not be able to set the id.
+ PUT requests to create a new object should not be able to set the id.
"""
content = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
@@ -219,11 +251,39 @@ class TestInstanceView(TestCase):
request = factory.put('/1', json.dumps(content),
content_type='application/json')
response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
self.assertEquals(updated.text, 'foobar')
+ def test_put_as_create_on_id_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if it doesn't exist.
+ """
+ content = {'text': 'foobar'}
+ # pk fields can not be created on demand, only the database can set th pk for a new object
+ request = factory.put('/5', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request, pk=5).render()
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ new_obj = self.objects.get(pk=5)
+ self.assertEquals(new_obj.text, 'foobar')
+
+ def test_put_as_create_on_slug_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
+ """
+ content = {'text': 'foobar'}
+ request = factory.put('/test_slug', json.dumps(content),
+ content_type='application/json')
+ response = self.slug_based_view(request, slug='test_slug').render()
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'})
+ new_obj = SlugBasedModel.objects.get(slug='test_slug')
+ self.assertEquals(new_obj.text, 'foobar')
+
# Regression test for #285
@@ -256,3 +316,36 @@ class TestCreateModelWithAutoNowAddField(TestCase):
self.assertEquals(response.status_code, status.HTTP_201_CREATED)
created = self.objects.get(id=1)
self.assertEquals(created.content, 'foobar')
+
+
+# Test for particularly ugly reression with m2m in browseable API
+class ClassB(models.Model):
+ name = models.CharField(max_length=255)
+
+
+class ClassA(models.Model):
+ name = models.CharField(max_length=255)
+ childs = models.ManyToManyField(ClassB, blank=True, null=True)
+
+
+class ClassASerializer(serializers.ModelSerializer):
+ childs = serializers.ManyPrimaryKeyRelatedField(source='childs')
+
+ class Meta:
+ model = ClassA
+
+
+class ExampleView(generics.ListCreateAPIView):
+ serializer_class = ClassASerializer
+ model = ClassA
+
+
+class TestM2MBrowseableAPI(TestCase):
+ def test_m2m_in_browseable_api(self):
+ """
+ Test for particularly ugly reression with m2m in browseable API
+ """
+ request = factory.get('/', HTTP_ACCEPT='text/html')
+ view = ExampleView().as_view()
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py
index da2f83c3..54096206 100644
--- a/rest_framework/tests/htmlrenderer.py
+++ b/rest_framework/tests/htmlrenderer.py
@@ -1,14 +1,16 @@
-from django.conf.urls.defaults import patterns, url
+from django.core.exceptions import PermissionDenied
+from django.http import Http404
from django.test import TestCase
from django.template import TemplateDoesNotExist, Template
import django.template.loader
+from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view, renderer_classes
-from rest_framework.renderers import HTMLRenderer
+from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response
@api_view(('GET',))
-@renderer_classes((HTMLRenderer,))
+@renderer_classes((TemplateHTMLRenderer,))
def example(request):
"""
A view that can returns an HTML representation.
@@ -17,12 +19,26 @@ def example(request):
return Response(data, template_name='example.html')
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def permission_denied(request):
+ raise PermissionDenied()
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def not_found(request):
+ raise Http404()
+
+
urlpatterns = patterns('',
url(r'^$', example),
+ url(r'^permission_denied$', permission_denied),
+ url(r'^not_found$', not_found),
)
-class HTMLRendererTests(TestCase):
+class TemplateHTMLRendererTests(TestCase):
urls = 'rest_framework.tests.htmlrenderer'
def setUp(self):
@@ -48,3 +64,52 @@ class HTMLRendererTests(TestCase):
response = self.client.get('/')
self.assertContains(response, "example: foobar")
self.assertEquals(response['Content-Type'], 'text/html')
+
+ def test_not_found_html_view(self):
+ response = self.client.get('/not_found')
+ self.assertEquals(response.status_code, 404)
+ self.assertEquals(response.content, "404 Not Found")
+ self.assertEquals(response['Content-Type'], 'text/html')
+
+ def test_permission_denied_html_view(self):
+ response = self.client.get('/permission_denied')
+ self.assertEquals(response.status_code, 403)
+ self.assertEquals(response.content, "403 Forbidden")
+ self.assertEquals(response['Content-Type'], 'text/html')
+
+
+class TemplateHTMLRendererExceptionTests(TestCase):
+ urls = 'rest_framework.tests.htmlrenderer'
+
+ def setUp(self):
+ """
+ Monkeypatch get_template
+ """
+ self.get_template = django.template.loader.get_template
+
+ def get_template(template_name):
+ if template_name == '404.html':
+ return Template("404: {{ detail }}")
+ if template_name == '403.html':
+ return Template("403: {{ detail }}")
+ raise TemplateDoesNotExist(template_name)
+
+ django.template.loader.get_template = get_template
+
+ def tearDown(self):
+ """
+ Revert monkeypatching
+ """
+ django.template.loader.get_template = self.get_template
+
+ def test_not_found_html_view_with_template(self):
+ response = self.client.get('/not_found')
+ self.assertEquals(response.status_code, 404)
+ self.assertEquals(response.content, "404: Not found")
+ self.assertEquals(response['Content-Type'], 'text/html')
+
+ def test_permission_denied_html_view_with_template(self):
+ response = self.client.get('/permission_denied')
+ self.assertEquals(response.status_code, 403)
+ self.assertEquals(response.content, "403: Permission denied")
+ self.assertEquals(response['Content-Type'], 'text/html')
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index 5532a8ee..c6a8224b 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -1,12 +1,31 @@
-from django.conf.urls.defaults import patterns, url
+import json
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status, serializers
-from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel
+from rest_framework.compat import patterns, url
+from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel
factory = RequestFactory()
+class BlogPostCommentSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail')
+ text = serializers.CharField()
+ blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail')
+
+ class Meta:
+ model = BlogPostComment
+ fields = ('text', 'blog_post_url', 'url')
+
+
+class PhotoSerializer(serializers.Serializer):
+ description = serializers.CharField()
+ album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title')
+
+ def restore_object(self, attrs, instance=None):
+ return Photo(**attrs)
+
+
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
@@ -32,12 +51,46 @@ class ManyToManyDetail(generics.RetrieveAPIView):
model_serializer_class = serializers.HyperlinkedModelSerializer
+class BlogPostCommentListCreate(generics.ListCreateAPIView):
+ model = BlogPostComment
+ serializer_class = BlogPostCommentSerializer
+
+
+class BlogPostCommentDetail(generics.RetrieveAPIView):
+ model = BlogPostComment
+ serializer_class = BlogPostCommentSerializer
+
+
+class BlogPostDetail(generics.RetrieveAPIView):
+ model = BlogPost
+
+
+class PhotoListCreate(generics.ListCreateAPIView):
+ model = Photo
+ model_serializer_class = PhotoSerializer
+
+
+class AlbumDetail(generics.RetrieveAPIView):
+ model = Album
+
+
+class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
+ model = OptionalRelationModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
+ url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
+ url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'),
+ url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'),
+ url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'),
+ url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'),
+ url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'),
)
@@ -112,7 +165,7 @@ class TestManyToManyHyperlinkedView(TestCase):
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/manytomany/')
- response = self.list_view(request).render()
+ response = self.list_view(request)
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data)
@@ -121,6 +174,89 @@ class TestManyToManyHyperlinkedView(TestCase):
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/manytomany/1/')
- response = self.detail_view(request, pk=1).render()
+ response = self.detail_view(request, pk=1)
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0])
+
+
+class TestCreateWithForeignKeys(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create a blog post
+ """
+ self.post = BlogPost.objects.create(title="Test post")
+ self.create_view = BlogPostCommentListCreate.as_view()
+
+ def test_create_comment(self):
+
+ data = {
+ 'text': 'A test comment',
+ 'blog_post_url': 'http://testserver/posts/1/'
+ }
+
+ request = factory.post('/comments/', data=data)
+ response = self.create_view(request)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response['Location'], 'http://testserver/comments/1/')
+ self.assertEqual(self.post.blogpostcomment_set.count(), 1)
+ self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
+
+
+class TestCreateWithForeignKeysAndCustomSlug(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create an Album
+ """
+ self.post = Album.objects.create(title='test-album')
+ self.list_create_view = PhotoListCreate.as_view()
+
+ def test_create_photo(self):
+
+ data = {
+ 'description': 'A test photo',
+ 'album_url': 'http://testserver/albums/test-album/'
+ }
+
+ request = factory.post('/photos/', data=data)
+ response = self.list_create_view(request)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
+ self.assertEqual(self.post.photo_set.count(), 1)
+ self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
+
+
+class TestOptionalRelationHyperlinkedView(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 1 OptionalRelationModel intances.
+ """
+ OptionalRelationModel().save()
+ self.objects = OptionalRelationModel.objects
+ self.detail_view = OptionalRelationDetail.as_view()
+ self.data = {"url": "http://testserver/optionalrelation/1/", "other": None}
+
+ def test_get_detail_view(self):
+ """
+ GET requests to RetrieveAPIView with optional relations should return None
+ for non existing relations.
+ """
+ request = factory.get('/optionalrelationmodel-detail/1')
+ response = self.detail_view(request, pk=1)
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ def test_put_detail_view(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView with optional relations
+ should accept None for non existing relations.
+ """
+ response = self.client.put('/optionalrelation/1/',
+ data=json.dumps(self.data),
+ content_type='application/json')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 97cd0849..93f09761 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -35,15 +35,27 @@ def foobar():
return 'foobar'
+class CustomField(models.CharField):
+
+ def __init__(self, *args, **kwargs):
+ kwargs['max_length'] = 12
+ super(CustomField, self).__init__(*args, **kwargs)
+
+
class RESTFrameworkModel(models.Model):
"""
Base for test models that sets app_label, so they play nicely.
"""
class Meta:
- app_label = 'rest_framework'
+ app_label = 'tests'
abstract = True
+class HasPositiveIntegerAsChoice(RESTFrameworkModel):
+ some_choices = ((1, 'A'), (2, 'B'), (3, 'C'))
+ some_integer = models.PositiveIntegerField(choices=some_choices)
+
+
class Anchor(RESTFrameworkModel):
text = models.CharField(max_length=100, default='anchor')
@@ -52,8 +64,14 @@ class BasicModel(RESTFrameworkModel):
text = models.CharField(max_length=100)
+class SlugBasedModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ slug = models.SlugField(max_length=32)
+
+
class DefaultValueModel(RESTFrameworkModel):
text = models.CharField(default='foobar', max_length=100)
+ extra = models.CharField(blank=True, null=True, max_length=100)
class CallableDefaultValueModel(RESTFrameworkModel):
@@ -62,12 +80,12 @@ class CallableDefaultValueModel(RESTFrameworkModel):
class ManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
-
+
class ReadOnlyManyToManyModel(RESTFrameworkModel):
text = models.CharField(max_length=100, default='anchor')
rel = models.ManyToManyField(Anchor)
-
+
# Models to test generic relations
@@ -90,6 +108,13 @@ class Bookmark(RESTFrameworkModel):
tags = GenericRelation(TaggedItem)
+# Model to test filtering.
+class FilterableItem(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
# Model for regression test for #285
class Comment(RESTFrameworkModel):
@@ -101,13 +126,93 @@ class Comment(RESTFrameworkModel):
class ActionItem(RESTFrameworkModel):
title = models.CharField(max_length=200)
done = models.BooleanField(default=False)
+ info = CustomField(default='---', max_length=12)
# Models for reverse relations
+class Person(RESTFrameworkModel):
+ name = models.CharField(max_length=10)
+ age = models.IntegerField(null=True, blank=True)
+
+ @property
+ def info(self):
+ return {
+ 'name': self.name,
+ 'age': self.age,
+ }
+
+
class BlogPost(RESTFrameworkModel):
title = models.CharField(max_length=100)
+ writer = models.ForeignKey(Person, null=True, blank=True)
+
+ def get_first_comment(self):
+ return self.blogpostcomment_set.all()[0]
class BlogPostComment(RESTFrameworkModel):
text = models.TextField()
blog_post = models.ForeignKey(BlogPost)
+
+
+class Album(RESTFrameworkModel):
+ title = models.CharField(max_length=100, unique=True)
+
+
+class Photo(RESTFrameworkModel):
+ description = models.TextField()
+ album = models.ForeignKey(Album)
+
+
+# Model for issue #324
+class BlankFieldModel(RESTFrameworkModel):
+ title = models.CharField(max_length=100, blank=True, null=False)
+
+
+# Model for issue #380
+class OptionalRelationModel(RESTFrameworkModel):
+ other = models.ForeignKey('OptionalRelationModel', blank=True, null=True)
+
+
+# Model for RegexField
+class Book(RESTFrameworkModel):
+ isbn = models.CharField(max_length=13)
+
+
+# Models for relations tests
+# ManyToMany
+class ManyToManyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
+
+
+# ForeignKey
+class ForeignKeyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
+ related_name='nullable_sources')
+
+
+# OneToOne
+class OneToOneTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class NullableOneToOneSource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
+ related_name='nullable_source')
diff --git a/rest_framework/tests/modelviews.py b/rest_framework/tests/modelviews.py
index 1f8468e8..f12e3b97 100644
--- a/rest_framework/tests/modelviews.py
+++ b/rest_framework/tests/modelviews.py
@@ -1,4 +1,4 @@
-# from django.conf.urls.defaults import patterns, url
+# from rest_framework.compat import patterns, url
# from django.forms import ModelForm
# from django.contrib.auth.models import Group, User
# from rest_framework.resources import ModelResource
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index a939c9ef..3b550877 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -1,8 +1,12 @@
+import datetime
+from decimal import Decimal
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
-from rest_framework import generics, status, pagination
-from rest_framework.tests.models import BasicModel
+from django.utils import unittest
+from rest_framework import generics, status, pagination, filters, serializers
+from rest_framework.compat import django_filters
+from rest_framework.tests.models import BasicModel, FilterableItem
factory = RequestFactory()
@@ -15,6 +19,36 @@ class RootView(generics.ListCreateAPIView):
paginate_by = 10
+if django_filters:
+ class DecimalFilter(django_filters.FilterSet):
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_class = DecimalFilter
+ filter_backend = filters.DjangoFilterBackend
+
+
+class DefaultPageSizeKwargView(generics.ListAPIView):
+ """
+ View for testing default paginate_by_param usage
+ """
+ model = BasicModel
+
+
+class PaginateByParamView(generics.ListAPIView):
+ """
+ View for testing custom paginate_by_param usage
+ """
+ model = BasicModel
+ paginate_by_param = 'page_size'
+
+
class IntegrationTestPagination(TestCase):
"""
Integration tests for paginated list views.
@@ -22,7 +56,7 @@ class IntegrationTestPagination(TestCase):
def setUp(self):
"""
- Create 26 BasicModel intances.
+ Create 26 BasicModel instances.
"""
for char in 'abcdefghijklmnopqrstuvwxyz':
BasicModel(text=char * 3).save()
@@ -62,9 +96,66 @@ class IntegrationTestPagination(TestCase):
self.assertNotEquals(response.data['previous'], None)
+class IntegrationTestPaginationAndFiltering(TestCase):
+
+ def setUp(self):
+ """
+ Create 50 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(26):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
+ ]
+ self.view = FilterFieldsRootView.as_view()
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_paginated_filtered_root_view(self):
+ """
+ GET requests to paginated filtered ListCreateAPIView should return
+ paginated results. The next and previous links should preserve the
+ filtered parameters.
+ """
+ request = factory.get('/?decimal=15.20')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[10:15])
+ self.assertEquals(response.data['next'], None)
+ self.assertNotEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+
+class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
+ class Meta:
+ object_serializer_class = serializers.Serializer
+
+
class UnitTestPagination(TestCase):
"""
- Unit tests for pagination of primative objects.
+ Unit tests for pagination of primitive objects.
"""
def setUp(self):
@@ -74,14 +165,117 @@ class UnitTestPagination(TestCase):
self.last_page = paginator.page(3)
def test_native_pagination(self):
- serializer = pagination.PaginationSerializer(instance=self.first_page)
+ serializer = pagination.PaginationSerializer(self.first_page)
self.assertEquals(serializer.data['count'], 26)
self.assertEquals(serializer.data['next'], '?page=2')
self.assertEquals(serializer.data['previous'], None)
self.assertEquals(serializer.data['results'], self.objects[:10])
- serializer = pagination.PaginationSerializer(instance=self.last_page)
+ serializer = pagination.PaginationSerializer(self.last_page)
self.assertEquals(serializer.data['count'], 26)
self.assertEquals(serializer.data['next'], None)
self.assertEquals(serializer.data['previous'], '?page=2')
self.assertEquals(serializer.data['results'], self.objects[20:])
+
+ def test_context_available_in_result(self):
+ """
+ Ensure context gets passed through to the object serializer.
+ """
+ serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
+ serializer.data
+ results = serializer.fields[serializer.results_field]
+ self.assertEquals(serializer.context, results.context)
+
+
+class TestUnpaginated(TestCase):
+ """
+ Tests for list views without pagination.
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = DefaultPageSizeKwargView.as_view()
+
+ def test_unpaginated(self):
+ """
+ Tests the default page size for this view.
+ no page size --> no limit --> no meta data
+ """
+ request = factory.get('/')
+ response = self.view(request)
+ self.assertEquals(response.data, self.data)
+
+
+class TestCustomPaginateByParam(TestCase):
+ """
+ Tests for list views with default page size kwarg
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = PaginateByParamView.as_view()
+
+ def test_default_page_size(self):
+ """
+ Tests the default page size for this view.
+ no page size --> no limit --> no meta data
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEquals(response.data, self.data)
+
+ def test_paginate_by_param(self):
+ """
+ If paginate_by_param is set, the new kwarg should limit per view requests.
+ """
+ request = factory.get('/?page_size=5')
+ response = self.view(request).render()
+ self.assertEquals(response.data['count'], 13)
+ self.assertEquals(response.data['results'], self.data[:5])
+
+
+class CustomField(serializers.Field):
+ def to_native(self, value):
+ if not 'view' in self.context:
+ raise RuntimeError("context isn't getting passed into custom field")
+ return "value"
+
+
+class BasicModelSerializer(serializers.Serializer):
+ text = CustomField()
+
+
+class TestContextPassedToCustomField(TestCase):
+ def setUp(self):
+ BasicModel.objects.create(text='ala ma kota')
+
+ def test_with_pagination(self):
+ class ListView(generics.ListCreateAPIView):
+ model = BasicModel
+ serializer_class = BasicModelSerializer
+ paginate_by = 1
+
+ self.view = ListView.as_view()
+ request = factory.get('/')
+ response = self.view(request).render()
+
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+
diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py
new file mode 100644
index 00000000..91daea8a
--- /dev/null
+++ b/rest_framework/tests/relations.py
@@ -0,0 +1,33 @@
+"""
+General tests for relational fields.
+"""
+
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class NullModel(models.Model):
+ pass
+
+
+class FieldTests(TestCase):
+ def test_pk_related_field_with_empty_string(self):
+ """
+ Regression test for #446
+
+ https://github.com/tomchristie/django-rest-framework/issues/446
+ """
+ field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_hyperlinked_related_field_with_empty_string(self):
+ field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_slug_related_field_with_empty_string(self):
+ field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
new file mode 100644
index 00000000..57913670
--- /dev/null
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -0,0 +1,434 @@
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.compat import patterns, url
+from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+
+def dummy_view(request, pk):
+ pass
+
+urlpatterns = patterns('',
+ url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
+ url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
+ url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
+ url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),
+ url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
+ url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
+ url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
+)
+
+class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
+ sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail')
+
+ class Meta:
+ model = ManyToManyTarget
+
+
+class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ManyToManySource
+
+
+class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
+ sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail')
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+
+
+# Nullable ForeignKey
+class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
+
+
+# OneToOne
+class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
+ nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail')
+
+ class Meta:
+ model = OneToOneTarget
+
+
+# TODO: Add test that .data cannot be accessed prior to .is_valid
+
+class HyperlinkedManyToManyTests(TestCase):
+ urls = 'rest_framework.tests.relations_hyperlink'
+
+ def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
+
+ def test_many_to_many_retrieve(self):
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']},
+ {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
+ {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_retrieve(self):
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
+ {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
+ {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_many_to_many_update(self):
+ data = {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']},
+ {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
+ {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_update(self):
+ data = {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure target 1 is updated, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']},
+ {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
+ {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}
+
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_many_to_many_create(self):
+ data = {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']}
+ serializer = ManyToManySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']},
+ {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
+ {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']},
+ {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_create(self):
+ data = {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']}
+ serializer = ManyToManyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-4')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
+ {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
+ {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']},
+ {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+
+class HyperlinkedForeignKeyTests(TestCase):
+ urls = 'rest_framework.tests.relations_hyperlink'
+
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
+ {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'},
+ {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_update(self):
+ data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
+ {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']},
+ {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'}
+ serializer = ForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-3')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']},
+ {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
+ {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
+
+
+class HyperlinkedNullableForeignKeyTests(TestCase):
+ urls = 'rest_framework.tests.relations_hyperlink'
+
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': ''}
+ expected_data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, expected_data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None},
+ {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': ''}
+ expected_data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None},
+ {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ # reverse foreign keys MUST be read_only
+ # In the general case they do not provide .remove() or .clear()
+ # and cannot be arbitrarily set.
+
+ # def test_reverse_foreign_key_update(self):
+ # data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ # instance = ForeignKeyTarget.objects.get(pk=1)
+ # serializer = ForeignKeyTargetSerializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # self.assertEquals(serializer.data, data)
+ # serializer.save()
+
+ # # Ensure target 1 is updated, and everything else is as expected
+ # queryset = ForeignKeyTarget.objects.all()
+ # serializer = ForeignKeyTargetSerializer(queryset)
+ # expected = [
+ # {'id': 1, 'name': u'target-1', 'sources': [1]},
+ # {'id': 2, 'name': u'target-2', 'sources': []},
+ # ]
+ # self.assertEquals(serializer.data, expected)
+
+
+class HyperlinkedNullableOneToOneTests(TestCase):
+ urls = 'rest_framework.tests.relations_hyperlink'
+
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset)
+ expected = [
+ {'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'},
+ {'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None},
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
new file mode 100644
index 00000000..0e129fae
--- /dev/null
+++ b/rest_framework/tests/relations_nested.py
@@ -0,0 +1,114 @@
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ depth = 1
+ model = ForeignKeySource
+
+
+class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+
+
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ sources = FlatForeignKeySourceSerializer()
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ depth = 1
+ model = NullableForeignKeySource
+
+
+class NullableOneToOneSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableOneToOneSource
+
+
+class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
+ nullable_source = NullableOneToOneSourceSerializer()
+
+ class Meta:
+ model = OneToOneTarget
+
+
+class ReverseForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 3, 'name': u'source-3', 'target': {'id': 1, 'name': u'target-1'}},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': 1},
+ ]},
+ {'id': 2, 'name': u'target-2', 'sources': [
+ ]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+
+class NestedNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+
+class NestedNullableOneToOneTests(TestCase):
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}},
+ {'id': 2, 'name': u'target-2', 'nullable_source': None},
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
new file mode 100644
index 00000000..54835860
--- /dev/null
+++ b/rest_framework/tests/relations_pk.py
@@ -0,0 +1,411 @@
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+
+
+class ManyToManyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.ManyPrimaryKeyRelatedField()
+
+ class Meta:
+ model = ManyToManyTarget
+
+
+class ManyToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManySource
+
+
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.ManyPrimaryKeyRelatedField()
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+
+
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
+
+
+# OneToOne
+class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
+ nullable_source = serializers.PrimaryKeyRelatedField()
+
+ class Meta:
+ model = OneToOneTarget
+
+
+# TODO: Add test that .data cannot be accessed prior to .is_valid
+
+class PKManyToManyTests(TestCase):
+ def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
+
+ def test_many_to_many_retrieve(self):
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'targets': [1]},
+ {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_retrieve(self):
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': u'target-3', 'sources': [3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_many_to_many_update(self):
+ data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]},
+ {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_update(self):
+ data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure target 1 is updated, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1]},
+ {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': u'target-3', 'sources': [3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_many_to_many_create(self):
+ data = {'id': 4, 'name': u'source-4', 'targets': [1, 3]}
+ serializer = ManyToManySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'targets': [1]},
+ {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]},
+ {'id': 4, 'name': u'source-4', 'targets': [1, 3]},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_create(self):
+ data = {'id': 4, 'name': u'target-4', 'sources': [1, 3]}
+ serializer = ManyToManyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-4')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': u'target-3', 'sources': [3]},
+ {'id': 4, 'name': u'target-4', 'sources': [1, 3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+
+class PKForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': 1}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': u'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 2},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': 1}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': u'target-2', 'sources': [1, 3]}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [2]},
+ {'id': 2, 'name': u'target-2', 'sources': [1, 3]},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': u'source-4', 'target': 2}
+ serializer = ForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': 1},
+ {'id': 4, 'name': u'source-4', 'target': 2},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': u'target-3', 'sources': [1, 3]}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [2]},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ {'id': 3, 'name': u'target-3', 'sources': [1, 3]},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': u'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
+
+
+class PKNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'id': 4, 'name': u'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ {'id': 4, 'name': u'source-4', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 4, 'name': u'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': u'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, expected_data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ {'id': 4, 'name': u'source-4', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'id': 1, 'name': u'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': None},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 1, 'name': u'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': u'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': None},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ # reverse foreign keys MUST be read_only
+ # In the general case they do not provide .remove() or .clear()
+ # and cannot be arbitrarily set.
+
+ # def test_reverse_foreign_key_update(self):
+ # data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ # instance = ForeignKeyTarget.objects.get(pk=1)
+ # serializer = ForeignKeyTargetSerializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # self.assertEquals(serializer.data, data)
+ # serializer.save()
+
+ # # Ensure target 1 is updated, and everything else is as expected
+ # queryset = ForeignKeyTarget.objects.all()
+ # serializer = ForeignKeyTargetSerializer(queryset)
+ # expected = [
+ # {'id': 1, 'name': u'target-1', 'sources': [1]},
+ # {'id': 2, 'name': u'target-2', 'sources': []},
+ # ]
+ # self.assertEquals(serializer.data, expected)
+
+
+class PKNullableOneToOneTests(TestCase):
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'nullable_source': 1},
+ {'id': 2, 'name': u'target-2', 'nullable_source': None},
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py
index 48d8d9bd..c1b4e624 100644
--- a/rest_framework/tests/renderers.py
+++ b/rest_framework/tests/renderers.py
@@ -1,11 +1,12 @@
+import pickle
import re
-from django.conf.urls.defaults import patterns, url, include
+from django.core.cache import cache
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import status, permissions
-from rest_framework.compat import yaml
+from rest_framework.compat import yaml, patterns, url, include
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
@@ -83,6 +84,7 @@ class HTMLView1(APIView):
urlpatterns = patterns('',
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^cache$', MockGETView.as_view()),
url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])),
url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])),
url(r'^html$', HTMLView.as_view()),
@@ -416,3 +418,89 @@ class XMLRendererTestCase(TestCase):
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
self.assertTrue(xml.endswith('</root>'))
self.assertTrue(string in xml, '%r not in %r' % (string, xml))
+
+
+# Tests for caching issue, #346
+class CacheRenderTest(TestCase):
+ """
+ Tests specific to caching responses
+ """
+
+ urls = 'rest_framework.tests.renderers'
+
+ cache_key = 'just_a_cache_key'
+
+ @classmethod
+ def _get_pickling_errors(cls, obj, seen=None):
+ """ Return any errors that would be raised if `obj' is pickled
+ Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897
+ """
+ if seen == None:
+ seen = []
+ try:
+ state = obj.__getstate__()
+ except AttributeError:
+ return
+ if state == None:
+ return
+ if isinstance(state, tuple):
+ if not isinstance(state[0], dict):
+ state = state[1]
+ else:
+ state = state[0].update(state[1])
+ result = {}
+ for i in state:
+ try:
+ pickle.dumps(state[i], protocol=2)
+ except pickle.PicklingError:
+ if not state[i] in seen:
+ seen.append(state[i])
+ result[i] = cls._get_pickling_errors(state[i], seen)
+ return result
+
+ def http_resp(self, http_method, url):
+ """
+ Simple wrapper for Client http requests
+ Removes the `client' and `request' attributes from as they are
+ added by django.test.client.Client and not part of caching
+ responses outside of tests.
+ """
+ method = getattr(self.client, http_method)
+ resp = method(url)
+ del resp.client, resp.request
+ return resp
+
+ def test_obj_pickling(self):
+ """
+ Test that responses are properly pickled
+ """
+ resp = self.http_resp('get', '/cache')
+
+ # Make sure that no pickling errors occurred
+ self.assertEqual(self._get_pickling_errors(resp), {})
+
+ # Unfortunately LocMem backend doesn't raise PickleErrors but returns
+ # None instead.
+ cache.set(self.cache_key, resp)
+ self.assertTrue(cache.get(self.cache_key) is not None)
+
+ def test_head_caching(self):
+ """
+ Test caching of HEAD requests
+ """
+ resp = self.http_resp('head', '/cache')
+ cache.set(self.cache_key, resp)
+
+ cached_resp = cache.get(self.cache_key)
+ self.assertIsInstance(cached_resp, Response)
+
+ def test_get_caching(self):
+ """
+ Test caching of GET requests
+ """
+ resp = self.http_resp('get', '/cache')
+ cache.set(self.cache_key, resp)
+
+ cached_resp = cache.get(self.cache_key)
+ self.assertIsInstance(cached_resp, Response)
+ self.assertEqual(cached_resp.content, resp.content)
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py
index ff48f3fa..4b032405 100644
--- a/rest_framework/tests/request.py
+++ b/rest_framework/tests/request.py
@@ -1,14 +1,15 @@
"""
Tests for content parsing, and form-overloaded content parsing.
"""
-from django.conf.urls.defaults import patterns
+import json
from django.contrib.auth.models import User
+from django.contrib.auth import authenticate, login, logout
+from django.contrib.sessions.middleware import SessionMiddleware
from django.test import TestCase, Client
-from django.utils import simplejson as json
-
+from django.test.client import RequestFactory
from rest_framework import status
from rest_framework.authentication import SessionAuthentication
-from django.test.client import RequestFactory
+from rest_framework.compat import patterns
from rest_framework.parsers import (
BaseParser,
FormParser,
@@ -276,3 +277,37 @@ class TestContentParsingWithAuthentication(TestCase):
# response = self.csrf_client.post('/', content)
# self.assertEqual(status.OK, response.status_code, "POST data is malformed")
+
+
+class TestUserSetter(TestCase):
+
+ def setUp(self):
+ # Pass request object through session middleware so session is
+ # available to login and logout functions
+ self.request = Request(factory.get('/'))
+ SessionMiddleware().process_request(self.request)
+
+ User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
+ self.user = authenticate(username='ringo', password='yellow')
+
+ def test_user_can_be_set(self):
+ self.request.user = self.user
+ self.assertEqual(self.request.user, self.user)
+
+ def test_user_can_login(self):
+ login(self.request, self.user)
+ self.assertEqual(self.request.user, self.user)
+
+ def test_user_can_logout(self):
+ self.request.user = self.user
+ self.assertFalse(self.request.user.is_anonymous())
+ logout(self.request)
+ self.assertTrue(self.request.user.is_anonymous())
+
+
+class TestAuthSetter(TestCase):
+
+ def test_auth_can_be_set(self):
+ request = Request(factory.get('/'))
+ request.auth = 'DUMMY'
+ self.assertEqual(request.auth, 'DUMMY')
diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py
index 18b6af39..875f4d42 100644
--- a/rest_framework/tests/response.py
+++ b/rest_framework/tests/response.py
@@ -1,8 +1,5 @@
-import unittest
-
-from django.conf.urls.defaults import patterns, url, include
from django.test import TestCase
-
+from rest_framework.compat import patterns, url, include
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework import status
@@ -131,12 +128,6 @@ class RendererIntegrationTests(TestCase):
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
- @unittest.skip('can\'t pass because view is a simple Django view and response is an ImmediateResponse')
- def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
- """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
- resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
- self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
-
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/reverse.py
index fd9a7d64..8c86e1fb 100644
--- a/rest_framework/tests/reverse.py
+++ b/rest_framework/tests/reverse.py
@@ -1,6 +1,6 @@
-from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from django.test.client import RequestFactory
+from rest_framework.compat import patterns, url
from rest_framework.reverse import reverse
factory = RequestFactory()
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index 936f15aa..bd96ba23 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -1,13 +1,16 @@
import datetime
+import pickle
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import *
+from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
+ BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,
+ ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
class SubComment(object):
def __init__(self, sub_comment):
self.sub_comment = sub_comment
-
+
class Comment(object):
def __init__(self, email, content, created):
@@ -18,7 +21,7 @@ class Comment(object):
def __eq__(self, other):
return all([getattr(self, attr) == getattr(other, attr)
for attr in ('email', 'content', 'created')])
-
+
def get_sub_comment(self):
sub_comment = SubComment('And Merry Christmas!')
return sub_comment
@@ -29,7 +32,7 @@ class CommentSerializer(serializers.Serializer):
content = serializers.CharField(max_length=1000)
created = serializers.DateTimeField()
sub_comment = serializers.Field(source='get_sub_comment.sub_comment')
-
+
def restore_object(self, data, instance=None):
if instance is None:
return Comment(**data)
@@ -38,10 +41,41 @@ class CommentSerializer(serializers.Serializer):
return instance
+class BookSerializer(serializers.ModelSerializer):
+ isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'})
+
+ class Meta:
+ model = Book
+
+
class ActionItemSerializer(serializers.ModelSerializer):
+
class Meta:
model = ActionItem
+
+class PersonSerializer(serializers.ModelSerializer):
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+ read_only_fields = ('age',)
+
+
+class AlbumsSerializer(serializers.ModelSerializer):
+
+ class Meta:
+ model = Album
+ fields = ['title'] # lists are also valid options
+
+
+class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = HasPositiveIntegerAsChoice
+ fields = ['some_integer']
+
+
class BasicTests(TestCase):
def setUp(self):
self.comment = Comment(
@@ -61,6 +95,9 @@ class BasicTests(TestCase):
'created': datetime.datetime(2012, 1, 1),
'sub_comment': 'And Merry Christmas!'
}
+ self.person_data = {'name': 'dwight', 'age': 35}
+ self.person = Person(**self.person_data)
+ self.person.save()
def test_empty(self):
serializer = CommentSerializer()
@@ -73,11 +110,11 @@ class BasicTests(TestCase):
self.assertEquals(serializer.data, expected)
def test_retrieve(self):
- serializer = CommentSerializer(instance=self.comment)
+ serializer = CommentSerializer(self.comment)
self.assertEquals(serializer.data, self.expected)
def test_create(self):
- serializer = CommentSerializer(self.data)
+ serializer = CommentSerializer(data=self.data)
expected = self.comment
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected)
@@ -85,13 +122,54 @@ class BasicTests(TestCase):
self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
def test_update(self):
- serializer = CommentSerializer(self.data, instance=self.comment)
+ serializer = CommentSerializer(self.comment, data=self.data)
expected = self.comment
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected)
self.assertTrue(serializer.object is expected)
self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
+ def test_partial_update(self):
+ msg = 'Merry New Year!'
+ partial_data = {'content': msg}
+ serializer = CommentSerializer(self.comment, data=partial_data)
+ self.assertEquals(serializer.is_valid(), False)
+ serializer = CommentSerializer(self.comment, data=partial_data, partial=True)
+ expected = self.comment
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEquals(serializer.object, expected)
+ self.assertTrue(serializer.object is expected)
+ self.assertEquals(serializer.data['content'], msg)
+
+ def test_model_fields_as_expected(self):
+ """
+ Make sure that the fields returned are the same as defined
+ in the Meta data
+ """
+ serializer = PersonSerializer(self.person)
+ self.assertEquals(set(serializer.data.keys()),
+ set(['name', 'age', 'info']))
+
+ def test_field_with_dictionary(self):
+ """
+ Make sure that dictionaries from fields are left intact
+ """
+ serializer = PersonSerializer(self.person)
+ expected = self.person_data
+ self.assertEquals(serializer.data['info'], expected)
+
+ def test_read_only_fields(self):
+ """
+ Attempting to update fields set as read_only should have no effect.
+ """
+
+ serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(serializer.errors, {})
+ # Assert age is unchanged (35)
+ self.assertEquals(instance.age, self.person_data['age'])
+
class ValidationTests(TestCase):
def setUp(self):
@@ -104,17 +182,17 @@ class ValidationTests(TestCase):
'email': 'tom@example.com',
'content': 'x' * 1001,
'created': datetime.datetime(2012, 1, 1)
- }
- self.actionitem = ActionItem('Some to do item',
+ }
+ self.actionitem = ActionItem(title='Some to do item',
)
def test_create(self):
- serializer = CommentSerializer(self.data)
+ serializer = CommentSerializer(data=self.data)
self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
def test_update(self):
- serializer = CommentSerializer(self.data, instance=self.comment)
+ serializer = CommentSerializer(self.comment, data=self.data)
self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
@@ -123,7 +201,7 @@ class ValidationTests(TestCase):
'content': 'xxx',
'created': datetime.datetime(2012, 1, 1)
}
- serializer = CommentSerializer(data, instance=self.comment)
+ serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'email': [u'This field is required.']})
@@ -131,10 +209,10 @@ class ValidationTests(TestCase):
"""Make sure that a boolean value with a 'False' value is not
mistaken for not having a default."""
data = {
- 'title':'Some action item',
+ 'title': 'Some action item',
#No 'done' value.
}
- serializer = ActionItemSerializer(data, instance=self.actionitem)
+ serializer = ActionItemSerializer(self.actionitem, data=data)
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.errors, {})
@@ -154,15 +232,34 @@ class ValidationTests(TestCase):
'created': datetime.datetime(2012, 1, 1)
}
- serializer = CommentSerializerWithFieldValidator(data)
+ serializer = CommentSerializerWithFieldValidator(data=data)
self.assertTrue(serializer.is_valid())
data['content'] = 'This should not validate'
- serializer = CommentSerializerWithFieldValidator(data)
+ serializer = CommentSerializerWithFieldValidator(data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
+ def test_bad_type_data_is_false(self):
+ """
+ Data of the wrong type is not valid.
+ """
+ data = ['i am', 'a', 'list']
+ serializer = CommentSerializer(self.comment, data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
+
+ data = 'and i am a string'
+ serializer = CommentSerializer(self.comment, data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
+
+ data = 42
+ serializer = CommentSerializer(self.comment, data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
+
def test_cross_field_validation(self):
class CommentSerializerWithCrossFieldValidator(CommentSerializer):
@@ -178,15 +275,109 @@ class ValidationTests(TestCase):
'created': datetime.datetime(2012, 1, 1)
}
- serializer = CommentSerializerWithCrossFieldValidator(data)
+ serializer = CommentSerializerWithCrossFieldValidator(data=data)
self.assertTrue(serializer.is_valid())
data['content'] = 'A comment from foo@bar.com'
- serializer = CommentSerializerWithCrossFieldValidator(data)
+ serializer = CommentSerializerWithCrossFieldValidator(data=data)
self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']})
+ def test_null_is_true_fields(self):
+ """
+ Omitting a value for null-field should validate.
+ """
+ serializer = PersonSerializer(data={'name': 'marko'})
+ self.assertEquals(serializer.is_valid(), True)
+ self.assertEquals(serializer.errors, {})
+
+ def test_modelserializer_max_length_exceeded(self):
+ data = {
+ 'title': 'x' * 201,
+ }
+ serializer = ActionItemSerializer(data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']})
+
+ def test_default_modelfield_max_length_exceeded(self):
+ data = {
+ 'title': 'Testing "info" field...',
+ 'info': 'x' * 13,
+ }
+ serializer = ActionItemSerializer(data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']})
+
+
+class PositiveIntegerAsChoiceTests(TestCase):
+ def test_positive_integer_in_json_is_correctly_parsed(self):
+ data = {'some_integer':1}
+ serializer = PositiveIntegerAsChoiceSerializer(data=data)
+ self.assertEquals(serializer.is_valid(), True)
+
+class ModelValidationTests(TestCase):
+ def test_validate_unique(self):
+ """
+ Just check if serializers.ModelSerializer handles unique checks via .full_clean()
+ """
+ serializer = AlbumsSerializer(data={'title': 'a'})
+ serializer.is_valid()
+ serializer.save()
+ second_serializer = AlbumsSerializer(data={'title': 'a'})
+ self.assertFalse(second_serializer.is_valid())
+ self.assertEqual(second_serializer.errors, {'title': [u'Album with this Title already exists.']})
+
+ def test_foreign_key_with_partial(self):
+ """
+ Test ModelSerializer validation with partial=True
+
+ Specifically test foreign key validation.
+ """
+
+ album = Album(title='test')
+ album.save()
+
+ class PhotoSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Photo
+
+ photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk})
+ self.assertTrue(photo_serializer.is_valid())
+ photo = photo_serializer.save()
+
+ # Updating only the album (foreign key)
+ photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True)
+ self.assertTrue(photo_serializer.is_valid())
+ self.assertTrue(photo_serializer.save())
+
+ # Updating only the description
+ photo_serializer = PhotoSerializer(instance=photo,
+ data={'description': 'new'},
+ partial=True)
+
+ self.assertTrue(photo_serializer.is_valid())
+ self.assertTrue(photo_serializer.save())
+
+
+class RegexValidationTest(TestCase):
+ def test_create_failed(self):
+ serializer = BookSerializer(data={'isbn': '1234567890'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+
+ serializer = BookSerializer(data={'isbn': '12345678901234'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+
+ serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+
+ def test_create_success(self):
+ serializer = BookSerializer(data={'isbn': '1234567890123'})
+ self.assertTrue(serializer.is_valid())
+
class MetadataTests(TestCase):
def test_empty(self):
@@ -233,7 +424,7 @@ class ManyToManyTests(TestCase):
Create an instance of a model with a ManyToMany relationship.
"""
data = {'rel': [self.anchor.id]}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
@@ -247,7 +438,7 @@ class ManyToManyTests(TestCase):
new_anchor = Anchor()
new_anchor.save()
data = {'rel': [self.anchor.id, new_anchor.id]}
- serializer = self.serializer_class(data, instance=self.instance)
+ serializer = self.serializer_class(self.instance, data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 1)
@@ -260,7 +451,7 @@ class ManyToManyTests(TestCase):
containing no items.
"""
data = {'rel': []}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
@@ -275,7 +466,7 @@ class ManyToManyTests(TestCase):
new_anchor = Anchor()
new_anchor.save()
data = {'rel': []}
- serializer = self.serializer_class(data, instance=self.instance)
+ serializer = self.serializer_class(self.instance, data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 1)
@@ -289,17 +480,19 @@ class ManyToManyTests(TestCase):
lists (eg form data).
"""
data = {'rel': ''}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
self.assertEquals(instance.pk, 2)
self.assertEquals(list(instance.rel.all()), [])
-
+
+
class ReadOnlyManyToManyTests(TestCase):
def setUp(self):
class ReadOnlyManyToManySerializer(serializers.ModelSerializer):
- rel = serializers.ManyRelatedField(readonly=True)
+ rel = serializers.ManyRelatedField(read_only=True)
+
class Meta:
model = ReadOnlyManyToManyModel
@@ -317,16 +510,15 @@ class ReadOnlyManyToManyTests(TestCase):
# A serialized representation of the model instance
self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'}
-
def test_update(self):
"""
Attempt to update an instance of a model with a ManyToMany
- relationship. Not updated due to readonly=True
+ relationship. Not updated due to read_only=True
"""
new_anchor = Anchor()
new_anchor.save()
data = {'rel': [self.anchor.id, new_anchor.id]}
- serializer = self.serializer_class(data, instance=self.instance)
+ serializer = self.serializer_class(self.instance, data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
@@ -337,12 +529,12 @@ class ReadOnlyManyToManyTests(TestCase):
def test_update_without_relationship(self):
"""
Attempt to update an instance of a model where many to ManyToMany
- relationship is not supplied. Not updated due to readonly=True
+ relationship is not supplied. Not updated due to read_only=True
"""
new_anchor = Anchor()
new_anchor.save()
data = {}
- serializer = self.serializer_class(data, instance=self.instance)
+ serializer = self.serializer_class(self.instance, data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
@@ -362,7 +554,7 @@ class DefaultValueTests(TestCase):
def test_create_using_default(self):
data = {}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(self.objects.all()), 1)
@@ -371,13 +563,28 @@ class DefaultValueTests(TestCase):
def test_create_overriding_default(self):
data = {'text': 'overridden'}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(self.objects.all()), 1)
self.assertEquals(instance.pk, 1)
self.assertEquals(instance.text, 'overridden')
+ def test_partial_update_default(self):
+ """ Regression test for issue #532 """
+ data = {'text': 'overridden'}
+ serializer = self.serializer_class(data=data, partial=True)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+
+ data = {'extra': 'extra_value'}
+ serializer = self.serializer_class(instance=instance, data=data, partial=True)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+
+ self.assertEquals(instance.extra, 'extra_value')
+ self.assertEquals(instance.text, 'overridden')
+
class CallableDefaultValueTests(TestCase):
def setUp(self):
@@ -390,7 +597,7 @@ class CallableDefaultValueTests(TestCase):
def test_create_using_default(self):
data = {}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(self.objects.all()), 1)
@@ -399,7 +606,7 @@ class CallableDefaultValueTests(TestCase):
def test_create_overriding_default(self):
data = {'text': 'overridden'}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(self.objects.all()), 1)
@@ -408,7 +615,10 @@ class CallableDefaultValueTests(TestCase):
class ManyRelatedTests(TestCase):
- def setUp(self):
+ def test_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
class BlogPostCommentSerializer(serializers.Serializer):
text = serializers.CharField()
@@ -417,14 +627,7 @@ class ManyRelatedTests(TestCase):
title = serializers.CharField()
comments = BlogPostCommentSerializer(source='blogpostcomment_set')
- self.serializer_class = BlogPostSerializer
-
- def test_reverse_relations(self):
- post = BlogPost.objects.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I hate this blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- serializer = self.serializer_class(instance=post)
+ serializer = BlogPostSerializer(instance=post)
expected = {
'title': 'Test blog post',
'comments': [
@@ -434,3 +637,267 @@ class ManyRelatedTests(TestCase):
}
self.assertEqual(serializer.data, expected)
+
+ def test_callable_source(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ first_comment = BlogPostCommentSerializer(source='get_first_comment')
+
+ serializer = BlogPostSerializer(post)
+
+ expected = {
+ 'title': 'Test blog post',
+ 'first_comment': {'text': 'I love this blog post'}
+ }
+ self.assertEqual(serializer.data, expected)
+
+
+class RelatedTraversalTest(TestCase):
+ def test_nested_traversal(self):
+ user = Person.objects.create(name="django")
+ post = BlogPost.objects.create(title="Test blog post", writer=user)
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ from rest_framework.tests.models import BlogPostComment
+
+ class PersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+ fields = ("name", "age")
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPostComment
+ fields = ("text", "post_owner")
+
+ text = serializers.CharField()
+ post_owner = PersonSerializer(source='blog_post.writer')
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ comments = BlogPostCommentSerializer(source='blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+
+ expected = {
+ 'title': u'Test blog post',
+ 'comments': [{
+ 'text': u'I love this blog post',
+ 'post_owner': {
+ "name": u"django",
+ "age": None
+ }
+ }]
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+
+class SerializerMethodFieldTests(TestCase):
+ def setUp(self):
+
+ class BoopSerializer(serializers.Serializer):
+ beep = serializers.SerializerMethodField('get_beep')
+ boop = serializers.Field()
+ boop_count = serializers.SerializerMethodField('get_boop_count')
+
+ def get_beep(self, obj):
+ return 'hello!'
+
+ def get_boop_count(self, obj):
+ return len(obj.boop)
+
+ self.serializer_class = BoopSerializer
+
+ def test_serializer_method_field(self):
+
+ class MyModel(object):
+ boop = ['a', 'b', 'c']
+
+ source_data = MyModel()
+
+ serializer = self.serializer_class(source_data)
+
+ expected = {
+ 'beep': u'hello!',
+ 'boop': [u'a', u'b', u'c'],
+ 'boop_count': 3,
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+
+# Test for issue #324
+class BlankFieldTests(TestCase):
+ def setUp(self):
+
+ class BlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlankFieldModel
+
+ class BlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField(blank=True)
+
+ class NotBlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+ class NotBlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField()
+
+ self.model_serializer_class = BlankFieldModelSerializer
+ self.serializer_class = BlankFieldSerializer
+ self.not_blank_model_serializer_class = NotBlankFieldModelSerializer
+ self.not_blank_serializer_class = NotBlankFieldSerializer
+ self.data = {'title': ''}
+
+ def test_create_blank_field(self):
+ serializer = self.serializer_class(data=self.data)
+ self.assertEquals(serializer.is_valid(), True)
+
+ def test_create_model_blank_field(self):
+ serializer = self.model_serializer_class(data=self.data)
+ self.assertEquals(serializer.is_valid(), True)
+
+ def test_create_model_null_field(self):
+ serializer = self.model_serializer_class(data={'title': None})
+ self.assertEquals(serializer.is_valid(), True)
+
+ def test_create_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a non-model serializer
+ """
+ serializer = self.not_blank_serializer_class(data=self.data)
+ self.assertEquals(serializer.is_valid(), False)
+
+ def test_create_model_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a model serializer
+ """
+ serializer = self.not_blank_model_serializer_class(data=self.data)
+ self.assertEquals(serializer.is_valid(), False)
+
+ def test_create_model_null_field(self):
+ serializer = self.model_serializer_class(data={})
+ self.assertEquals(serializer.is_valid(), True)
+
+
+#test for issue #460
+class SerializerPickleTests(TestCase):
+ """
+ Test pickleability of the output of Serializers
+ """
+ def test_pickle_simple_model_serializer_data(self):
+ """
+ Test simple serializer
+ """
+ pickle.dumps(PersonSerializer(Person(name="Methusela", age=969)).data)
+
+ def test_pickle_inner_serializer(self):
+ """
+ Test pickling a serializer whose resulting .data (a SortedDictWithMetadata) will
+ have unpickleable meta data--in order to make sure metadata doesn't get pulled into the pickle.
+ See DictWithMetadata.__getstate__
+ """
+ class InnerPersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+ fields = ('name', 'age')
+ pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data)
+
+
+class DepthTest(TestCase):
+ def test_implicit_nesting(self):
+ writer = Person.objects.create(name="django", age=1)
+ post = BlogPost.objects.create(title="Test blog post", writer=writer)
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ depth = 1
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {'id': 1, 'title': u'Test blog post',
+ 'writer': {'id': 1, 'name': u'django', 'age': 1}}
+
+ self.assertEqual(serializer.data, expected)
+
+ def test_explicit_nesting(self):
+ writer = Person.objects.create(name="django", age=1)
+ post = BlogPost.objects.create(title="Test blog post", writer=writer)
+
+ class PersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ writer = PersonSerializer()
+
+ class Meta:
+ model = BlogPost
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {'id': 1, 'title': u'Test blog post',
+ 'writer': {'id': 1, 'name': u'django', 'age': 1}}
+
+ self.assertEqual(serializer.data, expected)
+
+
+class NestedSerializerContextTests(TestCase):
+
+ def test_nested_serializer_context(self):
+ """
+ Regression for #497
+
+ https://github.com/tomchristie/django-rest-framework/issues/497
+ """
+ class PhotoSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Photo
+ fields = ("description", "callable")
+
+ callable = serializers.SerializerMethodField('_callable')
+
+ def _callable(self, instance):
+ if not 'context_item' in self.context:
+ raise RuntimeError("context isn't getting passed into 2nd level nested serializer")
+ return "success"
+
+ class AlbumSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Album
+ fields = ("photo_set", "callable")
+
+ photo_set = PhotoSerializer(source="photo_set")
+ callable = serializers.SerializerMethodField("_callable")
+
+ def _callable(self, instance):
+ if not 'context_item' in self.context:
+ raise RuntimeError("context isn't getting passed into 1st level nested serializer")
+ return "success"
+
+ class AlbumCollection(object):
+ albums = None
+
+ class AlbumCollectionSerializer(serializers.Serializer):
+ albums = AlbumSerializer(source="albums")
+
+ album1 = Album.objects.create(title="album 1")
+ album2 = Album.objects.create(title="album 2")
+ Photo.objects.create(description="Bigfoot", album=album1)
+ Photo.objects.create(description="Unicorn", album=album1)
+ Photo.objects.create(description="Yeti", album=album2)
+ Photo.objects.create(description="Sasquatch", album=album2)
+ album_collection = AlbumCollection()
+ album_collection.albums = [album1, album2]
+
+ # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers
+ AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data
diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/settings.py
new file mode 100644
index 00000000..0293fdc3
--- /dev/null
+++ b/rest_framework/tests/settings.py
@@ -0,0 +1,21 @@
+"""Tests for the settings module"""
+from django.test import TestCase
+
+from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS
+
+
+class TestSettings(TestCase):
+ """Tests relating to the api settings"""
+
+ def test_non_import_errors(self):
+ """Make sure other errors aren't suppressed."""
+ settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
+ with self.assertRaises(ValueError):
+ settings.DEFAULT_MODEL_SERIALIZER_CLASS
+
+ def test_import_error_message_maintained(self):
+ """Make sure real import errors are captured and raised sensibly."""
+ settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
+ with self.assertRaises(ImportError) as cm:
+ settings.DEFAULT_MODEL_SERIALIZER_CLASS
+ self.assertTrue('ImportError' in str(cm.exception))
diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py
index c90224aa..97f492ff 100644
--- a/rest_framework/tests/testcases.py
+++ b/rest_framework/tests/testcases.py
@@ -6,6 +6,7 @@ from django.test import TestCase
NO_SETTING = ('!', None)
+
class TestSettingsManager(object):
"""
A class which can modify some Django settings temporarily for a
@@ -19,7 +20,7 @@ class TestSettingsManager(object):
self._original_settings = {}
def set(self, **kwargs):
- for k,v in kwargs.iteritems():
+ for k, v in kwargs.iteritems():
self._original_settings.setdefault(k, getattr(settings, k,
NO_SETTING))
setattr(settings, k, v)
@@ -31,7 +32,7 @@ class TestSettingsManager(object):
call_command('syncdb', verbosity=0)
def revert(self):
- for k,v in self._original_settings.iteritems():
+ for k, v in self._original_settings.iteritems():
if v == NO_SETTING:
delattr(settings, k)
else:
@@ -57,6 +58,7 @@ class SettingsTestCase(TestCase):
def tearDown(self):
self.settings_manager.revert()
+
class TestModelsTestCase(SettingsTestCase):
def setUp(self, *args, **kwargs):
installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',)
diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py
new file mode 100644
index 00000000..adeaf6da
--- /dev/null
+++ b/rest_framework/tests/tests.py
@@ -0,0 +1,13 @@
+"""
+Force import of all modules in this package in order to get the standard test
+runner to pick up the tests. Yowzers.
+"""
+import os
+
+modules = [filename.rsplit('.', 1)[0]
+ for filename in os.listdir(os.path.dirname(__file__))
+ if filename.endswith('.py') and not filename.startswith('_')]
+__test__ = dict()
+
+for module in modules:
+ exec("from rest_framework.tests.%s import *" % module)
diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py
index 0b94c25b..4b98b941 100644
--- a/rest_framework/tests/throttling.py
+++ b/rest_framework/tests/throttling.py
@@ -106,7 +106,7 @@ class ThrottlingTests(TestCase):
if expect is not None:
self.assertEquals(response['X-Throttle-Wait-Seconds'], expect)
else:
- self.assertFalse('X-Throttle-Wait-Seconds' in response.headers)
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
def test_seconds_fields(self):
"""
diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py
new file mode 100644
index 00000000..3906adb9
--- /dev/null
+++ b/rest_framework/tests/utils.py
@@ -0,0 +1,27 @@
+from django.test.client import RequestFactory, FakePayload
+from django.test.client import MULTIPART_CONTENT
+from urlparse import urlparse
+
+
+class RequestFactory(RequestFactory):
+
+ def __init__(self, **defaults):
+ super(RequestFactory, self).__init__(**defaults)
+
+ def patch(self, path, data={}, content_type=MULTIPART_CONTENT,
+ **extra):
+ "Construct a PATCH request."
+
+ patch_data = self._encode_data(data, content_type)
+
+ parsed = urlparse(path)
+ r = {
+ 'CONTENT_LENGTH': len(patch_data),
+ 'CONTENT_TYPE': content_type,
+ 'PATH_INFO': self._get_path(parsed),
+ 'QUERY_STRING': parsed[4],
+ 'REQUEST_METHOD': 'PATCH',
+ 'wsgi.input': FakePayload(patch_data),
+ }
+ r.update(extra)
+ return self.request(**r)
diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py
index b390c42f..c032985e 100644
--- a/rest_framework/tests/validators.py
+++ b/rest_framework/tests/validators.py
@@ -285,7 +285,7 @@
# uiop = models.CharField(max_length=256, blank=True)
# @property
-# def readonly(self):
+# def read_only(self):
# return 'read only'
# class MockResource(ModelResource):
@@ -298,7 +298,7 @@
# def test_property_fields_are_allowed_on_model_forms(self):
# """Validation on ModelForms may include property fields that exist on the Model to be included in the input."""
-# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only'}
+# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'}
# self.assertEqual(self.validator.validate_request(content, None), content)
# def test_property_fields_are_not_required_on_model_forms(self):
@@ -310,19 +310,19 @@
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only', 'extra': 'extra'}
+# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'}
# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
# def test_validate_requires_fields_on_model_forms(self):
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'readonly': 'read only'}
+# content = {'read_only': 'read only'}
# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
# def test_validate_does_not_require_blankable_fields_on_model_forms(self):
# """Test standard ModelForm validation behaviour - fields with blank=True are not required."""
-# content = {'qwerty': 'example', 'readonly': 'read only'}
+# content = {'qwerty': 'example', 'read_only': 'read only'}
# self.validator.validate_request(content, None)
# def test_model_form_validator_uses_model_forms(self):
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py
index 43365e07..7cd82656 100644
--- a/rest_framework/tests/views.py
+++ b/rest_framework/tests/views.py
@@ -18,7 +18,7 @@ class BasicView(APIView):
return Response({'method': 'POST', 'data': request.DATA})
-@api_view(['GET', 'POST', 'PUT'])
+@api_view(['GET', 'POST', 'PUT', 'PATCH'])
def basic_view(request):
if request.method == 'GET':
return {'method': 'GET'}
@@ -26,6 +26,8 @@ def basic_view(request):
return {'method': 'POST', 'data': request.DATA}
elif request.method == 'PUT':
return {'method': 'PUT', 'data': request.DATA}
+ elif request.method == 'PATCH':
+ return {'method': 'PATCH', 'data': request.DATA}
def sanitise_json_error(error_dict):
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 6860e6b9..8fe64248 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -16,7 +16,7 @@ class BaseThrottle(object):
def wait(self):
"""
- Optionally, return a recommeded number of seconds to wait before
+ Optionally, return a recommended number of seconds to wait before
the next request.
"""
return None
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index 316ccd19..143928c9 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -1,10 +1,10 @@
-from django.conf.urls.defaults import url
+from rest_framework.compat import url
from rest_framework.settings import api_settings
def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
"""
- Supplement existing urlpatterns with corrosponding patterns that also
+ Supplement existing urlpatterns with corresponding patterns that also
include a '.format' suffix. Retains urlpattern ordering.
urlpatterns:
diff --git a/rest_framework/urls.py b/rest_framework/urls.py
index 1a81101f..fbe4bc07 100644
--- a/rest_framework/urls.py
+++ b/rest_framework/urls.py
@@ -1,7 +1,7 @@
"""
-Login and logout views for the browseable API.
+Login and logout views for the browsable API.
-Add these to your root URLconf if you're using the browseable API and
+Add these to your root URLconf if you're using the browsable API and
your API requires authentication.
The urls must be namespaced as 'rest_framework', and you should make sure
@@ -12,7 +12,7 @@ your authentication settings include `SessionAuthentication`.
url(r'^auth', include('rest_framework.urls', namespace='rest_framework'))
)
"""
-from django.conf.urls.defaults import patterns, url
+from rest_framework.compat import patterns, url
template_name = {'template_name': 'rest_framework/login.html'}
diff --git a/rest_framework/utils/__init__.py b/rest_framework/utils/__init__.py
index a59fff45..84fcb5db 100644
--- a/rest_framework/utils/__init__.py
+++ b/rest_framework/utils/__init__.py
@@ -1,7 +1,6 @@
from django.utils.encoding import smart_unicode
from django.utils.xmlutils import SimplerXMLGenerator
from rest_framework.compat import StringIO
-
import re
import xml.etree.ElementTree as ET
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py
index 672d32a3..80e39d46 100644
--- a/rest_framework/utils/breadcrumbs.py
+++ b/rest_framework/utils/breadcrumbs.py
@@ -6,7 +6,7 @@ def get_breadcrumbs(url):
from rest_framework.views import APIView
- def breadcrumbs_recursive(url, breadcrumbs_list, prefix):
+ def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):
"""Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url."""
try:
@@ -16,7 +16,11 @@ def get_breadcrumbs(url):
else:
# Check if this is a REST framework view, and if so add it to the breadcrumbs
if isinstance(getattr(view, 'cls_instance', None), APIView):
- breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url))
+ # Don't list the same view twice in a row.
+ # Probably an optional trailing slash.
+ if not seen or seen[-1] != view:
+ breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url))
+ seen.append(view)
if url == '':
# All done
@@ -24,11 +28,11 @@ def get_breadcrumbs(url):
elif url.endswith('/'):
# Drop trailing slash off the end and continue to try to resolve more breadcrumbs
- return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix)
+ return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen)
# Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs
- return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix)
+ return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen)
prefix = get_script_prefix().rstrip('/')
url = url[len(prefix):]
- return breadcrumbs_recursive(url, [], prefix)
+ return breadcrumbs_recursive(url, [], prefix, [])
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 2d1fb353..7afe100a 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -4,7 +4,7 @@ Helper classes for parsers.
import datetime
import decimal
import types
-from django.utils import simplejson as json
+import json
from django.utils.datastructures import SortedDict
from rest_framework.compat import timezone
from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
@@ -12,7 +12,7 @@ from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
class JSONEncoder(json.JSONEncoder):
"""
- JSONEncoder subclass that knows how to encode date/time,
+ JSONEncoder subclass that knows how to encode date/time/timedelta,
decimal types, and generators.
"""
def default(self, o):
@@ -34,6 +34,8 @@ class JSONEncoder(json.JSONEncoder):
if o.microsecond:
r = r[:12]
return r
+ elif isinstance(o, datetime.timedelta):
+ return str(o.total_seconds())
elif isinstance(o, decimal.Decimal):
return str(o)
elif hasattr(o, '__iter__'):
diff --git a/rest_framework/views.py b/rest_framework/views.py
index c721be3c..10bdd5a5 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -140,7 +140,7 @@ class APIView(View):
def http_method_not_allowed(self, request, *args, **kwargs):
"""
- Called if `request.method` does not corrospond to a handler method.
+ Called if `request.method` does not correspond to a handler method.
"""
raise exceptions.MethodNotAllowed(request.method)
@@ -218,7 +218,7 @@ class APIView(View):
def get_throttles(self):
"""
- Instantiates and returns the list of thottles that this view uses.
+ Instantiates and returns the list of throttles that this view uses.
"""
return [throttle() for throttle in self.throttle_classes]
@@ -320,13 +320,17 @@ class APIView(View):
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
if isinstance(exc, exceptions.APIException):
- return Response({'detail': exc.detail}, status=exc.status_code)
+ return Response({'detail': exc.detail},
+ status=exc.status_code,
+ exception=True)
elif isinstance(exc, Http404):
return Response({'detail': 'Not found'},
- status=status.HTTP_404_NOT_FOUND)
+ status=status.HTTP_404_NOT_FOUND,
+ exception=True)
elif isinstance(exc, PermissionDenied):
return Response({'detail': 'Permission denied'},
- status=status.HTTP_403_FORBIDDEN)
+ status=status.HTTP_403_FORBIDDEN,
+ exception=True)
raise
# Note: session based authentication is explicitly CSRF validated,