aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py4
-rw-r--r--rest_framework/authentication.py16
-rw-r--r--rest_framework/compat.py29
-rw-r--r--rest_framework/exceptions.py2
-rw-r--r--rest_framework/fields.py12
-rw-r--r--rest_framework/filters.py7
-rw-r--r--rest_framework/parsers.py6
-rw-r--r--rest_framework/relations.py12
-rw-r--r--rest_framework/renderers.py15
-rw-r--r--rest_framework/request.py4
-rw-r--r--rest_framework/serializers.py57
-rw-r--r--rest_framework/settings.py4
-rw-r--r--rest_framework/templatetags/rest_framework.py2
-rw-r--r--rest_framework/test.py12
-rw-r--r--rest_framework/utils/mediatypes.py2
-rw-r--r--rest_framework/views.py4
16 files changed, 130 insertions, 58 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 6759680b..2d76b55d 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -8,10 +8,10 @@ ______ _____ _____ _____ __ _
"""
__title__ = 'Django REST framework'
-__version__ = '2.3.12'
+__version__ = '2.3.13'
__author__ = 'Tom Christie'
__license__ = 'BSD 2-Clause'
-__copyright__ = 'Copyright 2011-2013 Tom Christie'
+__copyright__ = 'Copyright 2011-2014 Tom Christie'
# Version synonym
VERSION = __version__
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index e491ce5f..da9ca510 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -6,6 +6,7 @@ import base64
from django.contrib.auth import authenticate
from django.core.exceptions import ImproperlyConfigured
+from django.conf import settings
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
@@ -291,6 +292,7 @@ class OAuth2Authentication(BaseAuthentication):
OAuth 2 authentication backend using `django-oauth2-provider`
"""
www_authenticate_realm = 'api'
+ allow_query_params_token = settings.DEBUG
def __init__(self, *args, **kwargs):
super(OAuth2Authentication, self).__init__(*args, **kwargs)
@@ -308,7 +310,13 @@ class OAuth2Authentication(BaseAuthentication):
auth = get_authorization_header(request).split()
- if not auth or auth[0].lower() != b'bearer':
+ if auth and auth[0].lower() == b'bearer':
+ access_token = auth[1]
+ elif 'access_token' in request.POST:
+ access_token = request.POST['access_token']
+ elif 'access_token' in request.GET and self.allow_query_params_token:
+ access_token = request.GET['access_token']
+ else:
return None
if len(auth) == 1:
@@ -318,7 +326,7 @@ class OAuth2Authentication(BaseAuthentication):
msg = 'Invalid bearer header. Token string should not contain spaces.'
raise exceptions.AuthenticationFailed(msg)
- return self.authenticate_credentials(request, auth[1])
+ return self.authenticate_credentials(request, access_token)
def authenticate_credentials(self, request, access_token):
"""
@@ -326,11 +334,11 @@ class OAuth2Authentication(BaseAuthentication):
"""
try:
- token = oauth2_provider.models.AccessToken.objects.select_related('user')
+ token = oauth2_provider.oauth2.models.AccessToken.objects.select_related('user')
# provider_now switches to timezone aware datetime when
# the oauth2_provider version supports to it.
token = token.get(token=access_token, expires__gt=provider_now())
- except oauth2_provider.models.AccessToken.DoesNotExist:
+ except oauth2_provider.oauth2.models.AccessToken.DoesNotExist:
raise exceptions.AuthenticationFailed('Invalid token')
user = token.user
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index d283e2f5..d155f554 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -550,13 +550,10 @@ except (ImportError, ImproperlyConfigured):
# OAuth 2 support is optional
try:
- import provider.oauth2 as oauth2_provider
- from provider.oauth2 import models as oauth2_provider_models
- from provider.oauth2 import forms as oauth2_provider_forms
+ import provider as oauth2_provider
from provider import scope as oauth2_provider_scope
from provider import constants as oauth2_constants
- from provider import __version__ as provider_version
- if provider_version in ('0.2.3', '0.2.4'):
+ if oauth2_provider.__version__ in ('0.2.3', '0.2.4'):
# 0.2.3 and 0.2.4 are supported version that do not support
# timezone aware datetimes
import datetime
@@ -566,8 +563,6 @@ try:
from django.utils.timezone import now as provider_now
except ImportError:
oauth2_provider = None
- oauth2_provider_models = None
- oauth2_provider_forms = None
oauth2_provider_scope = None
oauth2_constants = None
provider_now = None
@@ -584,3 +579,23 @@ if six.PY3:
else:
def is_non_str_iterable(obj):
return hasattr(obj, '__iter__')
+
+
+try:
+ from django.utils.encoding import python_2_unicode_compatible
+except ImportError:
+ def python_2_unicode_compatible(klass):
+ """
+ A decorator that defines __unicode__ and __str__ methods under Python 2.
+ Under Python 3 it does nothing.
+
+ To support Python 2 and 3 with a single code base, define a __str__ method
+ returning text and apply this decorator to the class.
+ """
+ if '__str__' not in klass.__dict__:
+ raise ValueError("@python_2_unicode_compatible cannot be applied "
+ "to %s because it doesn't define __str__()." %
+ klass.__name__)
+ klass.__unicode__ = klass.__str__
+ klass.__str__ = lambda self: self.__unicode__().encode('utf-8')
+ return klass
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 0ac5866e..5f774a9f 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -20,6 +20,8 @@ class APIException(Exception):
def __init__(self, detail=None):
self.detail = detail or self.default_detail
+ def __str__(self):
+ return self.detail
class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 05daaab7..946a5954 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -164,7 +164,7 @@ class Field(object):
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
- model_field - The model field this field corresponds to, if one exists.
+ field_name - The name of the field being initialized.
"""
self.parent = parent
self.root = parent.root or parent
@@ -301,6 +301,11 @@ class WritableField(Field):
result.validators = self.validators[:]
return result
+ def get_default_value(self):
+ if is_simple_callable(self.default):
+ return self.default()
+ return self.default
+
def validate(self, value):
if value in validators.EMPTY_VALUES and self.required:
raise ValidationError(self.error_messages['required'])
@@ -349,10 +354,7 @@ class WritableField(Field):
except KeyError:
if self.default is not None and not self.partial:
# Note: partial updates shouldn't set defaults
- if is_simple_callable(self.default):
- native = self.default()
- else:
- native = self.default
+ native = self.get_default_value()
else:
if self.required:
raise ValidationError(self.error_messages['required'])
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index de91caed..96d15eb9 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -6,6 +6,7 @@ from __future__ import unicode_literals
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from rest_framework.compat import django_filters, six, guardian, get_model_name
+from rest_framework.settings import api_settings
from functools import reduce
import operator
@@ -69,7 +70,8 @@ class DjangoFilterBackend(BaseFilterBackend):
class SearchFilter(BaseFilterBackend):
- search_param = 'search' # The URL query parameter used for the search.
+ # The URL query parameter used for the search.
+ search_param = api_settings.SEARCH_PARAM
def get_search_terms(self, request):
"""
@@ -107,7 +109,8 @@ class SearchFilter(BaseFilterBackend):
class OrderingFilter(BaseFilterBackend):
- ordering_param = 'ordering' # The URL query parameter used for the ordering.
+ # The URL query parameter used for the ordering.
+ ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
def get_ordering(self, request):
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index f1b3e38d..4990971b 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -10,7 +10,7 @@ from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter
-from rest_framework.compat import etree, six, yaml
+from rest_framework.compat import etree, six, yaml, force_text
from rest_framework.exceptions import ParseError
from rest_framework import renderers
import json
@@ -288,7 +288,7 @@ class FileUploadParser(BaseParser):
try:
meta = parser_context['request'].META
- disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'])
- return disposition[1]['filename']
+ disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8'))
+ return force_text(disposition[1]['filename'])
except (AttributeError, KeyError):
pass
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 163a8984..3463954d 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -59,6 +59,8 @@ class RelatedField(WritableField):
super(RelatedField, self).__init__(*args, **kwargs)
if not self.required:
+ # Accessed in ModelChoiceIterator django/forms/models.py:1034
+ # If set adds empty choice.
self.empty_label = BLANK_CHOICE_DASH[0][1]
self.queryset = queryset
@@ -119,6 +121,14 @@ class RelatedField(WritableField):
choices = property(_get_choices, _set_choices)
+ ### Default value handling
+
+ def get_default_value(self):
+ default = super(RelatedField, self).get_default_value()
+ if self.many and default is None:
+ return []
+ return default
+
### Regular serializer stuff...
def field_to_native(self, obj, field_name):
@@ -167,7 +177,7 @@ class RelatedField(WritableField):
except KeyError:
if self.partial:
return
- value = [] if self.many else None
+ value = self.get_default_value()
if value in self.null_values:
if self.required:
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 7cf1c051..484961ad 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -146,7 +146,7 @@ class XMLRenderer(BaseRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
- Renders *obj* into serialized XML.
+ Renders `data` into serialized XML.
"""
if data is None:
return ''
@@ -193,17 +193,26 @@ class YAMLRenderer(BaseRenderer):
format = 'yaml'
encoder = encoders.SafeDumper
charset = 'utf-8'
+ ensure_ascii = True
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
- Renders *obj* into serialized YAML.
+ Renders `data` into serialized YAML.
"""
assert yaml, 'YAMLRenderer requires pyyaml to be installed'
if data is None:
return ''
- return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder)
+ return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii)
+
+
+class UnicodeYAMLRenderer(YAMLRenderer):
+ """
+ Renderer which serializes to YAML.
+ Does *not* apply character escaping for non-ascii characters.
+ """
+ ensure_ascii = False
class TemplateHTMLRenderer(BaseRenderer):
diff --git a/rest_framework/request.py b/rest_framework/request.py
index ca70b49e..40467c03 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -346,7 +346,7 @@ class Request(object):
media_type = self.content_type
if stream is None or media_type is None:
- empty_data = QueryDict('', self._request._encoding)
+ empty_data = QueryDict('', encoding=self._request._encoding)
empty_files = MultiValueDict()
return (empty_data, empty_files)
@@ -362,7 +362,7 @@ class Request(object):
# re-raise. Ensures we don't simply repeat the error when
# attempting to render the browsable renderer response, or when
# logging the request or similar.
- self._data = QueryDict('', self._request._encoding)
+ self._data = QueryDict('', encoding=self._request._encoding)
self._files = MultiValueDict()
raise
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 10256d47..ea9509bf 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -16,6 +16,7 @@ import datetime
import inspect
import types
from decimal import Decimal
+from django.contrib.contenttypes.generic import GenericForeignKey
from django.core.paginator import Page
from django.db import models
from django.forms import widgets
@@ -438,16 +439,6 @@ class BaseSerializer(WritableField):
raise ValidationError(self.error_messages['required'])
return
- # Set the serializer object if it exists
- obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
-
- # If we have a model manager or similar object then we need
- # to iterate through each instance.
- if (self.many and
- not hasattr(obj, '__iter__') and
- is_simple_callable(getattr(obj, 'all', None))):
- obj = obj.all()
-
if self.source == '*':
if value:
reverted_data = self.restore_fields(value, {})
@@ -457,6 +448,16 @@ class BaseSerializer(WritableField):
if value in (None, ''):
into[(self.source or field_name)] = None
else:
+ # Set the serializer object if it exists
+ obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
+
+ # If we have a model manager or similar object then we need
+ # to iterate through each instance.
+ if (self.many and
+ not hasattr(obj, '__iter__') and
+ is_simple_callable(getattr(obj, 'all', None))):
+ obj = obj.all()
+
kwargs = {
'instance': obj,
'data': value,
@@ -757,8 +758,11 @@ class ModelSerializer(Serializer):
field.read_only = True
ret[accessor_name] = field
+
+ # Ensure that 'read_only_fields' is an iterable
+ assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
- # Add the `read_only` flag to any fields that have bee specified
+ # Add the `read_only` flag to any fields that have been specified
# in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
assert field_name not in self.base_fields.keys(), (
@@ -771,7 +775,10 @@ class ModelSerializer(Serializer):
"on serializer '%s'." %
(field_name, self.__class__.__name__))
ret[field_name].read_only = True
-
+
+ # Ensure that 'write_only_fields' is an iterable
+ assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
+
for field_name in self.opts.write_only_fields:
assert field_name not in self.base_fields.keys(), (
"field '%s' on serializer '%s' specified in "
@@ -881,7 +888,7 @@ class ModelSerializer(Serializer):
except KeyError:
return ModelField(model_field=model_field, **kwargs)
- def get_validation_exclusions(self):
+ def get_validation_exclusions(self, instance=None):
"""
Return a list of field names to exclude from model validation.
"""
@@ -893,7 +900,7 @@ class ModelSerializer(Serializer):
field_name = field.source or field_name
if field_name in exclusions \
and not field.read_only \
- and field.required \
+ and (field.required or hasattr(instance, field_name)) \
and not isinstance(field, Serializer):
exclusions.remove(field_name)
return exclusions
@@ -908,7 +915,7 @@ class ModelSerializer(Serializer):
the full_clean validation checking.
"""
try:
- instance.full_clean(exclude=self.get_validation_exclusions())
+ instance.full_clean(exclude=self.get_validation_exclusions(instance))
except ValidationError as err:
self._errors = err.message_dict
return None
@@ -937,6 +944,8 @@ class ModelSerializer(Serializer):
# Forward m2m relations
for field in meta.many_to_many + meta.virtual_fields:
+ if isinstance(field, GenericForeignKey):
+ continue
if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name)
@@ -946,17 +955,15 @@ class ModelSerializer(Serializer):
if isinstance(self.fields.get(field_name, None), Serializer):
nested_forward_relations[field_name] = attrs[field_name]
- # Update an existing instance...
- if instance is not None:
- for key, val in attrs.items():
- try:
- setattr(instance, key, val)
- except ValueError:
- self._errors[key] = self.error_messages['required']
+ # Create an empty instance of the model
+ if instance is None:
+ instance = self.opts.model()
- # ...or create a new instance
- else:
- instance = self.opts.model(**attrs)
+ for key, val in attrs.items():
+ try:
+ setattr(instance, key, val)
+ except ValueError:
+ self._errors[key] = self.error_messages['required']
# Any relations that cannot be set until we've
# saved the model get hidden away on these
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index ce171d6d..38753c96 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -69,6 +69,10 @@ DEFAULTS = {
'PAGINATE_BY_PARAM': None,
'MAX_PAGINATE_BY': None,
+ # Filtering
+ 'SEARCH_PARAM': 'search',
+ 'ORDERING_PARAM': 'ordering',
+
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index beb8c5b0..dff176d6 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -180,7 +180,7 @@ def add_class(value, css_class):
# Bunch of stuff cloned from urlize
-TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"]
+TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}", "'"]
WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('&lt;', '&gt;'),
('"', '"'), ("'", "'")]
word_split_re = re.compile(r'(\s+)')
diff --git a/rest_framework/test.py b/rest_framework/test.py
index 234d10a4..df5a5b3b 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -8,6 +8,7 @@ from django.conf import settings
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
from django.test import testcases
+from django.utils.http import urlencode
from rest_framework.settings import api_settings
from rest_framework.compat import RequestFactory as DjangoRequestFactory
from rest_framework.compat import force_bytes_or_smart_bytes, six
@@ -71,6 +72,17 @@ class APIRequestFactory(DjangoRequestFactory):
return ret, content_type
+ def get(self, path, data=None, **extra):
+ r = {
+ 'QUERY_STRING': urlencode(data or {}, doseq=True),
+ }
+ # Fix to support old behavior where you have the arguments in the url
+ # See #1461
+ if not data and '?' in path:
+ r['QUERY_STRING'] = path.split('?')[1]
+ r.update(extra)
+ return self.generic('GET', path, **r)
+
def post(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('POST', path, data, content_type, **extra)
diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py
index c09c2933..92f99efd 100644
--- a/rest_framework/utils/mediatypes.py
+++ b/rest_framework/utils/mediatypes.py
@@ -74,7 +74,7 @@ class _MediaType(object):
return 0
elif self.sub_type == '*':
return 1
- elif not self.params or self.params.keys() == ['q']:
+ elif not self.params or list(self.params.keys()) == ['q']:
return 2
return 3
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 02a6e25a..a2668f2c 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -131,7 +131,7 @@ class APIView(View):
"""
If request is not permitted, determine what kind of exception to raise.
"""
- if not self.request.successful_authenticator:
+ if not request.successful_authenticator:
raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied()
@@ -295,7 +295,7 @@ class APIView(View):
# Dispatch methods
- def initialize_request(self, request, *args, **kargs):
+ def initialize_request(self, request, *args, **kwargs):
"""
Returns the initial request object.
"""