aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/authtoken/migrations/0001_initial.py14
-rw-r--r--rest_framework/authtoken/models.py3
-rw-r--r--rest_framework/authtoken/serializers.py24
-rw-r--r--rest_framework/authtoken/views.py24
-rw-r--r--rest_framework/compat.py16
-rw-r--r--rest_framework/decorators.py2
-rw-r--r--rest_framework/fields.py152
-rw-r--r--rest_framework/filters.py2
-rw-r--r--rest_framework/generics.py52
-rw-r--r--rest_framework/mixins.py30
-rw-r--r--rest_framework/renderers.py12
-rw-r--r--rest_framework/response.py7
-rw-r--r--rest_framework/serializers.py60
-rw-r--r--rest_framework/settings.py11
-rw-r--r--rest_framework/tests/authentication.py33
-rw-r--r--rest_framework/tests/files.py55
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py9
-rw-r--r--rest_framework/tests/models.py13
-rw-r--r--rest_framework/tests/pagination.py82
-rw-r--r--rest_framework/tests/serializer.py85
-rw-r--r--rest_framework/tests/throttling.py2
-rw-r--r--rest_framework/urlpatterns.py2
-rw-r--r--rest_framework/urls.py4
-rw-r--r--rest_framework/views.py2
25 files changed, 579 insertions, 119 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index fd176603..88108a8d 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,3 +1,3 @@
-__version__ = '2.1.2'
+__version__ = '2.1.3'
VERSION = __version__ # synonym
diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py
index 9d750381..f4e052e4 100644
--- a/rest_framework/authtoken/migrations/0001_initial.py
+++ b/rest_framework/authtoken/migrations/0001_initial.py
@@ -5,13 +5,21 @@ from south.v2 import SchemaMigration
from django.db import models
+try:
+ from django.contrib.auth import get_user_model
+except ImportError: # django < 1.5
+ from django.contrib.auth.models import User
+else:
+ User = get_user_model()
+
+
class Migration(SchemaMigration):
def forwards(self, orm):
# Adding model 'Token'
db.create_table('authtoken_token', (
('key', self.gf('django.db.models.fields.CharField')(max_length=40, primary_key=True)),
- ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['auth.User'])),
+ ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['%s.%s' % (User._meta.app_label, User._meta.object_name)])),
('created', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)),
))
db.send_create_signal('authtoken', ['Token'])
@@ -36,7 +44,7 @@ class Migration(SchemaMigration):
'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
'name': ('django.db.models.fields.CharField', [], {'max_length': '50'})
},
- 'auth.user': {
+ "%s.%s" % (User._meta.app_label, User._meta.module_name): {
'Meta': {'object_name': 'User'},
'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}),
@@ -56,7 +64,7 @@ class Migration(SchemaMigration):
'Meta': {'object_name': 'Token'},
'created': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}),
'key': ('django.db.models.fields.CharField', [], {'max_length': '40', 'primary_key': 'True'}),
- 'user': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'auth_token'", 'unique': 'True', 'to': "orm['auth.User']"})
+ 'user': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'auth_token'", 'unique': 'True', 'to': "orm['%s.%s']" % (User._meta.app_label, User._meta.object_name)})
},
'contenttypes.contenttype': {
'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"},
diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py
index 5b3071aa..4da2aa62 100644
--- a/rest_framework/authtoken/models.py
+++ b/rest_framework/authtoken/models.py
@@ -1,6 +1,7 @@
import uuid
import hmac
from hashlib import sha1
+from rest_framework.compat import User
from django.db import models
@@ -9,7 +10,7 @@ class Token(models.Model):
The default authorization token model.
"""
key = models.CharField(max_length=40, primary_key=True)
- user = models.OneToOneField('auth.User', related_name='auth_token')
+ user = models.OneToOneField(User, related_name='auth_token')
created = models.DateTimeField(auto_now_add=True)
def save(self, *args, **kwargs):
diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py
new file mode 100644
index 00000000..a5ed6e6d
--- /dev/null
+++ b/rest_framework/authtoken/serializers.py
@@ -0,0 +1,24 @@
+from django.contrib.auth import authenticate
+from rest_framework import serializers
+
+class AuthTokenSerializer(serializers.Serializer):
+ username = serializers.CharField()
+ password = serializers.CharField()
+
+ def validate(self, attrs):
+ username = attrs.get('username')
+ password = attrs.get('password')
+
+ if username and password:
+ user = authenticate(username=username, password=password)
+
+ if user:
+ if not user.is_active:
+ raise serializers.ValidationError('User account is disabled.')
+ attrs['user'] = user
+ return attrs
+ else:
+ raise serializers.ValidationError('Unable to login with provided credentials.')
+ else:
+ raise serializers.ValidationError('Must include "username" and "password"')
+
diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py
index e69de29b..3ac674e2 100644
--- a/rest_framework/authtoken/views.py
+++ b/rest_framework/authtoken/views.py
@@ -0,0 +1,24 @@
+from rest_framework.views import APIView
+from rest_framework import status
+from rest_framework import parsers
+from rest_framework import renderers
+from rest_framework.response import Response
+from rest_framework.authtoken.models import Token
+from rest_framework.authtoken.serializers import AuthTokenSerializer
+
+class ObtainAuthToken(APIView):
+ throttle_classes = ()
+ permission_classes = ()
+ parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
+ renderer_classes = (renderers.JSONRenderer,)
+ model = Token
+
+ def post(self, request):
+ serializer = AuthTokenSerializer(data=request.DATA)
+ if serializer.is_valid():
+ token, created = Token.objects.get_or_create(user=serializer.object['user'])
+ return Response({'token': token.key})
+ return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+
+
+obtain_auth_token = ObtainAuthToken.as_view()
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 5055bfd3..09b76368 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -1,6 +1,6 @@
"""
The `compat` module provides support for backwards compatibility with older
-versions of django/python, and compatbility wrappers around optional packages.
+versions of django/python, and compatibility wrappers around optional packages.
"""
# flake8: noqa
import django
@@ -27,6 +27,20 @@ def get_concrete_model(model_cls):
return model_cls
+# Django 1.5 add support for custom auth user model
+if django.VERSION >= (1, 5):
+ from django.conf import settings
+ if hasattr(settings, 'AUTH_USER_MODEL'):
+ User = settings.AUTH_USER_MODEL
+ else:
+ from django.contrib.auth.models import User
+else:
+ try:
+ from django.contrib.auth.models import User
+ except ImportError:
+ raise ImportError(u"User model is not to be found.")
+
+
# First implementation of Django class-based views did not include head method
# in base View class - https://code.djangoproject.com/ticket/15668
if django.VERSION >= (1, 4):
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index a231f191..1b710a03 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -17,7 +17,7 @@ def api_view(http_method_names):
)
# Note, the above allows us to set the docstring.
- # It is the equivelent of:
+ # It is the equivalent of:
#
# class WrappedAPIView(APIView):
# pass
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 071746de..25d98645 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -4,6 +4,8 @@ import inspect
import re
import warnings
+from io import BytesIO
+
from django.core import validators
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix
@@ -32,6 +34,7 @@ class Field(object):
creation_counter = 0
empty = ''
type_name = None
+ _use_files = None
def __init__(self, source=None):
self.parent = None
@@ -52,7 +55,7 @@ class Field(object):
self.root = parent.root or parent
self.context = self.root.context
- def field_from_native(self, data, field_name, into):
+ def field_from_native(self, data, files, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
@@ -167,7 +170,7 @@ class WritableField(Field):
if errors:
raise ValidationError(errors)
- def field_from_native(self, data, field_name, into):
+ def field_from_native(self, data, files, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
@@ -176,7 +179,10 @@ class WritableField(Field):
return
try:
- native = data[field_name]
+ if self._use_files:
+ native = files[field_name]
+ else:
+ native = data[field_name]
except KeyError:
if self.default is not None:
native = self.default
@@ -210,8 +216,19 @@ class ModelField(WritableField):
self.model_field = kwargs.pop('model_field')
except:
raise ValueError("ModelField requires 'model_field' kwarg")
+
+ self.min_length = kwargs.pop('min_length',
+ getattr(self.model_field, 'min_length', None))
+ self.max_length = kwargs.pop('max_length',
+ getattr(self.model_field, 'max_length', None))
+
super(ModelField, self).__init__(*args, **kwargs)
+ if self.min_length is not None:
+ self.validators.append(validators.MinLengthValidator(self.min_length))
+ if self.max_length is not None:
+ self.validators.append(validators.MaxLengthValidator(self.max_length))
+
def from_native(self, value):
rel = getattr(self.model_field, "rel", None)
if rel is not None:
@@ -318,13 +335,13 @@ class RelatedField(WritableField):
choices = property(_get_choices, _set_choices)
- ### Regular serializier stuff...
+ ### Regular serializer stuff...
def field_to_native(self, obj, field_name):
value = getattr(obj, self.source or field_name)
return self.to_native(value)
- def field_from_native(self, data, field_name, into):
+ def field_from_native(self, data, files, field_name, into):
if self.read_only:
return
@@ -342,7 +359,7 @@ class ManyRelatedMixin(object):
value = getattr(obj, self.source or field_name)
return [self.to_native(item) for item in value.all()]
- def field_from_native(self, data, field_name, into):
+ def field_from_native(self, data, files, field_name, into):
if self.read_only:
return
@@ -701,6 +718,23 @@ class CharField(WritableField):
return smart_unicode(value)
+class URLField(CharField):
+ type_name = 'URLField'
+
+ def __init__(self, **kwargs):
+ kwargs['max_length'] = kwargs.get('max_length', 200)
+ kwargs['validators'] = [validators.URLValidator()]
+ super(URLField, self).__init__(**kwargs)
+
+
+class SlugField(CharField):
+ type_name = 'SlugField'
+
+ def __init__(self, *args, **kwargs):
+ kwargs['max_length'] = kwargs.get('max_length', 50)
+ super(SlugField, self).__init__(*args, **kwargs)
+
+
class ChoiceField(WritableField):
type_name = 'ChoiceField'
widget = widgets.Select
@@ -933,3 +967,109 @@ class FloatField(WritableField):
except (TypeError, ValueError):
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
+
+
+class FileField(WritableField):
+ _use_files = True
+ type_name = 'FileField'
+ widget = widgets.FileInput
+
+ default_error_messages = {
+ 'invalid': _("No file was submitted. Check the encoding type on the form."),
+ 'missing': _("No file was submitted."),
+ 'empty': _("The submitted file is empty."),
+ 'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'),
+ 'contradiction': _('Please either submit a file or check the clear checkbox, not both.')
+ }
+
+ def __init__(self, *args, **kwargs):
+ self.max_length = kwargs.pop('max_length', None)
+ self.allow_empty_file = kwargs.pop('allow_empty_file', False)
+ super(FileField, self).__init__(*args, **kwargs)
+
+ def from_native(self, data):
+ if data in validators.EMPTY_VALUES:
+ return None
+
+ # UploadedFile objects should have name and size attributes.
+ try:
+ file_name = data.name
+ file_size = data.size
+ except AttributeError:
+ raise ValidationError(self.error_messages['invalid'])
+
+ if self.max_length is not None and len(file_name) > self.max_length:
+ error_values = {'max': self.max_length, 'length': len(file_name)}
+ raise ValidationError(self.error_messages['max_length'] % error_values)
+ if not file_name:
+ raise ValidationError(self.error_messages['invalid'])
+ if not self.allow_empty_file and not file_size:
+ raise ValidationError(self.error_messages['empty'])
+
+ return data
+
+ def to_native(self, value):
+ return value.name
+
+
+class ImageField(FileField):
+ _use_files = True
+
+ default_error_messages = {
+ 'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."),
+ }
+
+ def from_native(self, data):
+ """
+ Checks that the file-upload field data contains a valid image (GIF, JPG,
+ PNG, possibly others -- whatever the Python Imaging Library supports).
+ """
+ f = super(ImageField, self).from_native(data)
+ if f is None:
+ return None
+
+ # Try to import PIL in either of the two ways it can end up installed.
+ try:
+ from PIL import Image
+ except ImportError:
+ import Image
+
+ # We need to get a file object for PIL. We might have a path or we might
+ # have to read the data into memory.
+ if hasattr(data, 'temporary_file_path'):
+ file = data.temporary_file_path()
+ else:
+ if hasattr(data, 'read'):
+ file = BytesIO(data.read())
+ else:
+ file = BytesIO(data['content'])
+
+ try:
+ # load() could spot a truncated JPEG, but it loads the entire
+ # image in memory, which is a DoS vector. See #3848 and #18520.
+ # verify() must be called immediately after the constructor.
+ Image.open(file).verify()
+ except ImportError:
+ # Under PyPy, it is possible to import PIL. However, the underlying
+ # _imaging C module isn't available, so an ImportError will be
+ # raised. Catch and re-raise.
+ raise
+ except Exception: # Python Imaging Library doesn't recognize it as an image
+ raise ValidationError(self.error_messages['invalid_image'])
+ if hasattr(f, 'seek') and callable(f.seek):
+ f.seek(0)
+ return f
+
+
+class SerializerMethodField(Field):
+ """
+ A field that gets its value by calling a method on the serializer it's attached to.
+ """
+
+ def __init__(self, method_name):
+ self.method_name = method_name
+ super(SerializerMethodField, self).__init__()
+
+ def field_to_native(self, obj, field_name):
+ value = getattr(self.parent, self.method_name)(obj)
+ return self.to_native(value)
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index ccae4825..bcc87660 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -45,7 +45,7 @@ class DjangoFilterBackend(BaseFilterBackend):
class AutoFilterSet(self.default_filter_set):
class Meta:
model = view_model
- fields = filter_fields
+ fields = filter_fields
return AutoFilterSet
return None
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index ebd06e45..dd8dfcf8 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -14,6 +14,8 @@ class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
+
+ model = None
serializer_class = None
model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
@@ -30,8 +32,10 @@ class GenericAPIView(views.APIView):
def get_serializer_class(self):
"""
Return the class to use for the serializer.
- Use `self.serializer_class`, falling back to constructing a
- model serializer class from `self.model_serializer_class`
+
+ Defaults to using `self.serializer_class`, falls back to constructing a
+ model serializer class using `self.model_serializer_class`, with
+ `self.model` as the model.
"""
serializer_class = self.serializer_class
@@ -44,11 +48,13 @@ class GenericAPIView(views.APIView):
return serializer_class
def get_serializer(self, instance=None, data=None, files=None):
- # TODO: add support for files
- # TODO: add support for seperate serializer/deserializer
+ """
+ Return the serializer instance that should be used for validating and
+ deserializing input, and for serializing output.
+ """
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
- return serializer_class(instance, data=data, context=context)
+ return serializer_class(instance, data=data, files=files, context=context)
class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
@@ -56,47 +62,59 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
Base class for generic views onto a queryset.
"""
- pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
paginate_by = api_settings.PAGINATE_BY
+ paginate_by_param = api_settings.PAGINATE_BY_PARAM
+ pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
filter_backend = api_settings.FILTER_BACKEND
def filter_queryset(self, queryset):
+ """
+ Given a queryset, filter it with whichever filter backend is in use.
+ """
if not self.filter_backend:
return queryset
backend = self.filter_backend()
return backend.filter_queryset(self.request, queryset, self)
- def get_filtered_queryset(self):
- return self.filter_queryset(self.get_queryset())
-
- def get_pagination_serializer_class(self):
+ def get_pagination_serializer(self, page=None):
"""
- Return the class to use for the pagination serializer.
+ Return a serializer instance to use with paginated data.
"""
class SerializerClass(self.pagination_serializer_class):
class Meta:
object_serializer_class = self.get_serializer_class()
- return SerializerClass
-
- def get_pagination_serializer(self, page=None):
- pagination_serializer_class = self.get_pagination_serializer_class()
+ pagination_serializer_class = SerializerClass
context = self.get_serializer_context()
return pagination_serializer_class(instance=page, context=context)
+ def get_paginate_by(self, queryset):
+ """
+ Return the size of pages to use with pagination.
+ """
+ if self.paginate_by_param:
+ query_params = self.request.QUERY_PARAMS
+ try:
+ return int(query_params[self.paginate_by_param])
+ except (KeyError, ValueError):
+ pass
+ return self.paginate_by
+
class SingleObjectAPIView(SingleObjectMixin, GenericAPIView):
"""
Base class for generic views onto a model instance.
"""
+
pk_url_kwarg = 'pk' # Not provided in Django 1.3
slug_url_kwarg = 'slug' # Not provided in Django 1.3
+ slug_field = 'slug'
- def get_object(self):
+ def get_object(self, queryset=None):
"""
Override default to add support for object-level permissions.
"""
- obj = super(SingleObjectAPIView, self).get_object()
+ obj = super(SingleObjectAPIView, self).get_object(queryset)
if not self.has_permission(self.request, obj):
self.permission_denied(self.request)
return obj
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index c3625a88..1edcfa5c 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -15,13 +15,20 @@ class CreateModelMixin(object):
Should be mixed in with any `BaseView`.
"""
def create(self, request, *args, **kwargs):
- serializer = self.get_serializer(data=request.DATA)
+ serializer = self.get_serializer(data=request.DATA, files=request.FILES)
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
- return Response(serializer.data, status=status.HTTP_201_CREATED)
+ headers = self.get_success_headers(serializer.data)
+ return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ def get_success_headers(self, data):
+ try:
+ return {'Location': data['url']}
+ except (TypeError, KeyError):
+ return {}
+
def pre_save(self, obj):
pass
@@ -34,14 +41,16 @@ class ListModelMixin(object):
empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs):
- self.object_list = self.get_filtered_queryset()
+ queryset = self.get_queryset()
+ self.object_list = self.filter_queryset(queryset)
# Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets.
allow_empty = self.get_allow_empty()
- if not allow_empty and len(self.object_list) == 0:
- error_args = {'class_name': self.__class__.__name__}
- raise Http404(self.empty_error % error_args)
+ if not allow_empty and not self.object_list:
+ class_name = self.__class__.__name__
+ error_msg = self.empty_error % {'class_name': class_name}
+ raise Http404(error_msg)
# Pagination size is set by the `.paginate_by` attribute,
# which may be `None` to disable pagination.
@@ -75,17 +84,18 @@ class UpdateModelMixin(object):
def update(self, request, *args, **kwargs):
try:
self.object = self.get_object()
- success_status = status.HTTP_200_OK
+ created = False
except Http404:
self.object = None
- success_status = status.HTTP_201_CREATED
+ created = True
- serializer = self.get_serializer(self.object, data=request.DATA)
+ serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES)
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
- return Response(serializer.data, status=success_status)
+ status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK
+ return Response(serializer.data, status=status_code)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 22fd6e74..db1bce39 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -4,7 +4,7 @@ Renderers are used to serialize a response into specific media types.
They give us a generic way of being able to handle various media types
on the response, such as JSON encoded data or HTML output.
-REST framework also provides an HTML renderer the renders the browseable API.
+REST framework also provides an HTML renderer the renders the browsable API.
"""
import copy
import string
@@ -19,7 +19,7 @@ from rest_framework.request import clone_request
from rest_framework.utils import dict2xml
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
-from rest_framework import VERSION
+from rest_framework import VERSION, status
from rest_framework import serializers, parsers
@@ -320,7 +320,9 @@ class BrowsableAPIRenderer(BaseRenderer):
serializers.SlugRelatedField: forms.ChoiceField,
serializers.ManySlugRelatedField: forms.MultipleChoiceField,
serializers.HyperlinkedRelatedField: forms.ChoiceField,
- serializers.ManyHyperlinkedRelatedField: forms.MultipleChoiceField
+ serializers.ManyHyperlinkedRelatedField: forms.MultipleChoiceField,
+ serializers.FileField: forms.FileField,
+ serializers.ImageField: forms.ImageField,
}
fields = {}
@@ -479,7 +481,7 @@ class BrowsableAPIRenderer(BaseRenderer):
# Munge DELETE Response code to allow us to return content
# (Do this *after* we've rendered the template so that we include
# the normal deletion response code in the output)
- if response.status_code == 204:
- response.status_code = 200
+ if response.status_code == status.HTTP_204_NO_CONTENT:
+ response.status_code = status.HTTP_200_OK
return ret
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 0de01204..be78c43a 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -15,14 +15,17 @@ class Response(SimpleTemplateResponse):
Alters the init arguments slightly.
For example, drop 'template_name', and instead use 'data'.
- Setting 'renderer' and 'media_type' will typically be defered,
+ Setting 'renderer' and 'media_type' will typically be deferred,
For example being set automatically by the `APIView`.
"""
super(Response, self).__init__(None, status=status)
self.data = data
- self.headers = headers and headers[:] or []
self.template_name = template_name
self.exception = exception
+
+ if headers:
+ for name,value in headers.iteritems():
+ self[name] = value
@property
def rendered_content(self):
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 46d4765e..f7918c4c 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -89,9 +89,9 @@ class BaseSerializer(Field):
pass
_options_class = SerializerOptions
- _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations.
+ _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations.
- def __init__(self, instance=None, data=None, context=None, **kwargs):
+ def __init__(self, instance=None, data=None, files=None, context=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
self.fields = copy.deepcopy(self.base_fields)
@@ -101,26 +101,29 @@ class BaseSerializer(Field):
self.context = context or {}
self.init_data = data
+ self.init_files = files
self.object = instance
+ self.default_fields = self.get_default_fields()
self._data = None
+ self._files = None
self._errors = None
#####
# Methods to determine which fields to use when (de)serializing objects.
- def default_fields(self, nested=False):
+ def get_default_fields(self):
"""
Return the complete set of default fields for the object, as a dict.
"""
return {}
- def get_fields(self, nested=False):
+ def get_fields(self):
"""
Returns the complete set of fields for the object as a dict.
This will be the set of any explicitly declared fields,
- plus the set of fields returned by default_fields().
+ plus the set of fields returned by get_default_fields().
"""
ret = SortedDict()
@@ -131,8 +134,7 @@ class BaseSerializer(Field):
field.initialize(parent=self, field_name=key)
# Add in the default fields
- fields = self.default_fields(nested)
- for key, val in fields.items():
+ for key, val in self.default_fields.items():
if key not in ret:
ret[key] = val
@@ -163,7 +165,7 @@ class BaseSerializer(Field):
self.opts.depth = parent.opts.depth - 1
#####
- # Methods to convert or revert from objects <--> primative representations.
+ # Methods to convert or revert from objects <--> primitive representations.
def get_field_key(self, field_name):
"""
@@ -179,7 +181,7 @@ class BaseSerializer(Field):
ret = self._dict_class()
ret.fields = {}
- fields = self.get_fields(nested=bool(self.opts.depth))
+ fields = self.get_fields()
for field_name, field in fields.items():
key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name)
@@ -187,16 +189,16 @@ class BaseSerializer(Field):
ret.fields[key] = field
return ret
- def restore_fields(self, data):
+ def restore_fields(self, data, files):
"""
Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields.
"""
- fields = self.get_fields(nested=bool(self.opts.depth))
+ fields = self.get_fields()
reverted_data = {}
for field_name, field in fields.items():
try:
- field.field_from_native(data, field_name, reverted_data)
+ field.field_from_native(data, files, field_name, reverted_data)
except ValidationError as err:
self._errors[field_name] = list(err.messages)
@@ -207,7 +209,7 @@ class BaseSerializer(Field):
Run `validate_<fieldname>()` and `validate()` methods on the serializer
"""
# TODO: refactor this so we're not determining the fields again
- fields = self.get_fields(nested=bool(self.opts.depth))
+ fields = self.get_fields()
for field_name, field in fields.items():
try:
@@ -244,23 +246,23 @@ class BaseSerializer(Field):
def to_native(self, obj):
"""
- Serialize objects -> primatives.
+ Serialize objects -> primitives.
"""
if hasattr(obj, '__iter__'):
return [self.convert_object(item) for item in obj]
return self.convert_object(obj)
- def from_native(self, data):
+ def from_native(self, data, files):
"""
- Deserialize primatives -> objects.
+ Deserialize primitives -> objects.
"""
if hasattr(data, '__iter__') and not isinstance(data, dict):
# TODO: error data when deserializing lists
return (self.from_native(item) for item in data)
self._errors = {}
- if data is not None:
- attrs = self.restore_fields(data)
+ if data is not None or files is not None:
+ attrs = self.restore_fields(data, files)
attrs = self.perform_validation(attrs)
else:
self._errors['non_field_errors'] = ['No input provided']
@@ -275,6 +277,9 @@ class BaseSerializer(Field):
"""
obj = getattr(obj, self.source or field_name)
+ if is_simple_callable(obj):
+ obj = obj()
+
# If the object has an "all" method, assume it's a relationship
if is_simple_callable(getattr(obj, 'all', None)):
return [self.to_native(item) for item in obj.all()]
@@ -288,7 +293,7 @@ class BaseSerializer(Field):
setting self.object if no errors occurred.
"""
if self._errors is None:
- obj = self.from_native(self.init_data)
+ obj = self.from_native(self.init_data, self.init_files)
if not self._errors:
self.object = obj
return self._errors
@@ -330,16 +335,10 @@ class ModelSerializer(Serializer):
"""
_options_class = ModelSerializerOptions
- def default_fields(self, nested=False):
+ def get_default_fields(self):
"""
Return all the fields that should be serialized for the model.
"""
- # TODO: Modfiy this so that it's called on init, and drop
- # serialize/obj/data arguments.
- #
- # We *could* provide a hook for dynamic fields, but
- # it'd be nice if the default was to generate fields statically
- # at the point of __init__
cls = self.opts.model
opts = get_concrete_model(cls)._meta
@@ -351,6 +350,7 @@ class ModelSerializer(Serializer):
fields += [field for field in opts.many_to_many if field.serialize]
ret = SortedDict()
+ nested = bool(self.opts.depth)
is_pk = True # First field in the list is the pk
for model_field in fields:
@@ -427,6 +427,10 @@ class ModelSerializer(Serializer):
kwargs['choices'] = model_field.flatchoices
return ChoiceField(**kwargs)
+ max_length = getattr(model_field, 'max_length', None)
+ if max_length:
+ kwargs['max_length'] = max_length
+
field_mapping = {
models.FloatField: FloatField,
models.IntegerField: IntegerField,
@@ -437,9 +441,13 @@ class ModelSerializer(Serializer):
models.DateField: DateField,
models.EmailField: EmailField,
models.CharField: CharField,
+ models.URLField: URLField,
+ models.SlugField: SlugField,
models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField,
+ models.FileField: FileField,
+ models.ImageField: ImageField,
}
try:
return field_mapping[model_field.__class__](**kwargs)
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 906a7cf6..ee24a4ad 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -54,19 +54,26 @@ DEFAULTS = {
'user': None,
'anon': None,
},
+
+ # Pagination
'PAGINATE_BY': None,
+ 'PAGINATE_BY_PARAM': None,
+
+ # Filtering
'FILTER_BACKEND': None,
+ # Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
+ # Browser enhancements
'FORM_METHOD_OVERRIDE': '_method',
'FORM_CONTENT_OVERRIDE': '_content',
'FORM_CONTENTTYPE_OVERRIDE': '_content_type',
'URL_ACCEPT_OVERRIDE': 'accept',
'URL_FORMAT_OVERRIDE': 'format',
- 'FORMAT_SUFFIX_KWARG': 'format'
+ 'FORMAT_SUFFIX_KWARG': 'format',
}
@@ -152,7 +159,7 @@ class APISettings(object):
def validate_setting(self, attr, val):
if attr == 'FILTER_BACKEND' and val is not None:
- # Make sure we can initilize the class
+ # Make sure we can initialize the class
val()
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index 8ab4c4e4..96ca9f52 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -1,4 +1,4 @@
-from django.conf.urls.defaults import patterns
+from django.conf.urls.defaults import patterns, include
from django.contrib.auth.models import User
from django.test import Client, TestCase
@@ -27,6 +27,7 @@ MockView.authentication_classes += (TokenAuthentication,)
urlpatterns = patterns('',
(r'^$', MockView.as_view()),
+ (r'^auth-token/', 'rest_framework.authtoken.views.obtain_auth_token'),
)
@@ -152,3 +153,33 @@ class TokenAuthTests(TestCase):
self.token.delete()
token = Token.objects.create(user=self.user)
self.assertTrue(bool(token.key))
+
+ def test_token_login_json(self):
+ """Ensure token login view using JSON POST works."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/login/',
+ json.dumps({'username': self.username, 'password': self.password}), 'application/json')
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(json.loads(response.content)['token'], self.key)
+
+ def test_token_login_json_bad_creds(self):
+ """Ensure token login view using JSON POST fails if bad credentials are used."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/login/',
+ json.dumps({'username': self.username, 'password': "badpass"}), 'application/json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_json_missing_fields(self):
+ """Ensure token login view using JSON POST fails if missing fields."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/login/',
+ json.dumps({'username': self.username}), 'application/json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_form(self):
+ """Ensure token login view using form POST works."""
+ client = Client(enforce_csrf_checks=True)
+ response = client.post('/auth-token/login/',
+ {'username': self.username, 'password': self.password})
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(json.loads(response.content)['token'], self.key)
diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py
index 61d7f7b1..5dd57b7c 100644
--- a/rest_framework/tests/files.py
+++ b/rest_framework/tests/files.py
@@ -1,34 +1,39 @@
-# from django.test import TestCase
-# from django import forms
+import StringIO
+import datetime
-# from django.test.client import RequestFactory
-# from rest_framework.views import View
-# from rest_framework.response import Response
+from django.test import TestCase
-# import StringIO
+from rest_framework import serializers
-# class UploadFilesTests(TestCase):
-# """Check uploading of files"""
-# def setUp(self):
-# self.factory = RequestFactory()
+class UploadedFile(object):
+ def __init__(self, file, created=None):
+ self.file = file
+ self.created = created or datetime.datetime.now()
-# def test_upload_file(self):
-# class FileForm(forms.Form):
-# file = forms.FileField()
+class UploadedFileSerializer(serializers.Serializer):
+ file = serializers.FileField()
+ created = serializers.DateTimeField()
-# class MockView(View):
-# permissions = ()
-# form = FileForm
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.file = attrs['file']
+ instance.created = attrs['created']
+ return instance
+ return UploadedFile(**attrs)
-# def post(self, request, *args, **kwargs):
-# return Response({'FILE_NAME': self.CONTENT['file'].name,
-# 'FILE_CONTENT': self.CONTENT['file'].read()})
-# file = StringIO.StringIO('stuff')
-# file.name = 'stuff.txt'
-# request = self.factory.post('/', {'file': file})
-# view = MockView.as_view()
-# response = view(request)
-# self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"})
+class FileSerializerTests(TestCase):
+
+ def test_create(self):
+ now = datetime.datetime.now()
+ file = StringIO.StringIO('stuff')
+ file.name = 'stuff.txt'
+ file.size = file.len
+ serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
+ uploaded_file = UploadedFile(file=file, created=now)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.object.created, uploaded_file.created)
+ self.assertEquals(serializer.object.file, uploaded_file.file)
+ self.assertFalse(serializer.object is uploaded_file)
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index 5ab850af..d7effce7 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -8,12 +8,13 @@ factory = RequestFactory()
class BlogPostCommentSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail')
text = serializers.CharField()
blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail')
class Meta:
model = BlogPostComment
- fields = ('text', 'blog_post_url')
+ fields = ('text', 'blog_post_url', 'url')
class PhotoSerializer(serializers.Serializer):
@@ -53,6 +54,9 @@ class BlogPostCommentListCreate(generics.ListCreateAPIView):
model = BlogPostComment
serializer_class = BlogPostCommentSerializer
+class BlogPostCommentDetail(generics.RetrieveAPIView):
+ model = BlogPostComment
+ serializer_class = BlogPostCommentSerializer
class BlogPostDetail(generics.RetrieveAPIView):
model = BlogPost
@@ -80,6 +84,7 @@ urlpatterns = patterns('',
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'),
+ url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'),
url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'),
url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'),
url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'),
@@ -191,6 +196,7 @@ class TestCreateWithForeignKeys(TestCase):
request = factory.post('/comments/', data=data)
response = self.create_view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response['Location'], 'http://testserver/comments/1/')
self.assertEqual(self.post.blogpostcomment_set.count(), 1)
self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
@@ -215,6 +221,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
request = factory.post('/photos/', data=data)
response = self.list_create_view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
self.assertEqual(self.post.photo_set.count(), 1)
self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index f6e5333b..70523fc0 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -35,6 +35,13 @@ def foobar():
return 'foobar'
+class CustomField(models.CharField):
+
+ def __init__(self, *args, **kwargs):
+ kwargs['max_length'] = 12
+ super(CustomField, self).__init__(*args, **kwargs)
+
+
class RESTFrameworkModel(models.Model):
"""
Base for test models that sets app_label, so they play nicely.
@@ -113,12 +120,16 @@ class Comment(RESTFrameworkModel):
class ActionItem(RESTFrameworkModel):
title = models.CharField(max_length=200)
done = models.BooleanField(default=False)
+ info = CustomField(default='---', max_length=12)
# Models for reverse relations
class BlogPost(RESTFrameworkModel):
title = models.CharField(max_length=100)
+ def get_first_comment(self):
+ return self.blogpostcomment_set.all()[0]
+
class BlogPostComment(RESTFrameworkModel):
text = models.TextField()
@@ -157,4 +168,4 @@ class OptionalRelationModel(RESTFrameworkModel):
# Model for RegexField
class Book(RESTFrameworkModel):
- isbn = models.CharField(max_length=13) \ No newline at end of file
+ isbn = models.CharField(max_length=13)
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 713a7255..3062007d 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -34,6 +34,21 @@ if django_filters:
filter_backend = filters.DjangoFilterBackend
+class DefaultPageSizeKwargView(generics.ListAPIView):
+ """
+ View for testing default paginate_by_param usage
+ """
+ model = BasicModel
+
+
+class PaginateByParamView(generics.ListAPIView):
+ """
+ View for testing custom paginate_by_param usage
+ """
+ model = BasicModel
+ paginate_by_param = 'page_size'
+
+
class IntegrationTestPagination(TestCase):
"""
Integration tests for paginated list views.
@@ -135,7 +150,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
class UnitTestPagination(TestCase):
"""
- Unit tests for pagination of primative objects.
+ Unit tests for pagination of primitive objects.
"""
def setUp(self):
@@ -156,3 +171,68 @@ class UnitTestPagination(TestCase):
self.assertEquals(serializer.data['next'], None)
self.assertEquals(serializer.data['previous'], '?page=2')
self.assertEquals(serializer.data['results'], self.objects[20:])
+
+
+class TestUnpaginated(TestCase):
+ """
+ Tests for list views without pagination.
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = DefaultPageSizeKwargView.as_view()
+
+ def test_unpaginated(self):
+ """
+ Tests the default page size for this view.
+ no page size --> no limit --> no meta data
+ """
+ request = factory.get('/')
+ response = self.view(request)
+ self.assertEquals(response.data, self.data)
+
+
+class TestCustomPaginateByParam(TestCase):
+ """
+ Tests for list views with default page size kwarg
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = PaginateByParamView.as_view()
+
+ def test_default_page_size(self):
+ """
+ Tests the default page size for this view.
+ no page size --> no limit --> no meta data
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEquals(response.data, self.data)
+
+ def test_paginate_by_param(self):
+ """
+ If paginate_by_param is set, the new kwarg should limit per view requests.
+ """
+ request = factory.get('/?page_size=5')
+ response = self.view(request).render()
+ self.assertEquals(response.data['count'], 13)
+ self.assertEquals(response.data['results'], self.data[:5])
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index ad100e53..520029ec 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -48,6 +48,7 @@ class BookSerializer(serializers.ModelSerializer):
class ActionItemSerializer(serializers.ModelSerializer):
+
class Meta:
model = ActionItem
@@ -246,6 +247,23 @@ class ValidationTests(TestCase):
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.errors, {})
+ def test_modelserializer_max_length_exceeded(self):
+ data = {
+ 'title': 'x' * 201,
+ }
+ serializer = ActionItemSerializer(data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']})
+
+ def test_default_modelfield_max_length_exceeded(self):
+ data = {
+ 'title': 'Testing "info" field...',
+ 'info': 'x' * 13,
+ }
+ serializer = ActionItemSerializer(data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']})
+
class RegexValidationTest(TestCase):
def test_create_failed(self):
@@ -487,7 +505,10 @@ class CallableDefaultValueTests(TestCase):
class ManyRelatedTests(TestCase):
- def setUp(self):
+ def test_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
class BlogPostCommentSerializer(serializers.Serializer):
text = serializers.CharField()
@@ -496,14 +517,7 @@ class ManyRelatedTests(TestCase):
title = serializers.CharField()
comments = BlogPostCommentSerializer(source='blogpostcomment_set')
- self.serializer_class = BlogPostSerializer
-
- def test_reverse_relations(self):
- post = BlogPost.objects.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I hate this blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- serializer = self.serializer_class(instance=post)
+ serializer = BlogPostSerializer(instance=post)
expected = {
'title': 'Test blog post',
'comments': [
@@ -514,6 +528,59 @@ class ManyRelatedTests(TestCase):
self.assertEqual(serializer.data, expected)
+ def test_callable_source(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ first_comment = BlogPostCommentSerializer(source='get_first_comment')
+
+ serializer = BlogPostSerializer(post)
+
+ expected = {
+ 'title': 'Test blog post',
+ 'first_comment': {'text': 'I love this blog post'}
+ }
+ self.assertEqual(serializer.data, expected)
+
+
+class SerializerMethodFieldTests(TestCase):
+ def setUp(self):
+
+ class BoopSerializer(serializers.Serializer):
+ beep = serializers.SerializerMethodField('get_beep')
+ boop = serializers.Field()
+ boop_count = serializers.SerializerMethodField('get_boop_count')
+
+ def get_beep(self, obj):
+ return 'hello!'
+
+ def get_boop_count(self, obj):
+ return len(obj.boop)
+
+ self.serializer_class = BoopSerializer
+
+ def test_serializer_method_field(self):
+
+ class MyModel(object):
+ boop = ['a', 'b', 'c']
+
+ source_data = MyModel()
+
+ serializer = self.serializer_class(source_data)
+
+ expected = {
+ 'beep': u'hello!',
+ 'boop': [u'a', u'b', u'c'],
+ 'boop_count': 3,
+ }
+
+ self.assertEqual(serializer.data, expected)
+
# Test for issue #324
class BlankFieldTests(TestCase):
diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py
index 0b94c25b..4b98b941 100644
--- a/rest_framework/tests/throttling.py
+++ b/rest_framework/tests/throttling.py
@@ -106,7 +106,7 @@ class ThrottlingTests(TestCase):
if expect is not None:
self.assertEquals(response['X-Throttle-Wait-Seconds'], expect)
else:
- self.assertFalse('X-Throttle-Wait-Seconds' in response.headers)
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
def test_seconds_fields(self):
"""
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index 316ccd19..0ad926fa 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -4,7 +4,7 @@ from rest_framework.settings import api_settings
def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
"""
- Supplement existing urlpatterns with corrosponding patterns that also
+ Supplement existing urlpatterns with corresponding patterns that also
include a '.format' suffix. Retains urlpattern ordering.
urlpatterns:
diff --git a/rest_framework/urls.py b/rest_framework/urls.py
index 1a81101f..bcdc23e7 100644
--- a/rest_framework/urls.py
+++ b/rest_framework/urls.py
@@ -1,7 +1,7 @@
"""
-Login and logout views for the browseable API.
+Login and logout views for the browsable API.
-Add these to your root URLconf if you're using the browseable API and
+Add these to your root URLconf if you're using the browsable API and
your API requires authentication.
The urls must be namespaced as 'rest_framework', and you should make sure
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 1afbd697..10bdd5a5 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -140,7 +140,7 @@ class APIView(View):
def http_method_not_allowed(self, request, *args, **kwargs):
"""
- Called if `request.method` does not corrospond to a handler method.
+ Called if `request.method` does not correspond to a handler method.
"""
raise exceptions.MethodNotAllowed(request.method)