aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/authtoken/migrations/0001_initial.py14
-rw-r--r--rest_framework/authtoken/models.py3
-rw-r--r--rest_framework/authtoken/serializers.py24
-rw-r--r--rest_framework/authtoken/views.py24
-rw-r--r--rest_framework/compat.py52
-rw-r--r--rest_framework/decorators.py2
-rw-r--r--rest_framework/fields.py192
-rw-r--r--rest_framework/filters.py59
-rw-r--r--rest_framework/generics.py56
-rw-r--r--rest_framework/mixins.py30
-rw-r--r--rest_framework/pagination.py17
-rw-r--r--rest_framework/renderers.py11
-rw-r--r--rest_framework/response.py7
-rw-r--r--rest_framework/runtests/settings.py1
-rw-r--r--rest_framework/serializers.py94
-rw-r--r--rest_framework/settings.py18
-rw-r--r--rest_framework/templatetags/rest_framework.py25
-rw-r--r--rest_framework/tests/authentication.py33
-rw-r--r--rest_framework/tests/files.py55
-rw-r--r--rest_framework/tests/filterset.py168
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py42
-rw-r--r--rest_framework/tests/models.py28
-rw-r--r--rest_framework/tests/pagination.py159
-rw-r--r--rest_framework/tests/pk_relations.py19
-rw-r--r--rest_framework/tests/response.py6
-rw-r--r--rest_framework/tests/serializer.py144
-rw-r--r--rest_framework/tests/throttling.py2
-rw-r--r--rest_framework/urlpatterns.py2
-rw-r--r--rest_framework/urls.py4
-rw-r--r--rest_framework/utils/__init__.py1
-rw-r--r--rest_framework/views.py2
32 files changed, 1107 insertions, 189 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index fc99b879..48cebbc5 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,3 +1,3 @@
-__version__ = '2.1.1'
+__version__ = '2.1.6'
VERSION = __version__ # synonym
diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py
index 9d750381..f4e052e4 100644
--- a/rest_framework/authtoken/migrations/0001_initial.py
+++ b/rest_framework/authtoken/migrations/0001_initial.py
@@ -5,13 +5,21 @@ from south.v2 import SchemaMigration
from django.db import models
+try:
+ from django.contrib.auth import get_user_model
+except ImportError: # django < 1.5
+ from django.contrib.auth.models import User
+else:
+ User = get_user_model()
+
+
class Migration(SchemaMigration):
def forwards(self, orm):
# Adding model 'Token'
db.create_table('authtoken_token', (
('key', self.gf('django.db.models.fields.CharField')(max_length=40, primary_key=True)),
- ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['auth.User'])),
+ ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['%s.%s' % (User._meta.app_label, User._meta.object_name)])),
('created', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)),
))
db.send_create_signal('authtoken', ['Token'])
@@ -36,7 +44,7 @@ class Migration(SchemaMigration):
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '50'})
},
- 'auth.user': {
+ "%s.%s" % (User._meta.app_label, User._meta.module_name): {
'Meta': {'object_name': 'User'},
'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}),
@@ -56,7 +64,7 @@ class Migration(SchemaMigration):
'Meta': {'object_name': 'Token'},
'created': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}),
'key': ('django.db.models.fields.CharField', [], {'max_length': '40', 'primary_key': 'True'}),
- 'user': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'auth_token'", 'unique': 'True', 'to': "orm['auth.User']"})
+ 'user': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'auth_token'", 'unique': 'True', 'to': "orm['%s.%s']" % (User._meta.app_label, User._meta.object_name)})
},
'contenttypes.contenttype': {
'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"},
diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py
index 5b3071aa..4da2aa62 100644
--- a/rest_framework/authtoken/models.py
+++ b/rest_framework/authtoken/models.py
@@ -1,6 +1,7 @@
import uuid
import hmac
from hashlib import sha1
+from rest_framework.compat import User
from django.db import models
@@ -9,7 +10,7 @@ class Token(models.Model):
The default authorization token model.
"""
key = models.CharField(max_length=40, primary_key=True)
- user = models.OneToOneField('auth.User', related_name='auth_token')
+ user = models.OneToOneField(User, related_name='auth_token')
created = models.DateTimeField(auto_now_add=True)
def save(self, *args, **kwargs):
diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py
new file mode 100644
index 00000000..a5ed6e6d
--- /dev/null
+++ b/rest_framework/authtoken/serializers.py
@@ -0,0 +1,24 @@
+from django.contrib.auth import authenticate
+from rest_framework import serializers
+
+class AuthTokenSerializer(serializers.Serializer):
+ username = serializers.CharField()
+ password = serializers.CharField()
+
+ def validate(self, attrs):
+ username = attrs.get('username')
+ password = attrs.get('password')
+
+ if username and password:
+ user = authenticate(username=username, password=password)
+
+ if user:
+ if not user.is_active:
+ raise serializers.ValidationError('User account is disabled.')
+ attrs['user'] = user
+ return attrs
+ else:
+ raise serializers.ValidationError('Unable to login with provided credentials.')
+ else:
+ raise serializers.ValidationError('Must include "username" and "password"')
+
diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py
index e69de29b..3ac674e2 100644
--- a/rest_framework/authtoken/views.py
+++ b/rest_framework/authtoken/views.py
@@ -0,0 +1,24 @@
+from rest_framework.views import APIView
+from rest_framework import status
+from rest_framework import parsers
+from rest_framework import renderers
+from rest_framework.response import Response
+from rest_framework.authtoken.models import Token
+from rest_framework.authtoken.serializers import AuthTokenSerializer
+
+class ObtainAuthToken(APIView):
+ throttle_classes = ()
+ permission_classes = ()
+ parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
+ renderer_classes = (renderers.JSONRenderer,)
+ model = Token
+
+ def post(self, request):
+ serializer = AuthTokenSerializer(data=request.DATA)
+ if serializer.is_valid():
+ token, created = Token.objects.get_or_create(user=serializer.object['user'])
+ return Response({'token': token.key})
+ return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+
+
+obtain_auth_token = ObtainAuthToken.as_view()
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index b0367a32..09b76368 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -1,10 +1,17 @@
"""
The `compat` module provides support for backwards compatibility with older
-versions of django/python, and compatbility wrappers around optional packages.
+versions of django/python, and compatibility wrappers around optional packages.
"""
# flake8: noqa
import django
+# django-filter is optional
+try:
+ import django_filters
+except:
+ django_filters = None
+
+
# cStringIO only if it's available, otherwise StringIO
try:
import cStringIO as StringIO
@@ -20,6 +27,20 @@ def get_concrete_model(model_cls):
return model_cls
+# Django 1.5 add support for custom auth user model
+if django.VERSION >= (1, 5):
+ from django.conf import settings
+ if hasattr(settings, 'AUTH_USER_MODEL'):
+ User = settings.AUTH_USER_MODEL
+ else:
+ from django.contrib.auth.models import User
+else:
+ try:
+ from django.contrib.auth.models import User
+ except ImportError:
+ raise ImportError(u"User model is not to be found.")
+
+
# First implementation of Django class-based views did not include head method
# in base View class - https://code.djangoproject.com/ticket/15668
if django.VERSION >= (1, 4):
@@ -333,7 +354,7 @@ try:
"""
extensions = ['headerid(level=2)']
- safe_mode = False,
+ safe_mode = False
md = markdown.Markdown(extensions=extensions, safe_mode=safe_mode)
return md.convert(text)
@@ -348,33 +369,6 @@ except ImportError:
yaml = None
-import unittest
-try:
- import unittest.skip
-except ImportError: # python < 2.7
- from unittest import TestCase
- import functools
-
- def skip(reason):
- # Pasted from py27/lib/unittest/case.py
- """
- Unconditionally skip a test.
- """
- def decorator(test_item):
- if not (isinstance(test_item, type) and issubclass(test_item, TestCase)):
- @functools.wraps(test_item)
- def skip_wrapper(*args, **kwargs):
- pass
- test_item = skip_wrapper
-
- test_item.__unittest_skip__ = True
- test_item.__unittest_skip_why__ = reason
- return test_item
- return decorator
-
- unittest.skip = skip
-
-
# xml.etree.parse only throws ParseError for python >= 2.7
try:
from xml.etree import ParseError as ETParseError
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index a231f191..1b710a03 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -17,7 +17,7 @@ def api_view(http_method_names):
)
# Note, the above allows us to set the docstring.
- # It is the equivelent of:
+ # It is the equivalent of:
#
# class WrappedAPIView(APIView):
# pass
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index b8e1e2ad..482a3d48 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,8 +1,11 @@
import copy
import datetime
import inspect
+import re
import warnings
+from io import BytesIO
+
from django.core import validators
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix
@@ -32,6 +35,7 @@ class Field(object):
creation_counter = 0
empty = ''
type_name = None
+ _use_files = None
form_field_class = forms.CharField
def __init__(self, source=None):
@@ -52,8 +56,10 @@ class Field(object):
self.parent = parent
self.root = parent.root or parent
self.context = self.root.context
+ if self.root.partial:
+ self.required = False
- def field_from_native(self, data, field_name, into):
+ def field_from_native(self, data, files, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
@@ -168,7 +174,7 @@ class WritableField(Field):
if errors:
raise ValidationError(errors)
- def field_from_native(self, data, field_name, into):
+ def field_from_native(self, data, files, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
@@ -177,7 +183,10 @@ class WritableField(Field):
return
try:
- native = data[field_name]
+ if self._use_files:
+ native = files[field_name]
+ else:
+ native = data[field_name]
except KeyError:
if self.default is not None:
native = self.default
@@ -211,8 +220,19 @@ class ModelField(WritableField):
self.model_field = kwargs.pop('model_field')
except:
raise ValueError("ModelField requires 'model_field' kwarg")
+
+ self.min_length = kwargs.pop('min_length',
+ getattr(self.model_field, 'min_length', None))
+ self.max_length = kwargs.pop('max_length',
+ getattr(self.model_field, 'max_length', None))
+
super(ModelField, self).__init__(*args, **kwargs)
+ if self.min_length is not None:
+ self.validators.append(validators.MinLengthValidator(self.min_length))
+ if self.max_length is not None:
+ self.validators.append(validators.MaxLengthValidator(self.max_length))
+
def from_native(self, value):
rel = getattr(self.model_field, "rel", None)
if rel is not None:
@@ -319,13 +339,13 @@ class RelatedField(WritableField):
choices = property(_get_choices, _set_choices)
- ### Regular serializier stuff...
+ ### Regular serializer stuff...
def field_to_native(self, obj, field_name):
value = getattr(obj, self.source or field_name)
return self.to_native(value)
- def field_from_native(self, data, field_name, into):
+ def field_from_native(self, data, files, field_name, into):
if self.read_only:
return
@@ -343,7 +363,7 @@ class ManyRelatedMixin(object):
value = getattr(obj, self.source or field_name)
return [self.to_native(item) for item in value.all()]
- def field_from_native(self, data, field_name, into):
+ def field_from_native(self, data, files, field_name, into):
if self.read_only:
return
@@ -528,7 +548,10 @@ class HyperlinkedRelatedField(RelatedField):
view_name = self.view_name
request = self.context.get('request', None)
format = self.format or self.context.get('format', None)
- kwargs = {self.pk_url_kwarg: obj.pk}
+ pk = getattr(obj, 'pk', None)
+ if pk is None:
+ return
+ kwargs = {self.pk_url_kwarg: pk}
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except:
@@ -705,6 +728,23 @@ class CharField(WritableField):
return smart_unicode(value)
+class URLField(CharField):
+ type_name = 'URLField'
+
+ def __init__(self, **kwargs):
+ kwargs['max_length'] = kwargs.get('max_length', 200)
+ kwargs['validators'] = [validators.URLValidator()]
+ super(URLField, self).__init__(**kwargs)
+
+
+class SlugField(CharField):
+ type_name = 'SlugField'
+
+ def __init__(self, *args, **kwargs):
+ kwargs['max_length'] = kwargs.get('max_length', 50)
+ super(SlugField, self).__init__(*args, **kwargs)
+
+
class ChoiceField(WritableField):
type_name = 'ChoiceField'
form_field_class = forms.ChoiceField
@@ -775,8 +815,37 @@ class EmailField(CharField):
return result
+class RegexField(CharField):
+ type_name = 'RegexField'
+
+ def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs):
+ super(RegexField, self).__init__(max_length, min_length, *args, **kwargs)
+ self.regex = regex
+
+ def _get_regex(self):
+ return self._regex
+
+ def _set_regex(self, regex):
+ if isinstance(regex, basestring):
+ regex = re.compile(regex)
+ self._regex = regex
+ if hasattr(self, '_regex_validator') and self._regex_validator in self.validators:
+ self.validators.remove(self._regex_validator)
+ self._regex_validator = validators.RegexValidator(regex=regex)
+ self.validators.append(self._regex_validator)
+
+ regex = property(_get_regex, _set_regex)
+
+ def __deepcopy__(self, memo):
+ result = copy.copy(self)
+ memo[id(self)] = result
+ result.validators = self.validators[:]
+ return result
+
+
class DateField(WritableField):
type_name = 'DateField'
+ widget = widgets.DateInput
form_field_class = forms.DateField
default_error_messages = {
@@ -815,6 +884,7 @@ class DateField(WritableField):
class DateTimeField(WritableField):
type_name = 'DateTimeField'
+ widget = widgets.DateTimeInput
form_field_class = forms.DateTimeField
default_error_messages = {
@@ -915,3 +985,111 @@ class FloatField(WritableField):
except (TypeError, ValueError):
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
+
+
+class FileField(WritableField):
+ _use_files = True
+ type_name = 'FileField'
+ form_field_class = forms.FileField
+ widget = widgets.FileInput
+
+ default_error_messages = {
+ 'invalid': _("No file was submitted. Check the encoding type on the form."),
+ 'missing': _("No file was submitted."),
+ 'empty': _("The submitted file is empty."),
+ 'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'),
+ 'contradiction': _('Please either submit a file or check the clear checkbox, not both.')
+ }
+
+ def __init__(self, *args, **kwargs):
+ self.max_length = kwargs.pop('max_length', None)
+ self.allow_empty_file = kwargs.pop('allow_empty_file', False)
+ super(FileField, self).__init__(*args, **kwargs)
+
+ def from_native(self, data):
+ if data in validators.EMPTY_VALUES:
+ return None
+
+ # UploadedFile objects should have name and size attributes.
+ try:
+ file_name = data.name
+ file_size = data.size
+ except AttributeError:
+ raise ValidationError(self.error_messages['invalid'])
+
+ if self.max_length is not None and len(file_name) > self.max_length:
+ error_values = {'max': self.max_length, 'length': len(file_name)}
+ raise ValidationError(self.error_messages['max_length'] % error_values)
+ if not file_name:
+ raise ValidationError(self.error_messages['invalid'])
+ if not self.allow_empty_file and not file_size:
+ raise ValidationError(self.error_messages['empty'])
+
+ return data
+
+ def to_native(self, value):
+ return value.name
+
+
+class ImageField(FileField):
+ _use_files = True
+ form_field_class = forms.ImageField
+
+ default_error_messages = {
+ 'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."),
+ }
+
+ def from_native(self, data):
+ """
+ Checks that the file-upload field data contains a valid image (GIF, JPG,
+ PNG, possibly others -- whatever the Python Imaging Library supports).
+ """
+ f = super(ImageField, self).from_native(data)
+ if f is None:
+ return None
+
+ # Try to import PIL in either of the two ways it can end up installed.
+ try:
+ from PIL import Image
+ except ImportError:
+ import Image
+
+ # We need to get a file object for PIL. We might have a path or we might
+ # have to read the data into memory.
+ if hasattr(data, 'temporary_file_path'):
+ file = data.temporary_file_path()
+ else:
+ if hasattr(data, 'read'):
+ file = BytesIO(data.read())
+ else:
+ file = BytesIO(data['content'])
+
+ try:
+ # load() could spot a truncated JPEG, but it loads the entire
+ # image in memory, which is a DoS vector. See #3848 and #18520.
+ # verify() must be called immediately after the constructor.
+ Image.open(file).verify()
+ except ImportError:
+ # Under PyPy, it is possible to import PIL. However, the underlying
+ # _imaging C module isn't available, so an ImportError will be
+ # raised. Catch and re-raise.
+ raise
+ except Exception: # Python Imaging Library doesn't recognize it as an image
+ raise ValidationError(self.error_messages['invalid_image'])
+ if hasattr(f, 'seek') and callable(f.seek):
+ f.seek(0)
+ return f
+
+
+class SerializerMethodField(Field):
+ """
+ A field that gets its value by calling a method on the serializer it's attached to.
+ """
+
+ def __init__(self, method_name):
+ self.method_name = method_name
+ super(SerializerMethodField, self).__init__()
+
+ def field_to_native(self, obj, field_name):
+ value = getattr(self.parent, self.method_name)(obj)
+ return self.to_native(value)
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
new file mode 100644
index 00000000..bcc87660
--- /dev/null
+++ b/rest_framework/filters.py
@@ -0,0 +1,59 @@
+from rest_framework.compat import django_filters
+
+FilterSet = django_filters and django_filters.FilterSet or None
+
+
+class BaseFilterBackend(object):
+ """
+ A base class from which all filter backend classes should inherit.
+ """
+
+ def filter_queryset(self, request, queryset, view):
+ """
+ Return a filtered queryset.
+ """
+ raise NotImplementedError(".filter_queryset() must be overridden.")
+
+
+class DjangoFilterBackend(BaseFilterBackend):
+ """
+ A filter backend that uses django-filter.
+ """
+ default_filter_set = FilterSet
+
+ def __init__(self):
+ assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed'
+
+ def get_filter_class(self, view):
+ """
+ Return the django-filters `FilterSet` used to filter the queryset.
+ """
+ filter_class = getattr(view, 'filter_class', None)
+ filter_fields = getattr(view, 'filter_fields', None)
+ view_model = getattr(view, 'model', None)
+
+ if filter_class:
+ filter_model = filter_class.Meta.model
+
+ assert issubclass(filter_model, view_model), \
+ 'FilterSet model %s does not match view model %s' % \
+ (filter_model, view_model)
+
+ return filter_class
+
+ if filter_fields:
+ class AutoFilterSet(self.default_filter_set):
+ class Meta:
+ model = view_model
+ fields = filter_fields
+ return AutoFilterSet
+
+ return None
+
+ def filter_queryset(self, request, queryset, view):
+ filter_class = self.get_filter_class(view)
+
+ if filter_class:
+ return filter_class(request.GET, queryset=queryset)
+
+ return queryset
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 45cedd8b..dd8dfcf8 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -14,6 +14,8 @@ class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
+
+ model = None
serializer_class = None
model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
@@ -30,8 +32,10 @@ class GenericAPIView(views.APIView):
def get_serializer_class(self):
"""
Return the class to use for the serializer.
- Use `self.serializer_class`, falling back to constructing a
- model serializer class from `self.model_serializer_class`
+
+ Defaults to using `self.serializer_class`, falls back to constructing a
+ model serializer class using `self.model_serializer_class`, with
+ `self.model` as the model.
"""
serializer_class = self.serializer_class
@@ -44,11 +48,13 @@ class GenericAPIView(views.APIView):
return serializer_class
def get_serializer(self, instance=None, data=None, files=None):
- # TODO: add support for files
- # TODO: add support for seperate serializer/deserializer
+ """
+ Return the serializer instance that should be used for validating and
+ deserializing input, and for serializing output.
+ """
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
- return serializer_class(instance, data=data, context=context)
+ return serializer_class(instance, data=data, files=files, context=context)
class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
@@ -56,37 +62,59 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
Base class for generic views onto a queryset.
"""
- pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
paginate_by = api_settings.PAGINATE_BY
+ paginate_by_param = api_settings.PAGINATE_BY_PARAM
+ pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
+ filter_backend = api_settings.FILTER_BACKEND
+
+ def filter_queryset(self, queryset):
+ """
+ Given a queryset, filter it with whichever filter backend is in use.
+ """
+ if not self.filter_backend:
+ return queryset
+ backend = self.filter_backend()
+ return backend.filter_queryset(self.request, queryset, self)
- def get_pagination_serializer_class(self):
+ def get_pagination_serializer(self, page=None):
"""
- Return the class to use for the pagination serializer.
+ Return a serializer instance to use with paginated data.
"""
class SerializerClass(self.pagination_serializer_class):
class Meta:
object_serializer_class = self.get_serializer_class()
- return SerializerClass
-
- def get_pagination_serializer(self, page=None):
- pagination_serializer_class = self.get_pagination_serializer_class()
+ pagination_serializer_class = SerializerClass
context = self.get_serializer_context()
return pagination_serializer_class(instance=page, context=context)
+ def get_paginate_by(self, queryset):
+ """
+ Return the size of pages to use with pagination.
+ """
+ if self.paginate_by_param:
+ query_params = self.request.QUERY_PARAMS
+ try:
+ return int(query_params[self.paginate_by_param])
+ except (KeyError, ValueError):
+ pass
+ return self.paginate_by
+
class SingleObjectAPIView(SingleObjectMixin, GenericAPIView):
"""
Base class for generic views onto a model instance.
"""
+
pk_url_kwarg = 'pk' # Not provided in Django 1.3
slug_url_kwarg = 'slug' # Not provided in Django 1.3
+ slug_field = 'slug'
- def get_object(self):
+ def get_object(self, queryset=None):
"""
Override default to add support for object-level permissions.
"""
- obj = super(SingleObjectAPIView, self).get_object()
+ obj = super(SingleObjectAPIView, self).get_object(queryset)
if not self.has_permission(self.request, obj):
self.permission_denied(self.request)
return obj
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 6824a4d2..1edcfa5c 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -15,13 +15,20 @@ class CreateModelMixin(object):
Should be mixed in with any `BaseView`.
"""
def create(self, request, *args, **kwargs):
- serializer = self.get_serializer(data=request.DATA)
+ serializer = self.get_serializer(data=request.DATA, files=request.FILES)
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
- return Response(serializer.data, status=status.HTTP_201_CREATED)
+ headers = self.get_success_headers(serializer.data)
+ return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ def get_success_headers(self, data):
+ try:
+ return {'Location': data['url']}
+ except (TypeError, KeyError):
+ return {}
+
def pre_save(self, obj):
pass
@@ -34,14 +41,16 @@ class ListModelMixin(object):
empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs):
- self.object_list = self.get_queryset()
+ queryset = self.get_queryset()
+ self.object_list = self.filter_queryset(queryset)
# Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets.
allow_empty = self.get_allow_empty()
- if not allow_empty and len(self.object_list) == 0:
- error_args = {'class_name': self.__class__.__name__}
- raise Http404(self.empty_error % error_args)
+ if not allow_empty and not self.object_list:
+ class_name = self.__class__.__name__
+ error_msg = self.empty_error % {'class_name': class_name}
+ raise Http404(error_msg)
# Pagination size is set by the `.paginate_by` attribute,
# which may be `None` to disable pagination.
@@ -75,17 +84,18 @@ class UpdateModelMixin(object):
def update(self, request, *args, **kwargs):
try:
self.object = self.get_object()
- success_status = status.HTTP_200_OK
+ created = False
except Http404:
self.object = None
- success_status = status.HTTP_201_CREATED
+ created = True
- serializer = self.get_serializer(self.object, data=request.DATA)
+ serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES)
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
- return Response(serializer.data, status=success_status)
+ status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK
+ return Response(serializer.data, status=status_code)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 131718fd..d241ade7 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -1,4 +1,5 @@
from rest_framework import serializers
+from rest_framework.templatetags.rest_framework import replace_query_param
# TODO: Support URLconf kwarg-style paging
@@ -7,30 +8,30 @@ class NextPageField(serializers.Field):
"""
Field that returns a link to the next page in paginated results.
"""
+ page_field = 'page'
+
def to_native(self, value):
if not value.has_next():
return None
page = value.next_page_number()
request = self.context.get('request')
- relative_url = '?page=%d' % page
- if request:
- return request.build_absolute_uri(relative_url)
- return relative_url
+ url = request and request.build_absolute_uri() or ''
+ return replace_query_param(url, self.page_field, page)
class PreviousPageField(serializers.Field):
"""
Field that returns a link to the previous page in paginated results.
"""
+ page_field = 'page'
+
def to_native(self, value):
if not value.has_previous():
return None
page = value.previous_page_number()
request = self.context.get('request')
- relative_url = '?page=%d' % page
- if request:
- return request.build_absolute_uri('?page=%d' % page)
- return relative_url
+ url = request and request.build_absolute_uri() or ''
+ return replace_query_param(url, self.page_field, page)
class PaginationSerializerOptions(serializers.SerializerOptions):
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 4f3aa02c..25a32baa 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -4,7 +4,7 @@ Renderers are used to serialize a response into specific media types.
They give us a generic way of being able to handle various media types
on the response, such as JSON encoded data or HTML output.
-REST framework also provides an HTML renderer the renders the browseable API.
+REST framework also provides an HTML renderer the renders the browsable API.
"""
import copy
import string
@@ -19,7 +19,7 @@ from rest_framework.request import clone_request
from rest_framework.utils import dict2xml
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
-from rest_framework import VERSION
+from rest_framework import VERSION, status
from rest_framework import serializers, parsers
@@ -306,9 +306,8 @@ class BrowsableAPIRenderer(BaseRenderer):
return True
def serializer_to_form_fields(self, serializer):
-
fields = {}
- for k, v in serializer.get_fields(True).items():
+ for k, v in serializer.get_fields().items():
if getattr(v, 'read_only', True):
continue
@@ -457,7 +456,7 @@ class BrowsableAPIRenderer(BaseRenderer):
# Munge DELETE Response code to allow us to return content
# (Do this *after* we've rendered the template so that we include
# the normal deletion response code in the output)
- if response.status_code == 204:
- response.status_code = 200
+ if response.status_code == status.HTTP_204_NO_CONTENT:
+ response.status_code = status.HTTP_200_OK
return ret
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 0de01204..be78c43a 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -15,14 +15,17 @@ class Response(SimpleTemplateResponse):
Alters the init arguments slightly.
For example, drop 'template_name', and instead use 'data'.
- Setting 'renderer' and 'media_type' will typically be defered,
+ Setting 'renderer' and 'media_type' will typically be deferred,
For example being set automatically by the `APIView`.
"""
super(Response, self).__init__(None, status=status)
self.data = data
- self.headers = headers and headers[:] or []
self.template_name = template_name
self.exception = exception
+
+ if headers:
+ for name,value in headers.iteritems():
+ self[name] = value
@property
def rendered_content(self):
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index b48f85e4..dd5d9dc3 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -107,6 +107,7 @@ import django
if django.VERSION < (1, 3):
INSTALLED_APPS += ('staticfiles',)
+
# If we're running on the Jenkins server we want to archive the coverage reports as XML.
import os
if os.environ.get('HUDSON_URL', None):
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 4f68ada6..4519ab05 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -60,7 +60,7 @@ def _get_declared_fields(bases, attrs):
# If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
- # in order to the correct order of fields.
+ # in order to maintain the correct order of fields.
for base in bases[::-1]:
if hasattr(base, 'base_fields'):
fields = base.base_fields.items() + fields
@@ -89,50 +89,54 @@ class BaseSerializer(Field):
pass
_options_class = SerializerOptions
- _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations.
+ _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations.
- def __init__(self, instance=None, data=None, context=None, **kwargs):
+ def __init__(self, instance=None, data=None, files=None, context=None, partial=False, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
- self.fields = copy.deepcopy(self.base_fields)
self.parent = None
self.root = None
+ self.partial = partial
self.context = context or {}
self.init_data = data
+ self.init_files = files
self.object = instance
+ self.fields = self.get_fields()
self._data = None
+ self._files = None
self._errors = None
#####
# Methods to determine which fields to use when (de)serializing objects.
- def default_fields(self, nested=False):
+ def get_default_fields(self):
"""
Return the complete set of default fields for the object, as a dict.
"""
return {}
- def get_fields(self, nested=False):
+ def get_fields(self):
"""
Returns the complete set of fields for the object as a dict.
This will be the set of any explicitly declared fields,
- plus the set of fields returned by default_fields().
+ plus the set of fields returned by get_default_fields().
"""
ret = SortedDict()
# Get the explicitly declared fields
- for key, field in self.fields.items():
+ base_fields = copy.deepcopy(self.base_fields)
+ for key, field in base_fields.items():
ret[key] = field
# Set up the field
field.initialize(parent=self, field_name=key)
# Add in the default fields
- fields = self.default_fields(nested)
- for key, val in fields.items():
+ default_fields = self.get_default_fields()
+ for key, val in default_fields.items():
if key not in ret:
ret[key] = val
@@ -163,7 +167,7 @@ class BaseSerializer(Field):
self.opts.depth = parent.opts.depth - 1
#####
- # Methods to convert or revert from objects <--> primative representations.
+ # Methods to convert or revert from objects <--> primitive representations.
def get_field_key(self, field_name):
"""
@@ -179,24 +183,22 @@ class BaseSerializer(Field):
ret = self._dict_class()
ret.fields = {}
- fields = self.get_fields(nested=bool(self.opts.depth))
- for field_name, field in fields.items():
+ for field_name, field in self.fields.items():
key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name)
ret[key] = value
ret.fields[key] = field
return ret
- def restore_fields(self, data):
+ def restore_fields(self, data, files):
"""
Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields.
"""
- fields = self.get_fields(nested=bool(self.opts.depth))
reverted_data = {}
- for field_name, field in fields.items():
+ for field_name, field in self.fields.items():
try:
- field.field_from_native(data, field_name, reverted_data)
+ field.field_from_native(data, files, field_name, reverted_data)
except ValidationError as err:
self._errors[field_name] = list(err.messages)
@@ -206,10 +208,7 @@ class BaseSerializer(Field):
"""
Run `validate_<fieldname>()` and `validate()` methods on the serializer
"""
- # TODO: refactor this so we're not determining the fields again
- fields = self.get_fields(nested=bool(self.opts.depth))
-
- for field_name, field in fields.items():
+ for field_name, field in self.fields.items():
try:
validate_method = getattr(self, 'validate_%s' % field_name, None)
if validate_method:
@@ -244,23 +243,23 @@ class BaseSerializer(Field):
def to_native(self, obj):
"""
- Serialize objects -> primatives.
+ Serialize objects -> primitives.
"""
if hasattr(obj, '__iter__'):
return [self.convert_object(item) for item in obj]
return self.convert_object(obj)
- def from_native(self, data):
+ def from_native(self, data, files):
"""
- Deserialize primatives -> objects.
+ Deserialize primitives -> objects.
"""
if hasattr(data, '__iter__') and not isinstance(data, dict):
# TODO: error data when deserializing lists
return (self.from_native(item) for item in data)
self._errors = {}
- if data is not None:
- attrs = self.restore_fields(data)
+ if data is not None or files is not None:
+ attrs = self.restore_fields(data, files)
attrs = self.perform_validation(attrs)
else:
self._errors['non_field_errors'] = ['No input provided']
@@ -275,6 +274,9 @@ class BaseSerializer(Field):
"""
obj = getattr(obj, self.source or field_name)
+ if is_simple_callable(obj):
+ obj = obj()
+
# If the object has an "all" method, assume it's a relationship
if is_simple_callable(getattr(obj, 'all', None)):
return [self.to_native(item) for item in obj.all()]
@@ -288,7 +290,7 @@ class BaseSerializer(Field):
setting self.object if no errors occurred.
"""
if self._errors is None:
- obj = self.from_native(self.init_data)
+ obj = self.from_native(self.init_data, self.init_files)
if not self._errors:
self.object = obj
return self._errors
@@ -321,6 +323,7 @@ class ModelSerializerOptions(SerializerOptions):
def __init__(self, meta):
super(ModelSerializerOptions, self).__init__(meta)
self.model = getattr(meta, 'model', None)
+ self.read_only_fields = getattr(meta, 'read_only_fields', ())
class ModelSerializer(Serializer):
@@ -329,16 +332,10 @@ class ModelSerializer(Serializer):
"""
_options_class = ModelSerializerOptions
- def default_fields(self, nested=False):
+ def get_default_fields(self):
"""
Return all the fields that should be serialized for the model.
"""
- # TODO: Modfiy this so that it's called on init, and drop
- # serialize/obj/data arguments.
- #
- # We *could* provide a hook for dynamic fields, but
- # it'd be nice if the default was to generate fields statically
- # at the point of __init__
cls = self.opts.model
opts = get_concrete_model(cls)._meta
@@ -350,6 +347,7 @@ class ModelSerializer(Serializer):
fields += [field for field in opts.many_to_many if field.serialize]
ret = SortedDict()
+ nested = bool(self.opts.depth)
is_pk = True # First field in the list is the pk
for model_field in fields:
@@ -369,6 +367,12 @@ class ModelSerializer(Serializer):
field.initialize(parent=self, field_name=model_field.name)
ret[model_field.name] = field
+ for field_name in self.opts.read_only_fields:
+ assert field_name in ret, \
+ "read_only_fields on '%s' included invalid item '%s'" % \
+ (self.__class__.__name__, field_name)
+ ret[field_name].read_only = True
+
return ret
def get_pk_field(self, model_field):
@@ -381,7 +385,10 @@ class ModelSerializer(Serializer):
"""
Creates a default instance of a nested relational field.
"""
- return ModelSerializer()
+ class NestedModelSerializer(ModelSerializer):
+ class Meta:
+ model = model_field.rel.to
+ return NestedModelSerializer()
def get_related_field(self, model_field, to_many=False):
"""
@@ -417,6 +424,10 @@ class ModelSerializer(Serializer):
kwargs['choices'] = model_field.flatchoices
return ChoiceField(**kwargs)
+ max_length = getattr(model_field, 'max_length', None)
+ if max_length:
+ kwargs['max_length'] = max_length
+
field_mapping = {
models.FloatField: FloatField,
models.IntegerField: IntegerField,
@@ -427,9 +438,13 @@ class ModelSerializer(Serializer):
models.DateField: DateField,
models.EmailField: EmailField,
models.CharField: CharField,
+ models.URLField: URLField,
+ models.SlugField: SlugField,
models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField,
+ models.FileField: FileField,
+ models.ImageField: ImageField,
}
try:
return field_mapping[model_field.__class__](**kwargs)
@@ -442,11 +457,18 @@ class ModelSerializer(Serializer):
"""
self.m2m_data = {}
- if instance:
+ if instance is not None:
for key, val in attrs.items():
setattr(instance, key, val)
return instance
+ # Reverse relations
+ for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model():
+ field_name = obj.field.related_query_name()
+ if field_name in attrs:
+ self.m2m_data[field_name] = attrs.pop(field_name)
+
+ # Forward relations
for field in self.opts.model._meta.many_to_many:
if field.name in attrs:
self.m2m_data[field.name] = attrs.pop(field.name)
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 9c40a214..ee24a4ad 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -54,18 +54,26 @@ DEFAULTS = {
'user': None,
'anon': None,
},
+
+ # Pagination
'PAGINATE_BY': None,
+ 'PAGINATE_BY_PARAM': None,
+
+ # Filtering
+ 'FILTER_BACKEND': None,
+ # Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
+ # Browser enhancements
'FORM_METHOD_OVERRIDE': '_method',
'FORM_CONTENT_OVERRIDE': '_content',
'FORM_CONTENTTYPE_OVERRIDE': '_content_type',
'URL_ACCEPT_OVERRIDE': 'accept',
'URL_FORMAT_OVERRIDE': 'format',
- 'FORMAT_SUFFIX_KWARG': 'format'
+ 'FORMAT_SUFFIX_KWARG': 'format',
}
@@ -79,6 +87,7 @@ IMPORT_STRINGS = (
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
'DEFAULT_MODEL_SERIALIZER_CLASS',
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
+ 'FILTER_BACKEND',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
)
@@ -142,8 +151,15 @@ class APISettings(object):
if val and attr in self.import_strings:
val = perform_import(val, attr)
+ self.validate_setting(attr, val)
+
# Cache the result
setattr(self, attr, val)
return val
+ def validate_setting(self, attr, val):
+ if attr == 'FILTER_BACKEND' and val is not None:
+ # Make sure we can initialize the class
+ val()
+
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index c9b6eb10..4e0181ee 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -11,6 +11,18 @@ import string
register = template.Library()
+def replace_query_param(url, key, val):
+ """
+ Given a URL and a key/val pair, set or replace an item in the query
+ parameters of the URL, and return the new URL.
+ """
+ (scheme, netloc, path, query, fragment) = urlsplit(url)
+ query_dict = QueryDict(query).copy()
+ query_dict[key] = val
+ query = query_dict.urlencode()
+ return urlunsplit((scheme, netloc, path, query, fragment))
+
+
# Regex for adding classes to html snippets
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
@@ -31,19 +43,6 @@ hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|
trailing_empty_content_re = re.compile(r'(?:<p>(?:&nbsp;|\s|<br \/>)*?</p>\s*)+\Z')
-# Helper function for 'add_query_param'
-def replace_query_param(url, key, val):
- """
- Given a URL and a key/val pair, set or replace an item in the query
- parameters of the URL, and return the new URL.
- """
- (scheme, netloc, path, query, fragment) = urlsplit(url)
- query_dict = QueryDict(query).copy()
- query_dict[key] = val
- query = query_dict.urlencode()
- return urlunsplit((scheme, netloc, path, query, fragment))
-
-
# And the template tags themselves...
@register.simple_tag
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index 8ab4c4e4..96ca9f52 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -1,4 +1,4 @@
-from django.conf.urls.defaults import patterns
+from django.conf.urls.defaults import patterns, include
from django.contrib.auth.models import User
from django.test import Client, TestCase
@@ -27,6 +27,7 @@ MockView.authentication_classes += (TokenAuthentication,)
urlpatterns = patterns('',
(r'^$', MockView.as_view()),
+ (r'^auth-token/', 'rest_framework.authtoken.views.obtain_auth_token'),
)
@@ -152,3 +153,33 @@ class TokenAuthTests(TestCase):
self.token.delete()
token = Token.objects.create(user=self.user)
self.assertTrue(bool(token.key))
+
+ def test_token_login_json(self):
+ """Ensure token login view using JSON POST works."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/login/',
+ json.dumps({'username': self.username, 'password': self.password}), 'application/json')
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(json.loads(response.content)['token'], self.key)
+
+ def test_token_login_json_bad_creds(self):
+ """Ensure token login view using JSON POST fails if bad credentials are used."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/login/',
+ json.dumps({'username': self.username, 'password': "badpass"}), 'application/json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_json_missing_fields(self):
+ """Ensure token login view using JSON POST fails if missing fields."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/login/',
+ json.dumps({'username': self.username}), 'application/json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_form(self):
+ """Ensure token login view using form POST works."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/login/',
+ {'username': self.username, 'password': self.password})
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(json.loads(response.content)['token'], self.key)
diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py
index 61d7f7b1..5dd57b7c 100644
--- a/rest_framework/tests/files.py
+++ b/rest_framework/tests/files.py
@@ -1,34 +1,39 @@
-# from django.test import TestCase
-# from django import forms
+import StringIO
+import datetime
-# from django.test.client import RequestFactory
-# from rest_framework.views import View
-# from rest_framework.response import Response
+from django.test import TestCase
-# import StringIO
+from rest_framework import serializers
-# class UploadFilesTests(TestCase):
-# """Check uploading of files"""
-# def setUp(self):
-# self.factory = RequestFactory()
+class UploadedFile(object):
+ def __init__(self, file, created=None):
+ self.file = file
+ self.created = created or datetime.datetime.now()
-# def test_upload_file(self):
-# class FileForm(forms.Form):
-# file = forms.FileField()
+class UploadedFileSerializer(serializers.Serializer):
+ file = serializers.FileField()
+ created = serializers.DateTimeField()
-# class MockView(View):
-# permissions = ()
-# form = FileForm
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.file = attrs['file']
+ instance.created = attrs['created']
+ return instance
+ return UploadedFile(**attrs)
-# def post(self, request, *args, **kwargs):
-# return Response({'FILE_NAME': self.CONTENT['file'].name,
-# 'FILE_CONTENT': self.CONTENT['file'].read()})
-# file = StringIO.StringIO('stuff')
-# file.name = 'stuff.txt'
-# request = self.factory.post('/', {'file': file})
-# view = MockView.as_view()
-# response = view(request)
-# self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"})
+class FileSerializerTests(TestCase):
+
+ def test_create(self):
+ now = datetime.datetime.now()
+ file = StringIO.StringIO('stuff')
+ file.name = 'stuff.txt'
+ file.size = file.len
+ serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
+ uploaded_file = UploadedFile(file=file, created=now)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.object.created, uploaded_file.created)
+ self.assertEquals(serializer.object.file, uploaded_file.file)
+ self.assertFalse(serializer.object is uploaded_file)
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
new file mode 100644
index 00000000..af2e6c2e
--- /dev/null
+++ b/rest_framework/tests/filterset.py
@@ -0,0 +1,168 @@
+import datetime
+from decimal import Decimal
+from django.test import TestCase
+from django.test.client import RequestFactory
+from django.utils import unittest
+from rest_framework import generics, status, filters
+from rest_framework.compat import django_filters
+from rest_framework.tests.models import FilterableItem, BasicModel
+
+factory = RequestFactory()
+
+
+if django_filters:
+ # Basic filter on a list view.
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_fields = ['decimal', 'date']
+ filter_backend = filters.DjangoFilterBackend
+
+ # These class are used to test a filter class.
+ class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterClassRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backend = filters.DjangoFilterBackend
+
+ # These classes are used to test a misconfigured filter class.
+ class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+
+ class Meta:
+ model = BasicModel
+ fields = ['text']
+
+ class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = MisconfiguredFilter
+ filter_backend = filters.DjangoFilterBackend
+
+
+class IntegrationTestFiltering(TestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
+ def setUp(self):
+ """
+ Create 10 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(10):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
+ ]
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_fields_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the date filter works.
+ search_date = datetime.date(2012, 9, 22)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] == search_date]
+ self.assertEquals(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_class_root_view(self):
+ """
+ GET requests to filtered ListCreateAPIView that have a filter_class set
+ should return filtered results.
+ """
+ view = FilterClassRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ # Tests that the decimal filter set with 'lt' in the filter class works.
+ search_decimal = Decimal('4.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] < search_decimal]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the date filter set with 'gt' in the filter class works.
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the text filter set with 'icontains' in the filter class works.
+ search_text = 'ff'
+ request = factory.get('/?text=%s' % search_text)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if search_text in f['text'].lower()]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date and
+ f['decimal'] < search_decimal]
+ self.assertEquals(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_incorrectly_configured_filter(self):
+ """
+ An error should be displayed when the filter class is misconfigured.
+ """
+ view = IncorrectlyConfiguredRootView.as_view()
+
+ request = factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_unknown_filter(self):
+ """
+ GET requests with filters that aren't configured should return 200.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ search_integer = 10
+ request = factory.get('/?integer=%s' % search_integer)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index f71e2e28..d7effce7 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -2,18 +2,19 @@ from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status, serializers
-from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo
+from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel
factory = RequestFactory()
class BlogPostCommentSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail')
text = serializers.CharField()
blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail')
class Meta:
model = BlogPostComment
- fields = ('text', 'blog_post_url')
+ fields = ('text', 'blog_post_url', 'url')
class PhotoSerializer(serializers.Serializer):
@@ -53,6 +54,9 @@ class BlogPostCommentListCreate(generics.ListCreateAPIView):
model = BlogPostComment
serializer_class = BlogPostCommentSerializer
+class BlogPostCommentDetail(generics.RetrieveAPIView):
+ model = BlogPostComment
+ serializer_class = BlogPostCommentSerializer
class BlogPostDetail(generics.RetrieveAPIView):
model = BlogPost
@@ -67,6 +71,11 @@ class AlbumDetail(generics.RetrieveAPIView):
model = Album
+class OptionalRelationDetail(generics.RetrieveAPIView):
+ model = OptionalRelationModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
@@ -75,8 +84,10 @@ urlpatterns = patterns('',
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'),
+ url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'),
url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'),
- url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list')
+ url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'),
+ url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'),
)
@@ -185,6 +196,7 @@ class TestCreateWithForeignKeys(TestCase):
request = factory.post('/comments/', data=data)
response = self.create_view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response['Location'], 'http://testserver/comments/1/')
self.assertEqual(self.post.blogpostcomment_set.count(), 1)
self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
@@ -209,5 +221,29 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
request = factory.post('/photos/', data=data)
response = self.list_create_view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
self.assertEqual(self.post.photo_set.count(), 1)
self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
+
+
+class TestOptionalRelationHyperlinkedView(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 1 OptionalRelationModel intances.
+ """
+ OptionalRelationModel().save()
+ self.objects = OptionalRelationModel.objects
+ self.detail_view = OptionalRelationDetail.as_view()
+ self.data = {"url": "http://testserver/optionalrelation/1/", "other": None}
+
+ def test_get_detail_view(self):
+ """
+ GET requests to RetrieveAPIView with optional relations should return None
+ for non existing relations.
+ """
+ request = factory.get('/optionalrelationmodel-detail/1')
+ response = self.detail_view(request, pk=1).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 0e23734e..c35861c6 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -35,6 +35,13 @@ def foobar():
return 'foobar'
+class CustomField(models.CharField):
+
+ def __init__(self, *args, **kwargs):
+ kwargs['max_length'] = 12
+ super(CustomField, self).__init__(*args, **kwargs)
+
+
class RESTFrameworkModel(models.Model):
"""
Base for test models that sets app_label, so they play nicely.
@@ -95,6 +102,13 @@ class Bookmark(RESTFrameworkModel):
tags = GenericRelation(TaggedItem)
+# Model to test filtering.
+class FilterableItem(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
# Model for regression test for #285
class Comment(RESTFrameworkModel):
@@ -106,12 +120,16 @@ class Comment(RESTFrameworkModel):
class ActionItem(RESTFrameworkModel):
title = models.CharField(max_length=200)
done = models.BooleanField(default=False)
+ info = CustomField(default='---', max_length=12)
# Models for reverse relations
class BlogPost(RESTFrameworkModel):
title = models.CharField(max_length=100)
+ def get_first_comment(self):
+ return self.blogpostcomment_set.all()[0]
+
class BlogPostComment(RESTFrameworkModel):
text = models.TextField()
@@ -142,3 +160,13 @@ class Person(RESTFrameworkModel):
# Model for issue #324
class BlankFieldModel(RESTFrameworkModel):
title = models.CharField(max_length=100, blank=True)
+
+
+# Model for issue #380
+class OptionalRelationModel(RESTFrameworkModel):
+ other = models.ForeignKey('OptionalRelationModel', blank=True, null=True)
+
+
+# Model for RegexField
+class Book(RESTFrameworkModel):
+ isbn = models.CharField(max_length=13)
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 64e8d822..3062007d 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -1,8 +1,12 @@
+import datetime
+from decimal import Decimal
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
-from rest_framework import generics, status, pagination
-from rest_framework.tests.models import BasicModel
+from django.utils import unittest
+from rest_framework import generics, status, pagination, filters
+from rest_framework.compat import django_filters
+from rest_framework.tests.models import BasicModel, FilterableItem
factory = RequestFactory()
@@ -15,6 +19,36 @@ class RootView(generics.ListCreateAPIView):
paginate_by = 10
+if django_filters:
+ class DecimalFilter(django_filters.FilterSet):
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_class = DecimalFilter
+ filter_backend = filters.DjangoFilterBackend
+
+
+class DefaultPageSizeKwargView(generics.ListAPIView):
+ """
+ View for testing default paginate_by_param usage
+ """
+ model = BasicModel
+
+
+class PaginateByParamView(generics.ListAPIView):
+ """
+ View for testing custom paginate_by_param usage
+ """
+ model = BasicModel
+ paginate_by_param = 'page_size'
+
+
class IntegrationTestPagination(TestCase):
"""
Integration tests for paginated list views.
@@ -22,7 +56,7 @@ class IntegrationTestPagination(TestCase):
def setUp(self):
"""
- Create 26 BasicModel intances.
+ Create 26 BasicModel instances.
"""
for char in 'abcdefghijklmnopqrstuvwxyz':
BasicModel(text=char * 3).save()
@@ -62,9 +96,61 @@ class IntegrationTestPagination(TestCase):
self.assertNotEquals(response.data['previous'], None)
+class IntegrationTestPaginationAndFiltering(TestCase):
+
+ def setUp(self):
+ """
+ Create 50 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(26):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
+ ]
+ self.view = FilterFieldsRootView.as_view()
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_paginated_filtered_root_view(self):
+ """
+ GET requests to paginated filtered ListCreateAPIView should return
+ paginated results. The next and previous links should preserve the
+ filtered parameters.
+ """
+ request = factory.get('/?decimal=15.20')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[10:15])
+ self.assertEquals(response.data['next'], None)
+ self.assertNotEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+
class UnitTestPagination(TestCase):
"""
- Unit tests for pagination of primative objects.
+ Unit tests for pagination of primitive objects.
"""
def setUp(self):
@@ -85,3 +171,68 @@ class UnitTestPagination(TestCase):
self.assertEquals(serializer.data['next'], None)
self.assertEquals(serializer.data['previous'], '?page=2')
self.assertEquals(serializer.data['results'], self.objects[20:])
+
+
+class TestUnpaginated(TestCase):
+ """
+ Tests for list views without pagination.
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = DefaultPageSizeKwargView.as_view()
+
+ def test_unpaginated(self):
+ """
+ Tests the default page size for this view.
+ no page size --> no limit --> no meta data
+ """
+ request = factory.get('/')
+ response = self.view(request)
+ self.assertEquals(response.data, self.data)
+
+
+class TestCustomPaginateByParam(TestCase):
+ """
+ Tests for list views with default page size kwarg
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = PaginateByParamView.as_view()
+
+ def test_default_page_size(self):
+ """
+ Tests the default page size for this view.
+ no page size --> no limit --> no meta data
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEquals(response.data, self.data)
+
+ def test_paginate_by_param(self):
+ """
+ If paginate_by_param is set, the new kwarg should limit per view requests.
+ """
+ request = factory.get('/?page_size=5')
+ response = self.view(request).render()
+ self.assertEquals(response.data['count'], 13)
+ self.assertEquals(response.data['results'], self.data[:5])
diff --git a/rest_framework/tests/pk_relations.py b/rest_framework/tests/pk_relations.py
index 94709810..3dcc76f9 100644
--- a/rest_framework/tests/pk_relations.py
+++ b/rest_framework/tests/pk_relations.py
@@ -117,6 +117,25 @@ class PrimaryKeyManyToManyTests(TestCase):
]
self.assertEquals(serializer.data, expected)
+ def test_reverse_many_to_many_create(self):
+ data = {'id': 4, 'name': u'target-4', 'sources': [1, 3]}
+ serializer = ManyToManyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-4')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': u'target-3', 'sources': [3]},
+ {'id': 4, 'name': u'target-4', 'sources': [1, 3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
class PrimaryKeyForeignKeyTests(TestCase):
def setUp(self):
diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py
index 18b6af39..d7b75450 100644
--- a/rest_framework/tests/response.py
+++ b/rest_framework/tests/response.py
@@ -131,12 +131,6 @@ class RendererIntegrationTests(TestCase):
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
- @unittest.skip('can\'t pass because view is a simple Django view and response is an ImmediateResponse')
- def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
- """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
- resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
- self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
-
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index 8d1de429..61a05da1 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -2,7 +2,7 @@ import datetime
from django.test import TestCase
from rest_framework import serializers
from rest_framework.tests.models import (ActionItem, Anchor, BasicModel,
- BlankFieldModel, BlogPost, CallableDefaultValueModel, DefaultValueModel,
+ BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel)
@@ -40,7 +40,15 @@ class CommentSerializer(serializers.Serializer):
return instance
+class BookSerializer(serializers.ModelSerializer):
+ isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'})
+
+ class Meta:
+ model = Book
+
+
class ActionItemSerializer(serializers.ModelSerializer):
+
class Meta:
model = ActionItem
@@ -51,6 +59,7 @@ class PersonSerializer(serializers.ModelSerializer):
class Meta:
model = Person
fields = ('name', 'age', 'info')
+ read_only_fields = ('age',)
class BasicTests(TestCase):
@@ -106,8 +115,21 @@ class BasicTests(TestCase):
self.assertTrue(serializer.object is expected)
self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
+ def test_partial_update(self):
+ msg = 'Merry New Year!'
+ partial_data = {'content': msg}
+ serializer = CommentSerializer(self.comment, data=partial_data)
+ self.assertEquals(serializer.is_valid(), False)
+ serializer = CommentSerializer(self.comment, data=partial_data, partial=True)
+ expected = self.comment
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEquals(serializer.object, expected)
+ self.assertTrue(serializer.object is expected)
+ self.assertEquals(serializer.data['content'], msg)
+
def test_model_fields_as_expected(self):
- """ Make sure that the fields returned are the same as defined
+ """
+ Make sure that the fields returned are the same as defined
in the Meta data
"""
serializer = PersonSerializer(self.person)
@@ -115,12 +137,25 @@ class BasicTests(TestCase):
set(['name', 'age', 'info']))
def test_field_with_dictionary(self):
- """ Make sure that dictionaries from fields are left intact
+ """
+ Make sure that dictionaries from fields are left intact
"""
serializer = PersonSerializer(self.person)
expected = self.person_data
self.assertEquals(serializer.data['info'], expected)
+ def test_read_only_fields(self):
+ """
+ Attempting to update fields set as read_only should have no effect.
+ """
+
+ serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(serializer.errors, {})
+ # Assert age is unchanged (35)
+ self.assertEquals(instance.age, self.person_data['age'])
+
class ValidationTests(TestCase):
def setUp(self):
@@ -224,6 +259,42 @@ class ValidationTests(TestCase):
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.errors, {})
+ def test_modelserializer_max_length_exceeded(self):
+ data = {
+ 'title': 'x' * 201,
+ }
+ serializer = ActionItemSerializer(data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']})
+
+ def test_default_modelfield_max_length_exceeded(self):
+ data = {
+ 'title': 'Testing "info" field...',
+ 'info': 'x' * 13,
+ }
+ serializer = ActionItemSerializer(data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']})
+
+
+class RegexValidationTest(TestCase):
+ def test_create_failed(self):
+ serializer = BookSerializer(data={'isbn': '1234567890'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+
+ serializer = BookSerializer(data={'isbn': '12345678901234'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+
+ serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+
+ def test_create_success(self):
+ serializer = BookSerializer(data={'isbn': '1234567890123'})
+ self.assertTrue(serializer.is_valid())
+
class MetadataTests(TestCase):
def test_empty(self):
@@ -446,7 +517,10 @@ class CallableDefaultValueTests(TestCase):
class ManyRelatedTests(TestCase):
- def setUp(self):
+ def test_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
class BlogPostCommentSerializer(serializers.Serializer):
text = serializers.CharField()
@@ -455,14 +529,7 @@ class ManyRelatedTests(TestCase):
title = serializers.CharField()
comments = BlogPostCommentSerializer(source='blogpostcomment_set')
- self.serializer_class = BlogPostSerializer
-
- def test_reverse_relations(self):
- post = BlogPost.objects.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I hate this blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- serializer = self.serializer_class(instance=post)
+ serializer = BlogPostSerializer(instance=post)
expected = {
'title': 'Test blog post',
'comments': [
@@ -473,6 +540,59 @@ class ManyRelatedTests(TestCase):
self.assertEqual(serializer.data, expected)
+ def test_callable_source(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ first_comment = BlogPostCommentSerializer(source='get_first_comment')
+
+ serializer = BlogPostSerializer(post)
+
+ expected = {
+ 'title': 'Test blog post',
+ 'first_comment': {'text': 'I love this blog post'}
+ }
+ self.assertEqual(serializer.data, expected)
+
+
+class SerializerMethodFieldTests(TestCase):
+ def setUp(self):
+
+ class BoopSerializer(serializers.Serializer):
+ beep = serializers.SerializerMethodField('get_beep')
+ boop = serializers.Field()
+ boop_count = serializers.SerializerMethodField('get_boop_count')
+
+ def get_beep(self, obj):
+ return 'hello!'
+
+ def get_boop_count(self, obj):
+ return len(obj.boop)
+
+ self.serializer_class = BoopSerializer
+
+ def test_serializer_method_field(self):
+
+ class MyModel(object):
+ boop = ['a', 'b', 'c']
+
+ source_data = MyModel()
+
+ serializer = self.serializer_class(source_data)
+
+ expected = {
+ 'beep': u'hello!',
+ 'boop': [u'a', u'b', u'c'],
+ 'boop_count': 3,
+ }
+
+ self.assertEqual(serializer.data, expected)
+
# Test for issue #324
class BlankFieldTests(TestCase):
diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py
index 0b94c25b..4b98b941 100644
--- a/rest_framework/tests/throttling.py
+++ b/rest_framework/tests/throttling.py
@@ -106,7 +106,7 @@ class ThrottlingTests(TestCase):
if expect is not None:
self.assertEquals(response['X-Throttle-Wait-Seconds'], expect)
else:
- self.assertFalse('X-Throttle-Wait-Seconds' in response.headers)
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
def test_seconds_fields(self):
"""
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index 316ccd19..0ad926fa 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -4,7 +4,7 @@ from rest_framework.settings import api_settings
def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
"""
- Supplement existing urlpatterns with corrosponding patterns that also
+ Supplement existing urlpatterns with corresponding patterns that also
include a '.format' suffix. Retains urlpattern ordering.
urlpatterns:
diff --git a/rest_framework/urls.py b/rest_framework/urls.py
index 1a81101f..bcdc23e7 100644
--- a/rest_framework/urls.py
+++ b/rest_framework/urls.py
@@ -1,7 +1,7 @@
"""
-Login and logout views for the browseable API.
+Login and logout views for the browsable API.
-Add these to your root URLconf if you're using the browseable API and
+Add these to your root URLconf if you're using the browsable API and
your API requires authentication.
The urls must be namespaced as 'rest_framework', and you should make sure
diff --git a/rest_framework/utils/__init__.py b/rest_framework/utils/__init__.py
index a59fff45..84fcb5db 100644
--- a/rest_framework/utils/__init__.py
+++ b/rest_framework/utils/__init__.py
@@ -1,7 +1,6 @@
from django.utils.encoding import smart_unicode
from django.utils.xmlutils import SimplerXMLGenerator
from rest_framework.compat import StringIO
-
import re
import xml.etree.ElementTree as ET
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 1afbd697..10bdd5a5 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -140,7 +140,7 @@ class APIView(View):
def http_method_not_allowed(self, request, *args, **kwargs):
"""
- Called if `request.method` does not corrospond to a handler method.
+ Called if `request.method` does not correspond to a handler method.
"""
raise exceptions.MethodNotAllowed(request.method)