aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/authentication.py14
-rw-r--r--rest_framework/compat.py14
-rw-r--r--rest_framework/fields.py22
-rw-r--r--rest_framework/filters.py4
-rw-r--r--rest_framework/generics.py8
-rw-r--r--rest_framework/renderers.py37
-rw-r--r--rest_framework/response.py4
-rw-r--r--rest_framework/serializers.py22
-rw-r--r--rest_framework/templates/rest_framework/base.html4
-rw-r--r--rest_framework/templatetags/rest_framework.py4
-rw-r--r--rest_framework/test.py2
-rw-r--r--rest_framework/tests/models.py4
-rw-r--r--rest_framework/tests/test_authentication.py9
-rw-r--r--rest_framework/tests/test_fields.py27
-rw-r--r--rest_framework/tests/test_serializer.py5
-rw-r--r--rest_framework/tests/test_serializers.py5
-rw-r--r--rest_framework/urls.py8
18 files changed, 134 insertions, 61 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 2d76b55d..01036cef 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -8,7 +8,7 @@ ______ _____ _____ _____ __ _
"""
__title__ = 'Django REST framework'
-__version__ = '2.3.13'
+__version__ = '2.3.14'
__author__ = 'Tom Christie'
__license__ = 'BSD 2-Clause'
__copyright__ = 'Copyright 2011-2014 Tom Christie'
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index da9ca510..887ef5d7 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -310,6 +310,13 @@ class OAuth2Authentication(BaseAuthentication):
auth = get_authorization_header(request).split()
+ if len(auth) == 1:
+ msg = 'Invalid bearer header. No credentials provided.'
+ raise exceptions.AuthenticationFailed(msg)
+ elif len(auth) > 2:
+ msg = 'Invalid bearer header. Token string should not contain spaces.'
+ raise exceptions.AuthenticationFailed(msg)
+
if auth and auth[0].lower() == b'bearer':
access_token = auth[1]
elif 'access_token' in request.POST:
@@ -319,13 +326,6 @@ class OAuth2Authentication(BaseAuthentication):
else:
return None
- if len(auth) == 1:
- msg = 'Invalid bearer header. No credentials provided.'
- raise exceptions.AuthenticationFailed(msg)
- elif len(auth) > 2:
- msg = 'Invalid bearer header. Token string should not contain spaces.'
- raise exceptions.AuthenticationFailed(msg)
-
return self.authenticate_credentials(request, access_token)
def authenticate_credentials(self, request, access_token):
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index d155f554..9ad8b0d2 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -48,11 +48,15 @@ try:
except ImportError:
django_filters = None
-# guardian is optional
-try:
- import guardian
-except ImportError:
- guardian = None
+# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
+# Fixes (#1712). We keep the try/except for the test suite.
+guardian = None
+if 'guardian' in settings.INSTALLED_APPS:
+ try:
+ import guardian
+ import guardian.shortcuts # Fixes #1624
+ except ImportError:
+ pass
# cStringIO only if it's available, otherwise StringIO
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index d80aab56..6caae924 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -187,7 +187,7 @@ class Field(object):
def field_to_native(self, obj, field_name):
"""
- Given and object and a field name, returns the value that should be
+ Given an object and a field name, returns the value that should be
serialized for that field.
"""
if obj is None:
@@ -474,8 +474,12 @@ class CharField(WritableField):
self.validators.append(validators.MaxLengthValidator(max_length))
def from_native(self, value):
- if isinstance(value, six.string_types) or value is None:
+ if isinstance(value, six.string_types):
return value
+
+ if value is None:
+ return ''
+
return smart_text(value)
@@ -506,7 +510,7 @@ class SlugField(CharField):
class ChoiceField(WritableField):
type_name = 'ChoiceField'
- type_label = 'multiple choice'
+ type_label = 'choice'
form_field_class = forms.ChoiceField
widget = widgets.Select
default_error_messages = {
@@ -514,12 +518,16 @@ class ChoiceField(WritableField):
'the available choices.'),
}
- def __init__(self, choices=(), *args, **kwargs):
+ def __init__(self, choices=(), blank_display_value=None, *args, **kwargs):
self.empty = kwargs.pop('empty', '')
super(ChoiceField, self).__init__(*args, **kwargs)
self.choices = choices
if not self.required:
- self.choices = BLANK_CHOICE_DASH + self.choices
+ if blank_display_value is None:
+ blank_choice = BLANK_CHOICE_DASH
+ else:
+ blank_choice = [('', blank_display_value)]
+ self.choices = blank_choice + self.choices
def _get_choices(self):
return self._choices
@@ -1023,9 +1031,9 @@ class SerializerMethodField(Field):
A field that gets its value by calling a method on the serializer it's attached to.
"""
- def __init__(self, method_name):
+ def __init__(self, method_name, *args, **kwargs):
self.method_name = method_name
- super(SerializerMethodField, self).__init__()
+ super(SerializerMethodField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name):
value = getattr(self.parent, self.method_name)(obj)
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index 96d15eb9..c3b846ae 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -116,6 +116,10 @@ class OrderingFilter(BaseFilterBackend):
def get_ordering(self, request):
"""
Ordering is set by a comma delimited ?ordering=... query parameter.
+
+ The `ordering` query parameter can be overridden by setting
+ the `ordering_param` value on the OrderingFilter or by
+ specifying an `ORDERING_PARAM` value in the API settings.
"""
params = request.QUERY_PARAMS.get(self.ordering_param)
if params:
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 65ccd952..42204841 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -94,8 +94,8 @@ class GenericAPIView(views.APIView):
'view': self
}
- def get_serializer(self, instance=None, data=None,
- files=None, many=False, partial=False):
+ def get_serializer(self, instance=None, data=None, files=None, many=False,
+ partial=False, allow_add_remove=False):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
@@ -103,7 +103,9 @@ class GenericAPIView(views.APIView):
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
return serializer_class(instance, data=data, files=files,
- many=many, partial=partial, context=context)
+ many=many, partial=partial,
+ allow_add_remove=allow_add_remove,
+ context=context)
def get_pagination_serializer(self, page):
"""
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 484961ad..7048d87d 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -54,32 +54,37 @@ class JSONRenderer(BaseRenderer):
format = 'json'
encoder_class = encoders.JSONEncoder
ensure_ascii = True
- charset = None
- # JSON is a binary encoding, that can be encoded as utf-8, utf-16 or utf-32.
+
+ # We don't set a charset because JSON is a binary encoding,
+ # that can be encoded as utf-8, utf-16 or utf-32.
# See: http://www.ietf.org/rfc/rfc4627.txt
# Also: http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/
+ charset = None
+
+ def get_indent(self, accepted_media_type, renderer_context):
+ if accepted_media_type:
+ # If the media type looks like 'application/json; indent=4',
+ # then pretty print the result.
+ base_media_type, params = parse_header(accepted_media_type.encode('ascii'))
+ try:
+ return max(min(int(params['indent']), 8), 0)
+ except (KeyError, ValueError, TypeError):
+ pass
+
+ # If 'indent' is provided in the context, then pretty print the result.
+ # E.g. If we're being called by the BrowsableAPIRenderer.
+ return renderer_context.get('indent', None)
+
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
- Render `data` into JSON.
+ Render `data` into JSON, returning a bytestring.
"""
if data is None:
return bytes()
- # If 'indent' is provided in the context, then pretty print the result.
- # E.g. If we're being called by the BrowsableAPIRenderer.
renderer_context = renderer_context or {}
- indent = renderer_context.get('indent', None)
-
- if accepted_media_type:
- # If the media type looks like 'application/json; indent=4',
- # then pretty print the result.
- base_media_type, params = parse_header(accepted_media_type.encode('ascii'))
- indent = params.get('indent', indent)
- try:
- indent = max(min(int(indent), 8), 0)
- except (ValueError, TypeError):
- indent = None
+ indent = self.get_indent(accepted_media_type, renderer_context)
ret = json.dumps(data, cls=self.encoder_class,
indent=indent, ensure_ascii=self.ensure_ascii)
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 1dc6abcf..5c02ea50 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -5,6 +5,7 @@ it is initialized with unrendered data, instead of a pre-rendered string.
The appropriate renderer is called during Django's template response rendering.
"""
from __future__ import unicode_literals
+import django
from django.core.handlers.wsgi import STATUS_CODE_TEXT
from django.template.response import SimpleTemplateResponse
from rest_framework.compat import six
@@ -15,6 +16,9 @@ class Response(SimpleTemplateResponse):
An HttpResponse that allows its data to be rendered into
arbitrary media types.
"""
+ # TODO: remove that once Django 1.3 isn't supported
+ if django.VERSION >= (1, 4):
+ rendering_attrs = SimpleTemplateResponse.rendering_attrs + ['_closable_objects']
def __init__(self, data=None, status=200,
template_name=None, headers=None,
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 6dd09f68..43d339da 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -33,8 +33,8 @@ from rest_framework.settings import api_settings
# This helps keep the separation between model fields, form fields, and
# serializer fields more explicit.
-from rest_framework.relations import *
-from rest_framework.fields import *
+from rest_framework.relations import * # NOQA
+from rest_framework.fields import * # NOQA
def _resolve_model(obj):
@@ -49,7 +49,7 @@ def _resolve_model(obj):
String representations should have the format:
'appname.ModelName'
"""
- if type(obj) == str and len(obj.split('.')) == 2:
+ if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
@@ -345,7 +345,7 @@ class BaseSerializer(WritableField):
for field_name, field in self.fields.items():
if field.read_only and obj is None:
- continue
+ continue
field.initialize(parent=self, field_name=field_name)
key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name)
@@ -759,9 +759,9 @@ class ModelSerializer(Serializer):
field.read_only = True
ret[accessor_name] = field
-
+
# Ensure that 'read_only_fields' is an iterable
- assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
+ assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
# Add the `read_only` flag to any fields that have been specified
# in the `read_only_fields` option
@@ -776,10 +776,10 @@ class ModelSerializer(Serializer):
"on serializer '%s'." %
(field_name, self.__class__.__name__))
ret[field_name].read_only = True
-
+
# Ensure that 'write_only_fields' is an iterable
- assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
-
+ assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
+
for field_name in self.opts.write_only_fields:
assert field_name not in self.base_fields.keys(), (
"field '%s' on serializer '%s' specified in "
@@ -790,7 +790,7 @@ class ModelSerializer(Serializer):
"Non-existant field '%s' specified in `write_only_fields` "
"on serializer '%s'." %
(field_name, self.__class__.__name__))
- ret[field_name].write_only = True
+ ret[field_name].write_only = True
return ret
@@ -977,7 +977,7 @@ class ModelSerializer(Serializer):
try:
setattr(instance, key, val)
except ValueError:
- self._errors[key] = self.error_messages['required']
+ self._errors[key] = [self.error_messages['required']]
# Any relations that cannot be set until we've
# saved the model get hidden away on these
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 7067ee2f..e96fa8ec 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -93,7 +93,7 @@
{% endif %}
{% if options_form %}
- <form class="button-form" action="{{ request.get_full_path }}" method="POST" class="pull-right">
+ <form class="button-form" action="{{ request.get_full_path }}" method="POST">
{% csrf_token %}
<input type="hidden" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="OPTIONS" />
<button class="btn btn-primary js-tooltip" title="Make an OPTIONS request on the {{ name }} resource">OPTIONS</button>
@@ -101,7 +101,7 @@
{% endif %}
{% if delete_form %}
- <form class="button-form" action="{{ request.get_full_path }}" method="POST" class="pull-right">
+ <form class="button-form" action="{{ request.get_full_path }}" method="POST">
{% csrf_token %}
<input type="hidden" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="DELETE" />
<button class="btn btn-danger js-tooltip" title="Make a DELETE request on the {{ name }} resource">DELETE</button>
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index dff176d6..a155d8d2 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -122,7 +122,7 @@ def optional_login(request):
except NoReverseMatch:
return ''
- snippet = "<a href='%s?next=%s'>Log in</a>" % (login_url, request.path)
+ snippet = "<a href='%s?next=%s'>Log in</a>" % (login_url, escape(request.path))
return snippet
@@ -136,7 +136,7 @@ def optional_logout(request):
except NoReverseMatch:
return ''
- snippet = "<a href='%s?next=%s'>Log out</a>" % (logout_url, request.path)
+ snippet = "<a href='%s?next=%s'>Log out</a>" % (logout_url, escape(request.path))
return snippet
diff --git a/rest_framework/test.py b/rest_framework/test.py
index df5a5b3b..284bcee0 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -36,7 +36,7 @@ class APIRequestFactory(DjangoRequestFactory):
"""
if not data:
- return ('', None)
+ return ('', content_type)
assert format is None or content_type is None, (
'You may not set both `format` and `content_type`.'
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index e171d3bd..fba3f8f7 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -105,6 +105,7 @@ class Album(RESTFrameworkModel):
title = models.CharField(max_length=100, unique=True)
ref = models.CharField(max_length=10, unique=True, null=True, blank=True)
+
class Photo(RESTFrameworkModel):
description = models.TextField()
album = models.ForeignKey(Album)
@@ -112,7 +113,8 @@ class Photo(RESTFrameworkModel):
# Model for issue #324
class BlankFieldModel(RESTFrameworkModel):
- title = models.CharField(max_length=100, blank=True, null=False)
+ title = models.CharField(max_length=100, blank=True, null=False,
+ default="title")
# Model for issue #380
diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py
index a1c43d9c..34bf2910 100644
--- a/rest_framework/tests/test_authentication.py
+++ b/rest_framework/tests/test_authentication.py
@@ -550,6 +550,15 @@ class OAuth2Tests(TestCase):
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_missing(self):
+ """Ensure that a missing token lead to the correct HTTP error status code"""
+ auth = "Bearer"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_passing_auth(self):
"""Ensure GETing form over OAuth with correct client credentials succeed"""
auth = self._create_authorization_header()
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index 3ae1c438..17d12f23 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -717,6 +717,15 @@ class ChoiceFieldTests(TestCase):
f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES)
+ def test_blank_choice_display(self):
+ blank = 'No Preference'
+ f = serializers.ChoiceField(
+ required=False,
+ choices=SAMPLE_CHOICES,
+ blank_display_value=blank,
+ )
+ self.assertEqual(f.choices, [('', blank)] + SAMPLE_CHOICES)
+
def test_invalid_choice_model(self):
s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'})
self.assertFalse(s.is_valid())
@@ -993,3 +1002,21 @@ class BooleanField(TestCase):
bool_field = serializers.BooleanField(required=True)
self.assertFalse(BooleanRequiredSerializer(data={}).is_valid())
+
+
+class SerializerMethodFieldTest(TestCase):
+ """
+ Tests for the SerializerMethodField field_to_native() behavior
+ """
+ class SerializerTest(serializers.Serializer):
+ def get_my_test(self, obj):
+ return obj.my_test[0:5]
+
+ class Example():
+ my_test = 'Hey, this is a test !'
+
+ def test_field_to_native(self):
+ s = serializers.SerializerMethodField('get_my_test')
+ s.initialize(self.SerializerTest(), 'name')
+ result = s.field_to_native(self.Example(), None)
+ self.assertEqual(result, 'Hey, ')
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index e688c823..fb2eac0b 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -685,7 +685,7 @@ class ModelValidationTests(TestCase):
photo_serializer = PhotoSerializer(instance=photo, data={'album': ''}, partial=True)
self.assertFalse(photo_serializer.is_valid())
self.assertTrue('album' in photo_serializer.errors)
- self.assertEqual(photo_serializer.errors['album'], photo_serializer.error_messages['required'])
+ self.assertEqual(photo_serializer.errors['album'], [photo_serializer.error_messages['required']])
def test_foreign_key_with_partial(self):
"""
@@ -1236,6 +1236,9 @@ class BlankFieldTests(TestCase):
def test_create_model_null_field(self):
serializer = self.model_serializer_class(data={'title': None})
self.assertEqual(serializer.is_valid(), True)
+ serializer.save()
+ self.assertIsNot(serializer.object.pk, None)
+ self.assertEqual(serializer.object.title, '')
def test_create_not_blank_field(self):
"""
diff --git a/rest_framework/tests/test_serializers.py b/rest_framework/tests/test_serializers.py
index 082a400c..120510ac 100644
--- a/rest_framework/tests/test_serializers.py
+++ b/rest_framework/tests/test_serializers.py
@@ -3,6 +3,7 @@ from django.test import TestCase
from rest_framework.serializers import _resolve_model
from rest_framework.tests.models import BasicModel
+from rest_framework.compat import six
class ResolveModelTests(TestCase):
@@ -19,6 +20,10 @@ class ResolveModelTests(TestCase):
resolved_model = _resolve_model('tests.BasicModel')
self.assertEqual(resolved_model, BasicModel)
+ def test_resolve_unicode_representation(self):
+ resolved_model = _resolve_model(six.text_type('tests.BasicModel'))
+ self.assertEqual(resolved_model, BasicModel)
+
def test_resolve_non_django_model(self):
with self.assertRaises(ValueError):
_resolve_model(TestCase)
diff --git a/rest_framework/urls.py b/rest_framework/urls.py
index 9c4719f1..5d70f899 100644
--- a/rest_framework/urls.py
+++ b/rest_framework/urls.py
@@ -2,15 +2,15 @@
Login and logout views for the browsable API.
Add these to your root URLconf if you're using the browsable API and
-your API requires authentication.
-
-The urls must be namespaced as 'rest_framework', and you should make sure
-your authentication settings include `SessionAuthentication`.
+your API requires authentication:
urlpatterns = patterns('',
...
url(r'^auth', include('rest_framework.urls', namespace='rest_framework'))
)
+
+The urls must be namespaced as 'rest_framework', and you should make sure
+your authentication settings include `SessionAuthentication`.
"""
from __future__ import unicode_literals
from rest_framework.compat import patterns, url