aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/authtoken/views.py2
-rw-r--r--rest_framework/fields.py28
-rw-r--r--rest_framework/negotiation.py4
-rw-r--r--rest_framework/renderers.py28
-rw-r--r--rest_framework/request.py9
-rw-r--r--rest_framework/serializers.py36
-rw-r--r--rest_framework/tests/authentication.py4
-rw-r--r--rest_framework/tests/models.py25
-rw-r--r--rest_framework/tests/request.py28
-rw-r--r--rest_framework/tests/serializer.py41
11 files changed, 141 insertions, 66 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index a2233f3d..48cebbc5 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,3 +1,3 @@
-__version__ = '2.1.4'
+__version__ = '2.1.6'
VERSION = __version__ # synonym
diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py
index 3ac674e2..cfaacbe9 100644
--- a/rest_framework/authtoken/views.py
+++ b/rest_framework/authtoken/views.py
@@ -18,7 +18,7 @@ class ObtainAuthToken(APIView):
if serializer.is_valid():
token, created = Token.objects.get_or_create(user=serializer.object['user'])
return Response({'token': token.key})
- return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ return Response(serializer.errors, status=status.HTTP_401_UNAUTHORIZED)
obtain_auth_token = ObtainAuthToken.as_view()
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index dff9123d..ff39fac4 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -10,6 +10,7 @@ from django.core import validators
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix
from django.conf import settings
+from django import forms
from django.forms import widgets
from django.forms.models import ModelChoiceIterator
from django.utils.encoding import is_protected_type, smart_unicode
@@ -35,6 +36,7 @@ class Field(object):
empty = ''
type_name = None
_use_files = None
+ form_field_class = forms.CharField
def __init__(self, source=None):
self.parent = None
@@ -55,7 +57,7 @@ class Field(object):
self.root = parent.root or parent
self.context = self.root.context
if self.root.partial:
- self.required = False
+ self.required = False
def field_from_native(self, data, files, field_name, into):
"""
@@ -223,7 +225,7 @@ class ModelField(WritableField):
getattr(self.model_field, 'min_length', None))
self.max_length = kwargs.pop('max_length',
getattr(self.model_field, 'max_length', None))
-
+
super(ModelField, self).__init__(*args, **kwargs)
if self.min_length is not None:
@@ -394,6 +396,7 @@ class PrimaryKeyRelatedField(RelatedField):
Represents a to-one relationship as a pk value.
"""
default_read_only = False
+ form_field_class = forms.ChoiceField
# TODO: Remove these field hacks...
def prepare_value(self, obj):
@@ -440,6 +443,7 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
Represents a to-many relationship as a pk value.
"""
default_read_only = False
+ form_field_class = forms.MultipleChoiceField
def prepare_value(self, obj):
return self.to_native(obj.pk)
@@ -483,6 +487,7 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
class SlugRelatedField(RelatedField):
default_read_only = False
+ form_field_class = forms.ChoiceField
def __init__(self, *args, **kwargs):
self.slug_field = kwargs.pop('slug_field', None)
@@ -504,7 +509,7 @@ class SlugRelatedField(RelatedField):
class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField):
- pass
+ form_field_class = forms.MultipleChoiceField
### Hyperlinked relationships
@@ -517,6 +522,7 @@ class HyperlinkedRelatedField(RelatedField):
slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
default_read_only = False
+ form_field_class = forms.ChoiceField
def __init__(self, *args, **kwargs):
try:
@@ -616,7 +622,7 @@ class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
"""
Represents a to-many relationship, using hyperlinking.
"""
- pass
+ form_field_class = forms.MultipleChoiceField
class HyperlinkedIdentityField(Field):
@@ -674,6 +680,7 @@ class HyperlinkedIdentityField(Field):
class BooleanField(WritableField):
type_name = 'BooleanField'
+ form_field_class = forms.BooleanField
widget = widgets.CheckboxInput
default_error_messages = {
'invalid': _(u"'%s' value must be either True or False."),
@@ -686,15 +693,16 @@ class BooleanField(WritableField):
default = False
def from_native(self, value):
- if value in ('t', 'True', '1'):
+ if value in ('true', 't', 'True', '1'):
return True
- if value in ('f', 'False', '0'):
+ if value in ('false', 'f', 'False', '0'):
return False
return bool(value)
class CharField(WritableField):
type_name = 'CharField'
+ form_field_class = forms.CharField
def __init__(self, max_length=None, min_length=None, *args, **kwargs):
self.max_length, self.min_length = max_length, min_length
@@ -739,6 +747,7 @@ class SlugField(CharField):
class ChoiceField(WritableField):
type_name = 'ChoiceField'
+ form_field_class = forms.ChoiceField
widget = widgets.Select
default_error_messages = {
'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'),
@@ -785,6 +794,7 @@ class ChoiceField(WritableField):
class EmailField(CharField):
type_name = 'EmailField'
+ form_field_class = forms.EmailField
default_error_messages = {
'invalid': _('Enter a valid e-mail address.'),
@@ -836,6 +846,7 @@ class RegexField(CharField):
class DateField(WritableField):
type_name = 'DateField'
widget = widgets.DateInput
+ form_field_class = forms.DateField
default_error_messages = {
'invalid': _(u"'%s' value has an invalid date format. It must be "
@@ -874,6 +885,7 @@ class DateField(WritableField):
class DateTimeField(WritableField):
type_name = 'DateTimeField'
widget = widgets.DateTimeInput
+ form_field_class = forms.DateTimeField
default_error_messages = {
'invalid': _(u"'%s' value has an invalid format. It must be in "
@@ -928,6 +940,7 @@ class DateTimeField(WritableField):
class IntegerField(WritableField):
type_name = 'IntegerField'
+ form_field_class = forms.IntegerField
default_error_messages = {
'invalid': _('Enter a whole number.'),
@@ -957,6 +970,7 @@ class IntegerField(WritableField):
class FloatField(WritableField):
type_name = 'FloatField'
+ form_field_class = forms.FloatField
default_error_messages = {
'invalid': _("'%s' value must be a float."),
@@ -976,6 +990,7 @@ class FloatField(WritableField):
class FileField(WritableField):
_use_files = True
type_name = 'FileField'
+ form_field_class = forms.FileField
widget = widgets.FileInput
default_error_messages = {
@@ -1018,6 +1033,7 @@ class FileField(WritableField):
class ImageField(FileField):
_use_files = True
+ form_field_class = forms.ImageField
default_error_messages = {
'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."),
diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py
index dae38477..ee2800a6 100644
--- a/rest_framework/negotiation.py
+++ b/rest_framework/negotiation.py
@@ -2,6 +2,7 @@ from django.http import Http404
from rest_framework import exceptions
from rest_framework.settings import api_settings
from rest_framework.utils.mediatypes import order_by_precedence, media_type_matches
+from rest_framework.utils.mediatypes import _MediaType
class BaseContentNegotiation(object):
@@ -48,7 +49,8 @@ class DefaultContentNegotiation(BaseContentNegotiation):
for media_type in media_type_set:
if media_type_matches(renderer.media_type, media_type):
# Return the most specific media type as accepted.
- if len(renderer.media_type) > len(media_type):
+ if (_MediaType(renderer.media_type).precedence >
+ _MediaType(media_type).precedence):
# Eg client requests '*/*'
# Accepted media type is 'application/json'
return renderer, renderer.media_type
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 550963cb..25a32baa 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -306,26 +306,6 @@ class BrowsableAPIRenderer(BaseRenderer):
return True
def serializer_to_form_fields(self, serializer):
- field_mapping = {
- serializers.FloatField: forms.FloatField,
- serializers.IntegerField: forms.IntegerField,
- serializers.DateTimeField: forms.DateTimeField,
- serializers.DateField: forms.DateField,
- serializers.EmailField: forms.EmailField,
- serializers.RegexField: forms.RegexField,
- serializers.CharField: forms.CharField,
- serializers.ChoiceField: forms.ChoiceField,
- serializers.BooleanField: forms.BooleanField,
- serializers.PrimaryKeyRelatedField: forms.ChoiceField,
- serializers.ManyPrimaryKeyRelatedField: forms.MultipleChoiceField,
- serializers.SlugRelatedField: forms.ChoiceField,
- serializers.ManySlugRelatedField: forms.MultipleChoiceField,
- serializers.HyperlinkedRelatedField: forms.ChoiceField,
- serializers.ManyHyperlinkedRelatedField: forms.MultipleChoiceField,
- serializers.FileField: forms.FileField,
- serializers.ImageField: forms.ImageField,
- }
-
fields = {}
for k, v in serializer.get_fields().items():
if getattr(v, 'read_only', True):
@@ -349,13 +329,7 @@ class BrowsableAPIRenderer(BaseRenderer):
kwargs['label'] = k
- try:
- fields[k] = field_mapping[v.__class__](**kwargs)
- except KeyError:
- if getattr(v, 'choices', None) is not None:
- fields[k] = forms.ChoiceField(**kwargs)
- else:
- fields[k] = forms.CharField(**kwargs)
+ fields[k] = v.form_field_class(**kwargs)
return fields
def get_form(self, view, method, request):
diff --git a/rest_framework/request.py b/rest_framework/request.py
index a1827ba4..39c64321 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -169,6 +169,15 @@ class Request(object):
self._user, self._auth = self._authenticate()
return self._user
+ @user.setter
+ def user(self, value):
+ """
+ Sets the user on the current request. This is necessary to maintain
+ compatilbility with django.contrib.auth where the user proprety is
+ set in the login and logout functions.
+ """
+ self._user = value
+
@property
def auth(self):
"""
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 775a8a1e..984f3ac5 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -60,7 +60,7 @@ def _get_declared_fields(bases, attrs):
# If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
- # in order to the correct order of fields.
+ # in order to maintain the correct order of fields.
for base in bases[::-1]:
if hasattr(base, 'base_fields'):
fields = base.base_fields.items() + fields
@@ -94,7 +94,6 @@ class BaseSerializer(Field):
def __init__(self, instance=None, data=None, files=None, context=None, partial=False, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
- self.fields = copy.deepcopy(self.base_fields)
self.parent = None
self.root = None
self.partial = partial
@@ -104,7 +103,7 @@ class BaseSerializer(Field):
self.init_data = data
self.init_files = files
self.object = instance
- self.default_fields = self.get_default_fields()
+ self.fields = self.get_fields()
self._data = None
self._files = None
@@ -140,13 +139,15 @@ class BaseSerializer(Field):
ret = SortedDict()
# Get the explicitly declared fields
- for key, field in self.fields.items():
+ base_fields = copy.deepcopy(self.base_fields)
+ for key, field in base_fields.items():
ret[key] = field
# Set up the field
field.initialize(parent=self, field_name=key)
# Add in the default fields
- for key, val in self.default_fields.items():
+ default_fields = self.get_default_fields()
+ for key, val in default_fields.items():
if key not in ret:
ret[key] = val
@@ -193,8 +194,7 @@ class BaseSerializer(Field):
ret = self._dict_class()
ret.fields = {}
- fields = self.get_fields()
- for field_name, field in fields.items():
+ for field_name, field in self.fields.items():
key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name)
ret[key] = value
@@ -206,9 +206,8 @@ class BaseSerializer(Field):
Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields.
"""
- fields = self.get_fields()
reverted_data = {}
- for field_name, field in fields.items():
+ for field_name, field in self.fields.items():
try:
field.field_from_native(data, files, field_name, reverted_data)
except ValidationError as err:
@@ -220,10 +219,7 @@ class BaseSerializer(Field):
"""
Run `validate_<fieldname>()` and `validate()` methods on the serializer
"""
- # TODO: refactor this so we're not determining the fields again
- fields = self.get_fields()
-
- for field_name, field in fields.items():
+ for field_name, field in self.fields.items():
try:
validate_method = getattr(self, 'validate_%s' % field_name, None)
if validate_method:
@@ -294,10 +290,18 @@ class BaseSerializer(Field):
Override default so that we can apply ModelSerializer as a nested
field to relationships.
"""
- obj = getattr(obj, self.source or field_name)
- if is_simple_callable(obj):
- obj = obj()
+ if self.source:
+ value = obj
+ for component in self.source.split('.'):
+ value = getattr(value, component)
+ if is_simple_callable(value):
+ value = value()
+ obj = value
+ else:
+ value = getattr(obj, field_name)
+ if is_simple_callable(value):
+ obj = value()
# If the object has an "all" method, assume it's a relationship
if is_simple_callable(getattr(obj, 'all', None)):
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index 96ca9f52..802bc6c1 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -167,14 +167,14 @@ class TokenAuthTests(TestCase):
client = Client(enforce_csrf_checks=True)
response = client.post('/auth-token/login/',
json.dumps({'username': self.username, 'password': "badpass"}), 'application/json')
- self.assertEqual(response.status_code, 400)
+ self.assertEqual(response.status_code, 401)
def test_token_login_json_missing_fields(self):
"""Ensure token login view using JSON POST fails if missing fields."""
client = Client(enforce_csrf_checks=True)
response = client.post('/auth-token/login/',
json.dumps({'username': self.username}), 'application/json')
- self.assertEqual(response.status_code, 400)
+ self.assertEqual(response.status_code, 401)
def test_token_login_form(self):
"""Ensure token login view using form POST works."""
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 9a59e841..428bf130 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -124,8 +124,21 @@ class ActionItem(RESTFrameworkModel):
# Models for reverse relations
+class Person(RESTFrameworkModel):
+ name = models.CharField(max_length=10)
+ age = models.IntegerField(null=True, blank=True)
+
+ @property
+ def info(self):
+ return {
+ 'name': self.name,
+ 'age': self.age,
+ }
+
+
class BlogPost(RESTFrameworkModel):
title = models.CharField(max_length=100)
+ writer = models.ForeignKey(Person, null=True, blank=True)
def get_first_comment(self):
return self.blogpostcomment_set.all()[0]
@@ -145,18 +158,6 @@ class Photo(RESTFrameworkModel):
album = models.ForeignKey(Album)
-class Person(RESTFrameworkModel):
- name = models.CharField(max_length=10)
- age = models.IntegerField(null=True, blank=True)
-
- @property
- def info(self):
- return {
- 'name': self.name,
- 'age': self.age,
- }
-
-
# Model for issue #324
class BlankFieldModel(RESTFrameworkModel):
title = models.CharField(max_length=100, blank=True, null=True)
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py
index ff48f3fa..2850992d 100644
--- a/rest_framework/tests/request.py
+++ b/rest_framework/tests/request.py
@@ -3,6 +3,8 @@ Tests for content parsing, and form-overloaded content parsing.
"""
from django.conf.urls.defaults import patterns
from django.contrib.auth.models import User
+from django.contrib.auth import authenticate, login, logout
+from django.contrib.sessions.middleware import SessionMiddleware
from django.test import TestCase, Client
from django.utils import simplejson as json
@@ -276,3 +278,29 @@ class TestContentParsingWithAuthentication(TestCase):
# response = self.csrf_client.post('/', content)
# self.assertEqual(status.OK, response.status_code, "POST data is malformed")
+
+
+class TestUserSetter(TestCase):
+
+ def setUp(self):
+ # Pass request object through session middleware so session is
+ # available to login and logout functions
+ self.request = Request(factory.get('/'))
+ SessionMiddleware().process_request(self.request)
+
+ User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
+ self.user = authenticate(username='ringo', password='yellow')
+
+ def test_user_can_be_set(self):
+ self.request.user = self.user
+ self.assertEqual(self.request.user, self.user)
+
+ def test_user_can_login(self):
+ login(self.request, self.user)
+ self.assertEqual(self.request.user, self.user)
+
+ def test_user_can_logout(self):
+ self.request.user = self.user
+ self.assertFalse(self.request.user.is_anonymous())
+ logout(self.request)
+ self.assertTrue(self.request.user.is_anonymous())
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index bdf72a91..18e24b71 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -577,6 +577,47 @@ class ManyRelatedTests(TestCase):
self.assertEqual(serializer.data, expected)
+class RelatedTraversalTest(TestCase):
+ def test_nested_traversal(self):
+ user = Person.objects.create(name="django")
+ post = BlogPost.objects.create(title="Test blog post", writer=user)
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ from rest_framework.tests.models import BlogPostComment
+
+ class PersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+ fields = ("name", "age")
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPostComment
+ fields = ("text", "post_owner")
+
+ text = serializers.CharField()
+ post_owner = PersonSerializer(source='blog_post.writer')
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ comments = BlogPostCommentSerializer(source='blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+
+ expected = {
+ 'title': u'Test blog post',
+ 'comments': [{
+ 'text': u'I love this blog post',
+ 'post_owner': {
+ "name": u"django",
+ "age": None
+ }
+ }]
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+
class SerializerMethodFieldTests(TestCase):
def setUp(self):