aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py4
-rw-r--r--rest_framework/authentication.py16
-rw-r--r--rest_framework/authtoken/models.py2
-rw-r--r--rest_framework/compat.py29
-rw-r--r--rest_framework/exceptions.py2
-rw-r--r--rest_framework/fields.py14
-rw-r--r--rest_framework/filters.py7
-rw-r--r--rest_framework/mixins.py37
-rw-r--r--rest_framework/parsers.py6
-rw-r--r--rest_framework/relations.py19
-rw-r--r--rest_framework/renderers.py26
-rw-r--r--rest_framework/request.py4
-rwxr-xr-xrest_framework/runtests/runtests.py4
-rw-r--r--rest_framework/serializers.py65
-rw-r--r--rest_framework/settings.py4
-rw-r--r--rest_framework/templates/rest_framework/base.html2
-rw-r--r--rest_framework/templatetags/rest_framework.py20
-rw-r--r--rest_framework/test.py12
-rw-r--r--rest_framework/tests/models.py12
-rw-r--r--rest_framework/tests/serializers.py8
-rw-r--r--rest_framework/tests/test_authentication.py42
-rw-r--r--rest_framework/tests/test_filters.py100
-rw-r--r--rest_framework/tests/test_genericrelations.py28
-rw-r--r--rest_framework/tests/test_htmlrenderer.py8
-rw-r--r--rest_framework/tests/test_nullable_fields.py30
-rw-r--r--rest_framework/tests/test_pagination.py32
-rw-r--r--rest_framework/tests/test_parsers.py4
-rw-r--r--rest_framework/tests/test_relations.py24
-rw-r--r--rest_framework/tests/test_relations_nested.py4
-rw-r--r--rest_framework/tests/test_renderers.py29
-rw-r--r--rest_framework/tests/test_serializer.py158
-rw-r--r--rest_framework/tests/test_templatetags.py34
-rw-r--r--rest_framework/tests/test_testing.py10
-rw-r--r--rest_framework/tests/test_urlizer.py38
-rw-r--r--rest_framework/tests/test_validation.py44
-rw-r--r--rest_framework/tests/utils.py25
-rw-r--r--rest_framework/tests/views.py8
-rw-r--r--rest_framework/throttling.py2
-rw-r--r--rest_framework/utils/mediatypes.py2
-rw-r--r--rest_framework/views.py4
40 files changed, 748 insertions, 171 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 6759680b..2d76b55d 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -8,10 +8,10 @@ ______ _____ _____ _____ __ _
"""
__title__ = 'Django REST framework'
-__version__ = '2.3.12'
+__version__ = '2.3.13'
__author__ = 'Tom Christie'
__license__ = 'BSD 2-Clause'
-__copyright__ = 'Copyright 2011-2013 Tom Christie'
+__copyright__ = 'Copyright 2011-2014 Tom Christie'
# Version synonym
VERSION = __version__
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index e491ce5f..da9ca510 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -6,6 +6,7 @@ import base64
from django.contrib.auth import authenticate
from django.core.exceptions import ImproperlyConfigured
+from django.conf import settings
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
@@ -291,6 +292,7 @@ class OAuth2Authentication(BaseAuthentication):
OAuth 2 authentication backend using `django-oauth2-provider`
"""
www_authenticate_realm = 'api'
+ allow_query_params_token = settings.DEBUG
def __init__(self, *args, **kwargs):
super(OAuth2Authentication, self).__init__(*args, **kwargs)
@@ -308,7 +310,13 @@ class OAuth2Authentication(BaseAuthentication):
auth = get_authorization_header(request).split()
- if not auth or auth[0].lower() != b'bearer':
+ if auth and auth[0].lower() == b'bearer':
+ access_token = auth[1]
+ elif 'access_token' in request.POST:
+ access_token = request.POST['access_token']
+ elif 'access_token' in request.GET and self.allow_query_params_token:
+ access_token = request.GET['access_token']
+ else:
return None
if len(auth) == 1:
@@ -318,7 +326,7 @@ class OAuth2Authentication(BaseAuthentication):
msg = 'Invalid bearer header. Token string should not contain spaces.'
raise exceptions.AuthenticationFailed(msg)
- return self.authenticate_credentials(request, auth[1])
+ return self.authenticate_credentials(request, access_token)
def authenticate_credentials(self, request, access_token):
"""
@@ -326,11 +334,11 @@ class OAuth2Authentication(BaseAuthentication):
"""
try:
- token = oauth2_provider.models.AccessToken.objects.select_related('user')
+ token = oauth2_provider.oauth2.models.AccessToken.objects.select_related('user')
# provider_now switches to timezone aware datetime when
# the oauth2_provider version supports to it.
token = token.get(token=access_token, expires__gt=provider_now())
- except oauth2_provider.models.AccessToken.DoesNotExist:
+ except oauth2_provider.oauth2.models.AccessToken.DoesNotExist:
raise exceptions.AuthenticationFailed('Invalid token')
user = token.user
diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py
index 8eac2cc4..167fa531 100644
--- a/rest_framework/authtoken/models.py
+++ b/rest_framework/authtoken/models.py
@@ -34,7 +34,7 @@ class Token(models.Model):
return super(Token, self).save(*args, **kwargs)
def generate_key(self):
- return binascii.hexlify(os.urandom(20))
+ return binascii.hexlify(os.urandom(20)).decode()
def __unicode__(self):
return self.key
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index d283e2f5..d155f554 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -550,13 +550,10 @@ except (ImportError, ImproperlyConfigured):
# OAuth 2 support is optional
try:
- import provider.oauth2 as oauth2_provider
- from provider.oauth2 import models as oauth2_provider_models
- from provider.oauth2 import forms as oauth2_provider_forms
+ import provider as oauth2_provider
from provider import scope as oauth2_provider_scope
from provider import constants as oauth2_constants
- from provider import __version__ as provider_version
- if provider_version in ('0.2.3', '0.2.4'):
+ if oauth2_provider.__version__ in ('0.2.3', '0.2.4'):
# 0.2.3 and 0.2.4 are supported version that do not support
# timezone aware datetimes
import datetime
@@ -566,8 +563,6 @@ try:
from django.utils.timezone import now as provider_now
except ImportError:
oauth2_provider = None
- oauth2_provider_models = None
- oauth2_provider_forms = None
oauth2_provider_scope = None
oauth2_constants = None
provider_now = None
@@ -584,3 +579,23 @@ if six.PY3:
else:
def is_non_str_iterable(obj):
return hasattr(obj, '__iter__')
+
+
+try:
+ from django.utils.encoding import python_2_unicode_compatible
+except ImportError:
+ def python_2_unicode_compatible(klass):
+ """
+ A decorator that defines __unicode__ and __str__ methods under Python 2.
+ Under Python 3 it does nothing.
+
+ To support Python 2 and 3 with a single code base, define a __str__ method
+ returning text and apply this decorator to the class.
+ """
+ if '__str__' not in klass.__dict__:
+ raise ValueError("@python_2_unicode_compatible cannot be applied "
+ "to %s because it doesn't define __str__()." %
+ klass.__name__)
+ klass.__unicode__ = klass.__str__
+ klass.__str__ = lambda self: self.__unicode__().encode('utf-8')
+ return klass
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 0ac5866e..5f774a9f 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -20,6 +20,8 @@ class APIException(Exception):
def __init__(self, detail=None):
self.detail = detail or self.default_detail
+ def __str__(self):
+ return self.detail
class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 05daaab7..8cdc5551 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -164,7 +164,7 @@ class Field(object):
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
- model_field - The model field this field corresponds to, if one exists.
+ field_name - The name of the field being initialized.
"""
self.parent = parent
self.root = parent.root or parent
@@ -289,7 +289,7 @@ class WritableField(Field):
self.validators = self.default_validators + validators
self.default = default if default is not None else self.default
- # Widgets are ony used for HTML forms.
+ # Widgets are only used for HTML forms.
widget = widget or self.widget
if isinstance(widget, type):
widget = widget()
@@ -301,6 +301,11 @@ class WritableField(Field):
result.validators = self.validators[:]
return result
+ def get_default_value(self):
+ if is_simple_callable(self.default):
+ return self.default()
+ return self.default
+
def validate(self, value):
if value in validators.EMPTY_VALUES and self.required:
raise ValidationError(self.error_messages['required'])
@@ -349,10 +354,7 @@ class WritableField(Field):
except KeyError:
if self.default is not None and not self.partial:
# Note: partial updates shouldn't set defaults
- if is_simple_callable(self.default):
- native = self.default()
- else:
- native = self.default
+ native = self.get_default_value()
else:
if self.required:
raise ValidationError(self.error_messages['required'])
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index de91caed..96d15eb9 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -6,6 +6,7 @@ from __future__ import unicode_literals
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from rest_framework.compat import django_filters, six, guardian, get_model_name
+from rest_framework.settings import api_settings
from functools import reduce
import operator
@@ -69,7 +70,8 @@ class DjangoFilterBackend(BaseFilterBackend):
class SearchFilter(BaseFilterBackend):
- search_param = 'search' # The URL query parameter used for the search.
+ # The URL query parameter used for the search.
+ search_param = api_settings.SEARCH_PARAM
def get_search_terms(self, request):
"""
@@ -107,7 +109,8 @@ class SearchFilter(BaseFilterBackend):
class OrderingFilter(BaseFilterBackend):
- ordering_param = 'ordering' # The URL query parameter used for the ordering.
+ # The URL query parameter used for the ordering.
+ ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
def get_ordering(self, request):
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 5fbcf700..e1a24dc7 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -116,30 +116,27 @@ class UpdateModelMixin(object):
partial = kwargs.pop('partial', False)
self.object = self.get_object_or_none()
- if self.object is None:
- created = True
- save_kwargs = {'force_insert': True}
- success_status_code = status.HTTP_201_CREATED
- else:
- created = False
- save_kwargs = {'force_update': True}
- success_status_code = status.HTTP_200_OK
-
serializer = self.get_serializer(self.object, data=request.DATA,
files=request.FILES, partial=partial)
- if serializer.is_valid():
- try:
- self.pre_save(serializer.object)
- except ValidationError as err:
- # full_clean on model instance may be called in pre_save, so we
- # have to handle eventual errors.
- return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST)
- self.object = serializer.save(**save_kwargs)
- self.post_save(self.object, created=created)
- return Response(serializer.data, status=success_status_code)
+ if not serializer.is_valid():
+ return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ try:
+ self.pre_save(serializer.object)
+ except ValidationError as err:
+ # full_clean on model instance may be called in pre_save,
+ # so we have to handle eventual errors.
+ return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST)
+
+ if self.object is None:
+ self.object = serializer.save(force_insert=True)
+ self.post_save(self.object, created=True)
+ return Response(serializer.data, status=status.HTTP_201_CREATED)
+
+ self.object = serializer.save(force_update=True)
+ self.post_save(self.object, created=False)
+ return Response(serializer.data, status=status.HTTP_200_OK)
def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index f1b3e38d..4990971b 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -10,7 +10,7 @@ from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter
-from rest_framework.compat import etree, six, yaml
+from rest_framework.compat import etree, six, yaml, force_text
from rest_framework.exceptions import ParseError
from rest_framework import renderers
import json
@@ -288,7 +288,7 @@ class FileUploadParser(BaseParser):
try:
meta = parser_context['request'].META
- disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'])
- return disposition[1]['filename']
+ disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8'))
+ return force_text(disposition[1]['filename'])
except (AttributeError, KeyError):
pass
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 02185c2f..3463954d 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -33,6 +33,7 @@ class RelatedField(WritableField):
many_widget = widgets.SelectMultiple
form_field_class = forms.ChoiceField
many_form_field_class = forms.MultipleChoiceField
+ null_values = (None, '', 'None')
cache_choices = False
empty_label = None
@@ -58,6 +59,8 @@ class RelatedField(WritableField):
super(RelatedField, self).__init__(*args, **kwargs)
if not self.required:
+ # Accessed in ModelChoiceIterator django/forms/models.py:1034
+ # If set adds empty choice.
self.empty_label = BLANK_CHOICE_DASH[0][1]
self.queryset = queryset
@@ -118,6 +121,14 @@ class RelatedField(WritableField):
choices = property(_get_choices, _set_choices)
+ ### Default value handling
+
+ def get_default_value(self):
+ default = super(RelatedField, self).get_default_value()
+ if self.many and default is None:
+ return []
+ return default
+
### Regular serializer stuff...
def field_to_native(self, obj, field_name):
@@ -166,11 +177,11 @@ class RelatedField(WritableField):
except KeyError:
if self.partial:
return
- value = [] if self.many else None
+ value = self.get_default_value()
- if value in (None, '') and self.required:
- raise ValidationError(self.error_messages['required'])
- elif value in (None, ''):
+ if value in self.null_values:
+ if self.required:
+ raise ValidationError(self.error_messages['required'])
into[(self.source or field_name)] = None
elif self.many:
into[(self.source or field_name)] = [self.from_native(item) for item in value]
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index e8afc26d..484961ad 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -146,7 +146,7 @@ class XMLRenderer(BaseRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
- Renders *obj* into serialized XML.
+ Renders `data` into serialized XML.
"""
if data is None:
return ''
@@ -193,17 +193,26 @@ class YAMLRenderer(BaseRenderer):
format = 'yaml'
encoder = encoders.SafeDumper
charset = 'utf-8'
+ ensure_ascii = True
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
- Renders *obj* into serialized YAML.
+ Renders `data` into serialized YAML.
"""
assert yaml, 'YAMLRenderer requires pyyaml to be installed'
if data is None:
return ''
- return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder)
+ return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii)
+
+
+class UnicodeYAMLRenderer(YAMLRenderer):
+ """
+ Renderer which serializes to YAML.
+ Does *not* apply character escaping for non-ascii characters.
+ """
+ ensure_ascii = False
class TemplateHTMLRenderer(BaseRenderer):
@@ -427,7 +436,7 @@ class BrowsableAPIRenderer(BaseRenderer):
files = request.FILES
except ParseError:
data = None
- files = None
+ files = None
else:
data = None
files = None
@@ -544,6 +553,14 @@ class BrowsableAPIRenderer(BaseRenderer):
raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request)
raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form
+ response_headers = dict(response.items())
+ renderer_content_type = ''
+ if renderer:
+ renderer_content_type = '%s' % renderer.media_type
+ if renderer.charset:
+ renderer_content_type += ' ;%s' % renderer.charset
+ response_headers['Content-Type'] = renderer_content_type
+
context = {
'content': self.get_content(renderer, data, accepted_media_type, renderer_context),
'view': view,
@@ -555,6 +572,7 @@ class BrowsableAPIRenderer(BaseRenderer):
'breadcrumblist': self.get_breadcrumbs(request),
'allowed_methods': view.allowed_methods,
'available_formats': [renderer.format for renderer in view.renderer_classes],
+ 'response_headers': response_headers,
'put_form': self.get_rendered_html_form(view, 'PUT', request),
'post_form': self.get_rendered_html_form(view, 'POST', request),
diff --git a/rest_framework/request.py b/rest_framework/request.py
index ca70b49e..40467c03 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -346,7 +346,7 @@ class Request(object):
media_type = self.content_type
if stream is None or media_type is None:
- empty_data = QueryDict('', self._request._encoding)
+ empty_data = QueryDict('', encoding=self._request._encoding)
empty_files = MultiValueDict()
return (empty_data, empty_files)
@@ -362,7 +362,7 @@ class Request(object):
# re-raise. Ensures we don't simply repeat the error when
# attempting to render the browsable renderer response, or when
# logging the request or similar.
- self._data = QueryDict('', self._request._encoding)
+ self._data = QueryDict('', encoding=self._request._encoding)
self._files = MultiValueDict()
raise
diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py
index da36d23f..2daaae4e 100755
--- a/rest_framework/runtests/runtests.py
+++ b/rest_framework/runtests/runtests.py
@@ -26,6 +26,10 @@ def usage():
def main():
+ try:
+ django.setup()
+ except AttributeError:
+ pass
TestRunner = get_runner(settings)
test_runner = TestRunner()
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 10256d47..9cb548a5 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -16,6 +16,7 @@ import datetime
import inspect
import types
from decimal import Decimal
+from django.contrib.contenttypes.generic import GenericForeignKey
from django.core.paginator import Page
from django.db import models
from django.forms import widgets
@@ -438,16 +439,6 @@ class BaseSerializer(WritableField):
raise ValidationError(self.error_messages['required'])
return
- # Set the serializer object if it exists
- obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
-
- # If we have a model manager or similar object then we need
- # to iterate through each instance.
- if (self.many and
- not hasattr(obj, '__iter__') and
- is_simple_callable(getattr(obj, 'all', None))):
- obj = obj.all()
-
if self.source == '*':
if value:
reverted_data = self.restore_fields(value, {})
@@ -457,6 +448,16 @@ class BaseSerializer(WritableField):
if value in (None, ''):
into[(self.source or field_name)] = None
else:
+ # Set the serializer object if it exists
+ obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
+
+ # If we have a model manager or similar object then we need
+ # to iterate through each instance.
+ if (self.many and
+ not hasattr(obj, '__iter__') and
+ is_simple_callable(getattr(obj, 'all', None))):
+ obj = obj.all()
+
kwargs = {
'instance': obj,
'data': value,
@@ -757,8 +758,11 @@ class ModelSerializer(Serializer):
field.read_only = True
ret[accessor_name] = field
+
+ # Ensure that 'read_only_fields' is an iterable
+ assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
- # Add the `read_only` flag to any fields that have bee specified
+ # Add the `read_only` flag to any fields that have been specified
# in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
assert field_name not in self.base_fields.keys(), (
@@ -771,7 +775,10 @@ class ModelSerializer(Serializer):
"on serializer '%s'." %
(field_name, self.__class__.__name__))
ret[field_name].read_only = True
-
+
+ # Ensure that 'write_only_fields' is an iterable
+ assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
+
for field_name in self.opts.write_only_fields:
assert field_name not in self.base_fields.keys(), (
"field '%s' on serializer '%s' specified in "
@@ -821,6 +828,10 @@ class ModelSerializer(Serializer):
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
+ if model_field.help_text is not None:
+ kwargs['help_text'] = model_field.help_text
+ if model_field.verbose_name is not None:
+ kwargs['label'] = model_field.verbose_name
return PrimaryKeyRelatedField(**kwargs)
@@ -881,7 +892,7 @@ class ModelSerializer(Serializer):
except KeyError:
return ModelField(model_field=model_field, **kwargs)
- def get_validation_exclusions(self):
+ def get_validation_exclusions(self, instance=None):
"""
Return a list of field names to exclude from model validation.
"""
@@ -893,7 +904,7 @@ class ModelSerializer(Serializer):
field_name = field.source or field_name
if field_name in exclusions \
and not field.read_only \
- and field.required \
+ and (field.required or hasattr(instance, field_name)) \
and not isinstance(field, Serializer):
exclusions.remove(field_name)
return exclusions
@@ -908,7 +919,7 @@ class ModelSerializer(Serializer):
the full_clean validation checking.
"""
try:
- instance.full_clean(exclude=self.get_validation_exclusions())
+ instance.full_clean(exclude=self.get_validation_exclusions(instance))
except ValidationError as err:
self._errors = err.message_dict
return None
@@ -937,6 +948,8 @@ class ModelSerializer(Serializer):
# Forward m2m relations
for field in meta.many_to_many + meta.virtual_fields:
+ if isinstance(field, GenericForeignKey):
+ continue
if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name)
@@ -946,17 +959,15 @@ class ModelSerializer(Serializer):
if isinstance(self.fields.get(field_name, None), Serializer):
nested_forward_relations[field_name] = attrs[field_name]
- # Update an existing instance...
- if instance is not None:
- for key, val in attrs.items():
- try:
- setattr(instance, key, val)
- except ValueError:
- self._errors[key] = self.error_messages['required']
+ # Create an empty instance of the model
+ if instance is None:
+ instance = self.opts.model()
- # ...or create a new instance
- else:
- instance = self.opts.model(**attrs)
+ for key, val in attrs.items():
+ try:
+ setattr(instance, key, val)
+ except ValueError:
+ self._errors[key] = self.error_messages['required']
# Any relations that cannot be set until we've
# saved the model get hidden away on these
@@ -1081,6 +1092,10 @@ class HyperlinkedModelSerializer(ModelSerializer):
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
+ if model_field.help_text is not None:
+ kwargs['help_text'] = model_field.help_text
+ if model_field.verbose_name is not None:
+ kwargs['label'] = model_field.verbose_name
if self.opts.lookup_field:
kwargs['lookup_field'] = self.opts.lookup_field
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index ce171d6d..38753c96 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -69,6 +69,10 @@ DEFAULTS = {
'PAGINATE_BY_PARAM': None,
'MAX_PAGINATE_BY': None,
+ # Filtering
+ 'SEARCH_PARAM': 'search',
+ 'ORDERING_PARAM': 'ordering',
+
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index d19d5a2b..7067ee2f 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -118,7 +118,7 @@
</div>
<div class="response-info">
<pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %}
-{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span>
+{% for key, val in response_headers.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span>
{% endfor %}
</div>{{ content|urlize_quoted_links }}</pre>{% endautoescape %}
</div>
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index 83c046f9..dff176d6 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -6,7 +6,7 @@ from django.utils.encoding import iri_to_uri
from django.utils.html import escape
from django.utils.safestring import SafeData, mark_safe
from rest_framework.compat import urlparse, force_text, six, smart_urlquote
-import re, string
+import re
register = template.Library()
@@ -180,7 +180,7 @@ def add_class(value, css_class):
# Bunch of stuff cloned from urlize
-TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"]
+TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}", "'"]
WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('&lt;', '&gt;'),
('"', '"'), ("'", "'")]
word_split_re = re.compile(r'(\s+)')
@@ -189,6 +189,17 @@ simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net
simple_email_re = re.compile(r'^\S+@\S+\.\S+$')
+def smart_urlquote_wrapper(matched_url):
+ """
+ Simple wrapper for smart_urlquote. ValueError("Invalid IPv6 URL") can
+ be raised here, see issue #1386
+ """
+ try:
+ return smart_urlquote(matched_url)
+ except ValueError:
+ return None
+
+
@register.filter
def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True):
"""
@@ -211,7 +222,6 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
safe_input = isinstance(text, SafeData)
words = word_split_re.split(force_text(text))
for i, word in enumerate(words):
- match = None
if '.' in word or '@' in word or ':' in word:
# Deal with punctuation.
lead, middle, trail = '', word, ''
@@ -233,9 +243,9 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
url = None
nofollow_attr = ' rel="nofollow"' if nofollow else ''
if simple_url_re.match(middle):
- url = smart_urlquote(middle)
+ url = smart_urlquote_wrapper(middle)
elif simple_url_2_re.match(middle):
- url = smart_urlquote('http://%s' % middle)
+ url = smart_urlquote_wrapper('http://%s' % middle)
elif not ':' in middle and simple_email_re.match(middle):
local, domain = middle.rsplit('@', 1)
try:
diff --git a/rest_framework/test.py b/rest_framework/test.py
index 234d10a4..df5a5b3b 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -8,6 +8,7 @@ from django.conf import settings
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
from django.test import testcases
+from django.utils.http import urlencode
from rest_framework.settings import api_settings
from rest_framework.compat import RequestFactory as DjangoRequestFactory
from rest_framework.compat import force_bytes_or_smart_bytes, six
@@ -71,6 +72,17 @@ class APIRequestFactory(DjangoRequestFactory):
return ret, content_type
+ def get(self, path, data=None, **extra):
+ r = {
+ 'QUERY_STRING': urlencode(data or {}, doseq=True),
+ }
+ # Fix to support old behavior where you have the arguments in the url
+ # See #1461
+ if not data and '?' in path:
+ r['QUERY_STRING'] = path.split('?')[1]
+ r.update(extra)
+ return self.generic('GET', path, **r)
+
def post(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('POST', path, data, content_type, **extra)
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 32a726c0..0256697a 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -103,7 +103,7 @@ class BlogPostComment(RESTFrameworkModel):
class Album(RESTFrameworkModel):
title = models.CharField(max_length=100, unique=True)
-
+ ref = models.CharField(max_length=10, unique=True, null=True, blank=True)
class Photo(RESTFrameworkModel):
description = models.TextField()
@@ -143,7 +143,8 @@ class ForeignKeyTarget(RESTFrameworkModel):
class ForeignKeySource(RESTFrameworkModel):
name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
+ target = models.ForeignKey(ForeignKeyTarget, related_name='sources',
+ help_text='Target', verbose_name='Target')
# Nullable ForeignKey
@@ -168,3 +169,10 @@ class NullableOneToOneSource(RESTFrameworkModel):
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
+
+
+# Models to test filters
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
diff --git a/rest_framework/tests/serializers.py b/rest_framework/tests/serializers.py
new file mode 100644
index 00000000..cc943c7d
--- /dev/null
+++ b/rest_framework/tests/serializers.py
@@ -0,0 +1,8 @@
+from rest_framework import serializers
+
+from rest_framework.tests.models import NullableForeignKeySource
+
+
+class NullableFKSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py
index f072b81b..a1c43d9c 100644
--- a/rest_framework/tests/test_authentication.py
+++ b/rest_framework/tests/test_authentication.py
@@ -3,6 +3,7 @@ from django.contrib.auth.models import User
from django.http import HttpResponse
from django.test import TestCase
from django.utils import unittest
+from django.utils.http import urlencode
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions
from rest_framework import permissions
@@ -18,8 +19,8 @@ from rest_framework.authentication import (
OAuth2Authentication
)
from rest_framework.authtoken.models import Token
-from rest_framework.compat import patterns, url, include
-from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
+from rest_framework.compat import patterns, url, include, six
+from rest_framework.compat import oauth2_provider, oauth2_provider_scope
from rest_framework.compat import oauth, oauth_provider
from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView
@@ -53,10 +54,14 @@ urlpatterns = patterns('',
permission_classes=[permissions.TokenHasReadWriteScope]))
)
+class OAuth2AuthenticationDebug(OAuth2Authentication):
+ allow_query_params_token = True
+
if oauth2_provider is not None:
urlpatterns += patterns('',
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
+ url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
permission_classes=[permissions.TokenHasReadWriteScope])),
)
@@ -190,6 +195,12 @@ class TokenAuthTests(TestCase):
token = Token.objects.create(user=self.user)
self.assertTrue(bool(token.key))
+ def test_generate_key_returns_string(self):
+ """Ensure generate_key returns a string"""
+ token = Token()
+ key = token.generate_key()
+ self.assertTrue(isinstance(key, six.string_types))
+
def test_token_login_json(self):
"""Ensure token login view using JSON POST works."""
client = APIClient(enforce_csrf_checks=True)
@@ -488,7 +499,7 @@ class OAuth2Tests(TestCase):
self.ACCESS_TOKEN = "access_token"
self.REFRESH_TOKEN = "refresh_token"
- self.oauth2_client = oauth2_provider_models.Client.objects.create(
+ self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create(
client_id=self.CLIENT_ID,
client_secret=self.CLIENT_SECRET,
redirect_uri='',
@@ -497,12 +508,12 @@ class OAuth2Tests(TestCase):
user=None,
)
- self.access_token = oauth2_provider_models.AccessToken.objects.create(
+ self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create(
token=self.ACCESS_TOKEN,
client=self.oauth2_client,
user=self.user,
)
- self.refresh_token = oauth2_provider_models.RefreshToken.objects.create(
+ self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create(
user=self.user,
access_token=self.access_token,
client=self.oauth2_client
@@ -546,6 +557,27 @@ class OAuth2Tests(TestCase):
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_passing_auth_url_transport(self):
+ """Ensure GETing form over OAuth with correct client credentials in form data succeed"""
+ response = self.csrf_client.post('/oauth2-test/',
+ data={'access_token': self.access_token.token})
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_passing_auth_url_transport(self):
+ """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True"""
+ query = urlencode({'access_token': self.access_token.token})
+ response = self.csrf_client.get('/oauth2-test-debug/?%s' % query)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_failing_auth_url_transport(self):
+ """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False"""
+ query = urlencode({'access_token': self.access_token.token})
+ response = self.csrf_client.get('/oauth2-test/?%s' % query)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth(self):
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
auth = self._create_authorization_header()
diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py
index 18188186..23226bbc 100644
--- a/rest_framework/tests/test_filters.py
+++ b/rest_framework/tests/test_filters.py
@@ -7,18 +7,15 @@ from django.test import TestCase
from django.utils import unittest
from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url
+from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel
+from .models import FilterableItem
+from .utils import temporary_setting
factory = APIRequestFactory()
-class FilterableItem(models.Model):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
-
-
if django_filters:
# Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView):
@@ -128,7 +125,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter works.
search_decimal = Decimal('2.25')
- request = factory.get('/?decimal=%s' % search_decimal)
+ request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal]
@@ -136,7 +133,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the date filter works.
search_date = datetime.date(2012, 9, 22)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-09-22'
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] == search_date]
@@ -151,7 +148,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter works.
search_decimal = Decimal('2.25')
- request = factory.get('/?decimal=%s' % search_decimal)
+ request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal]
@@ -184,7 +181,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter set with 'lt' in the filter class works.
search_decimal = Decimal('4.25')
- request = factory.get('/?decimal=%s' % search_decimal)
+ request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] < search_decimal]
@@ -192,7 +189,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the date filter set with 'gt' in the filter class works.
search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-10-02'
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date]
@@ -200,7 +197,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the text filter set with 'icontains' in the filter class works.
search_text = 'ff'
- request = factory.get('/?text=%s' % search_text)
+ request = factory.get('/', {'text': '%s' % search_text})
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if search_text in f['text'].lower()]
@@ -209,7 +206,10 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that multiple filters works.
search_decimal = Decimal('5.25')
search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ request = factory.get('/', {
+ 'decimal': '%s' % (search_decimal,),
+ 'date': '%s' % (search_date,)
+ })
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date and
@@ -234,7 +234,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
view = FilterFieldsRootView.as_view()
search_integer = 10
- request = factory.get('/?integer=%s' % search_integer)
+ request = factory.get('/', {'integer': '%s' % search_integer})
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -265,14 +265,18 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
# Tests that the decimal filter set that should fail.
search_decimal = Decimal('4.25')
high_item = self.objects.filter(decimal__gt=search_decimal)[0]
- response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
+ response = self.client.get(
+ '{url}'.format(url=self._get_url(high_item)),
+ {'decimal': '{param}'.format(param=search_decimal)})
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
# Tests that the decimal filter set that should succeed.
search_decimal = Decimal('4.25')
low_item = self.objects.filter(decimal__lt=search_decimal)[0]
low_item_data = self._serialize_object(low_item)
- response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
+ response = self.client.get(
+ '{url}'.format(url=self._get_url(low_item)),
+ {'decimal': '{param}'.format(param=search_decimal)})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, low_item_data)
@@ -281,7 +285,11 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
search_date = datetime.date(2012, 10, 2)
valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
valid_item_data = self._serialize_object(valid_item)
- response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
+ response = self.client.get(
+ '{url}'.format(url=self._get_url(valid_item)), {
+ 'decimal': '{decimal}'.format(decimal=search_decimal),
+ 'date': '{date}'.format(date=search_date)
+ })
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, valid_item_data)
@@ -315,7 +323,7 @@ class SearchFilterTests(TestCase):
search_fields = ('title', 'text')
view = SearchListView.as_view()
- request = factory.get('?search=b')
+ request = factory.get('/', {'search': 'b'})
response = view(request)
self.assertEqual(
response.data,
@@ -332,7 +340,7 @@ class SearchFilterTests(TestCase):
search_fields = ('=title', 'text')
view = SearchListView.as_view()
- request = factory.get('?search=zzz')
+ request = factory.get('/', {'search': 'zzz'})
response = view(request)
self.assertEqual(
response.data,
@@ -348,7 +356,7 @@ class SearchFilterTests(TestCase):
search_fields = ('title', '^text')
view = SearchListView.as_view()
- request = factory.get('?search=b')
+ request = factory.get('/', {'search': 'b'})
response = view(request)
self.assertEqual(
response.data,
@@ -357,6 +365,24 @@ class SearchFilterTests(TestCase):
]
)
+ def test_search_with_nonstandard_search_param(self):
+ with temporary_setting('SEARCH_PARAM', 'query', module=filters):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('/', {'query': 'b'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'z', 'text': 'abc'},
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
class OrdringFilterModel(models.Model):
title = models.CharField(max_length=20)
@@ -396,7 +422,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('?ordering=text')
+ request = factory.get('/', {'ordering': 'text'})
response = view(request)
self.assertEqual(
response.data,
@@ -415,7 +441,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('?ordering=-text')
+ request = factory.get('/', {'ordering': '-text'})
response = view(request)
self.assertEqual(
response.data,
@@ -434,7 +460,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('?ordering=foobar')
+ request = factory.get('/', {'ordering': 'foobar'})
response = view(request)
self.assertEqual(
response.data,
@@ -503,7 +529,7 @@ class OrderingFilterTests(TestCase):
models.Count("relateds"))
view = OrderingListView.as_view()
- request = factory.get('?ordering=relateds__count')
+ request = factory.get('/', {'ordering': 'relateds__count'})
response = view(request)
self.assertEqual(
response.data,
@@ -514,6 +540,26 @@ class OrderingFilterTests(TestCase):
]
)
+ def test_ordering_with_nonstandard_ordering_param(self):
+ with temporary_setting('ORDERING_PARAM', 'order', filters):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('/', {'order': 'text'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
class SensitiveOrderingFilterModel(models.Model):
username = models.CharField(max_length=20)
@@ -566,7 +612,7 @@ class SensitiveOrderingFilterTests(TestCase):
serializer_class = serializer_cls
view = OrderingListView.as_view()
- request = factory.get('?ordering=-username')
+ request = factory.get('/', {'ordering': '-username'})
response = view(request)
if serializer_cls == SensitiveDataSerializer3:
@@ -596,7 +642,7 @@ class SensitiveOrderingFilterTests(TestCase):
serializer_class = serializer_cls
view = OrderingListView.as_view()
- request = factory.get('?ordering=password')
+ request = factory.get('/', {'ordering': 'password'})
response = view(request)
if serializer_cls == SensitiveDataSerializer3:
@@ -612,4 +658,4 @@ class SensitiveOrderingFilterTests(TestCase):
{'id': 2, username_field: 'userB'}, # PassC
{'id': 3, username_field: 'userC'}, # PassA
]
- ) \ No newline at end of file
+ )
diff --git a/rest_framework/tests/test_genericrelations.py b/rest_framework/tests/test_genericrelations.py
index 2d341344..46a2d863 100644
--- a/rest_framework/tests/test_genericrelations.py
+++ b/rest_framework/tests/test_genericrelations.py
@@ -4,8 +4,10 @@ from django.contrib.contenttypes.generic import GenericRelation, GenericForeignK
from django.db import models
from django.test import TestCase
from rest_framework import serializers
+from rest_framework.compat import python_2_unicode_compatible
+@python_2_unicode_compatible
class Tag(models.Model):
"""
Tags have a descriptive slug, and are attached to an arbitrary object.
@@ -15,10 +17,11 @@ class Tag(models.Model):
object_id = models.PositiveIntegerField()
tagged_item = GenericForeignKey('content_type', 'object_id')
- def __unicode__(self):
+ def __str__(self):
return self.tag
+@python_2_unicode_compatible
class Bookmark(models.Model):
"""
A URL bookmark that may have multiple tags attached.
@@ -26,10 +29,11 @@ class Bookmark(models.Model):
url = models.URLField()
tags = GenericRelation(Tag)
- def __unicode__(self):
+ def __str__(self):
return 'Bookmark: %s' % self.url
+@python_2_unicode_compatible
class Note(models.Model):
"""
A textual note that may have multiple tags attached.
@@ -37,7 +41,7 @@ class Note(models.Model):
text = models.TextField()
tags = GenericRelation(Tag)
- def __unicode__(self):
+ def __str__(self):
return 'Note: %s' % self.text
@@ -127,3 +131,21 @@ class TestGenericRelations(TestCase):
}
]
self.assertEqual(serializer.data, expected)
+
+ def test_restore_object_generic_fk(self):
+ """
+ Ensure an object with a generic foreign key can be restored.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Tag
+ exclude = ('content_type', 'object_id')
+
+ serializer = TagSerializer()
+
+ bookmark = Bookmark(url='http://example.com')
+ attrs = {'tagged_item': bookmark, 'tag': 'example'}
+
+ tag = serializer.restore_object(attrs)
+ self.assertEqual(tag.tagged_item, bookmark)
diff --git a/rest_framework/tests/test_htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py
index 8957a43c..514d9e2b 100644
--- a/rest_framework/tests/test_htmlrenderer.py
+++ b/rest_framework/tests/test_htmlrenderer.py
@@ -50,7 +50,7 @@ class TemplateHTMLRendererTests(TestCase):
"""
self.get_template = django.template.loader.get_template
- def get_template(template_name):
+ def get_template(template_name, dirs=None):
if template_name == 'example.html':
return Template("example: {{ object }}")
raise TemplateDoesNotExist(template_name)
@@ -108,11 +108,13 @@ class TemplateHTMLRendererExceptionTests(TestCase):
def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertEqual(response.content, six.b("404: Not found"))
+ self.assertTrue(response.content in (
+ six.b("404: Not found"), six.b("404 Not Found")))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertEqual(response.content, six.b("403: Permission denied"))
+ self.assertTrue(response.content in (
+ six.b("403: Permission denied"), six.b("403 Forbidden")))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
diff --git a/rest_framework/tests/test_nullable_fields.py b/rest_framework/tests/test_nullable_fields.py
new file mode 100644
index 00000000..6ee55c00
--- /dev/null
+++ b/rest_framework/tests/test_nullable_fields.py
@@ -0,0 +1,30 @@
+from django.core.urlresolvers import reverse
+
+from rest_framework.compat import patterns, url
+from rest_framework.test import APITestCase
+from rest_framework.tests.models import NullableForeignKeySource
+from rest_framework.tests.serializers import NullableFKSourceSerializer
+from rest_framework.tests.views import NullableFKSourceDetail
+
+
+urlpatterns = patterns(
+ '',
+ url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'),
+)
+
+
+class NullableForeignKeyTests(APITestCase):
+ """
+ DRF should be able to handle nullable foreign keys when a test
+ Client POST/PUT request is made with its own serialized object.
+ """
+ urls = 'rest_framework.tests.test_nullable_fields'
+
+ def test_updating_object_with_null_fk(self):
+ obj = NullableForeignKeySource(name='example', target=None)
+ obj.save()
+ serialized_data = NullableFKSourceSerializer(obj).data
+
+ response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data)
+
+ self.assertEqual(response.data, serialized_data)
diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py
index cadb515f..24c1ba39 100644
--- a/rest_framework/tests/test_pagination.py
+++ b/rest_framework/tests/test_pagination.py
@@ -9,14 +9,18 @@ from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel
+from .models import FilterableItem
factory = APIRequestFactory()
+# Helper function to split arguments out of an url
+def split_arguments_from_url(url):
+ if '?' not in url:
+ return url
-class FilterableItem(models.Model):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
+ path, args = url.split('?')
+ args = dict(r.split('=') for r in args.split('&'))
+ return path, args
class RootView(generics.ListCreateAPIView):
@@ -84,7 +88,7 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None)
- request = factory.get(response.data['next'])
+ request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -93,7 +97,7 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None)
- request = factory.get(response.data['next'])
+ request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -146,7 +150,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
EXPECTED_NUM_QUERIES = 2
- request = factory.get('/?decimal=15.20')
+ request = factory.get('/', {'decimal': '15.20'})
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -155,7 +159,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None)
- request = factory.get(response.data['next'])
+ request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -164,7 +168,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None)
- request = factory.get(response.data['previous'])
+ request = factory.get(*split_arguments_from_url(response.data['previous']))
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -191,7 +195,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
view = BasicFilterFieldsRootView.as_view()
- request = factory.get('/?decimal=15.20')
+ request = factory.get('/', {'decimal': '15.20'})
with self.assertNumQueries(2):
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -200,7 +204,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None)
- request = factory.get(response.data['next'])
+ request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2):
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -209,7 +213,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None)
- request = factory.get(response.data['previous'])
+ request = factory.get(*split_arguments_from_url(response.data['previous']))
with self.assertNumQueries(2):
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -317,7 +321,7 @@ class TestCustomPaginateByParam(TestCase):
"""
If paginate_by_param is set, the new kwarg should limit per view requests.
"""
- request = factory.get('/?page_size=5')
+ request = factory.get('/', {'page_size': 5})
response = self.view(request).render()
self.assertEqual(response.data['count'], 13)
self.assertEqual(response.data['results'], self.data[:5])
@@ -345,7 +349,7 @@ class TestMaxPaginateByParam(TestCase):
"""
If max_paginate_by is set, it should limit page size for the view.
"""
- request = factory.get('/?page_size=10')
+ request = factory.get('/', data={'page_size': 10})
response = self.view(request).render()
self.assertEqual(response.data['count'], 13)
self.assertEqual(response.data['results'], self.data[:5])
diff --git a/rest_framework/tests/test_parsers.py b/rest_framework/tests/test_parsers.py
index 7699e10c..8af90677 100644
--- a/rest_framework/tests/test_parsers.py
+++ b/rest_framework/tests/test_parsers.py
@@ -96,7 +96,7 @@ class TestFileUploadParser(TestCase):
request = MockRequest()
request.upload_handlers = (MemoryFileUploadHandler(),)
request.META = {
- 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'),
+ 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt',
'HTTP_CONTENT_LENGTH': 14,
}
self.parser_context = {'request': request, 'kwargs': {}}
@@ -112,4 +112,4 @@ class TestFileUploadParser(TestCase):
def test_get_filename(self):
parser = FileUploadParser()
filename = parser.get_filename(self.stream, None, self.parser_context)
- self.assertEqual(filename, 'file.txt'.encode('utf-8'))
+ self.assertEqual(filename, 'file.txt')
diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py
index f52e0e1e..37ac826b 100644
--- a/rest_framework/tests/test_relations.py
+++ b/rest_framework/tests/test_relations.py
@@ -2,8 +2,10 @@
General tests for relational fields.
"""
from __future__ import unicode_literals
+from django import get_version
from django.db import models
from django.test import TestCase
+from django.utils import unittest
from rest_framework import serializers
from rest_framework.tests.models import BlogPost
@@ -118,3 +120,25 @@ class RelatedFieldSourceTests(TestCase):
(serializers.ModelSerializer,), attrs)
with self.assertRaises(AttributeError):
TestSerializer(data={'name': 'foo'})
+
+@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6')
+class RelatedFieldChoicesTests(TestCase):
+ """
+ Tests for #1408 "Web browseable API doesn't have blank option on drop down list box"
+ https://github.com/tomchristie/django-rest-framework/issues/1408
+ """
+ def test_blank_option_is_added_to_choice_if_required_equals_false(self):
+ """
+
+ """
+ post = BlogPost(title="Checking blank option is added")
+ post.save()
+
+ queryset = BlogPost.objects.all()
+ field = serializers.RelatedField(required=False, queryset=queryset)
+
+ choice_count = BlogPost.objects.count()
+ widget_count = len(field.widget.choices)
+
+ self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added')
+
diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py
index d393b0c3..4d9da489 100644
--- a/rest_framework/tests/test_relations_nested.py
+++ b/rest_framework/tests/test_relations_nested.py
@@ -3,9 +3,7 @@ from django.db import models
from django.test import TestCase
from rest_framework import serializers
-
-class OneToOneTarget(models.Model):
- name = models.CharField(max_length=100)
+from .models import OneToOneTarget
class OneToOneSource(models.Model):
diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py
index fb33df2c..7cb7d0f9 100644
--- a/rest_framework/tests/test_renderers.py
+++ b/rest_framework/tests/test_renderers.py
@@ -12,7 +12,7 @@ from rest_framework.compat import yaml, etree, patterns, url, include, six, Stri
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
- XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer
+ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer, UnicodeYAMLRenderer
from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
@@ -256,6 +256,18 @@ class RendererEndToEndTests(TestCase):
self.assertEqual(resp.get('Content-Type', None), None)
self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
+ def test_contains_headers_of_api_response(self):
+ """
+ Issue #1437
+
+ Test we display the headers of the API response and not those from the
+ HTML response
+ """
+ resp = self.client.get('/html1')
+ self.assertContains(resp, '>GET, HEAD, OPTIONS<')
+ self.assertContains(resp, '>application/json<')
+ self.assertNotContains(resp, '>text/html; charset=utf-8<')
+
_flat_repr = '{"foo": ["bar", "baz"]}'
_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}'
@@ -455,6 +467,17 @@ if yaml:
self.assertTrue(string in content, '%r not in %r' % (string, content))
+ class UnicodeYAMLRendererTests(TestCase):
+ """
+ Tests specific for the Unicode YAML Renderer
+ """
+ def test_proper_encoding(self):
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = UnicodeYAMLRenderer()
+ content = renderer.render(obj, 'application/yaml')
+ self.assertEqual(content.strip(), 'countries: [United Kingdom, France, España]'.encode('utf-8'))
+
+
class XMLRendererTestCase(TestCase):
"""
Tests specific to the XML Renderer
@@ -601,6 +624,10 @@ class CacheRenderTest(TestCase):
method = getattr(self.client, http_method)
resp = method(url)
del resp.client, resp.request
+ try:
+ del resp.wsgi_request
+ except AttributeError:
+ pass
return resp
def test_obj_pickling(self):
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index 6b1e333e..e688c823 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -3,15 +3,42 @@ from __future__ import unicode_literals
from django.db import models
from django.db.models.fields import BLANK_CHOICE_DASH
from django.test import TestCase
+from django.utils import unittest
from django.utils.datastructures import MultiValueDict
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers, fields, relations
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
- ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel)
+ ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel,
+ ForeignKeySource, ManyToManySource)
from rest_framework.tests.models import BasicModelSerializer
import datetime
import pickle
+try:
+ import PIL
+except:
+ PIL = None
+
+
+if PIL is not None:
+ class AMOAFModel(RESTFrameworkModel):
+ char_field = models.CharField(max_length=1024, blank=True)
+ comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
+ decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
+ email_field = models.EmailField(max_length=1024, blank=True)
+ file_field = models.FileField(upload_to='test', max_length=1024, blank=True)
+ image_field = models.ImageField(upload_to='test', max_length=1024, blank=True)
+ slug_field = models.SlugField(max_length=1024, blank=True)
+ url_field = models.URLField(max_length=1024, blank=True)
+
+ class DVOAFModel(RESTFrameworkModel):
+ positive_integer_field = models.PositiveIntegerField(blank=True)
+ positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
+ email_field = models.EmailField(blank=True)
+ file_field = models.FileField(upload_to='test', blank=True)
+ image_field = models.ImageField(upload_to='test', blank=True)
+ slug_field = models.SlugField(blank=True)
+ url_field = models.URLField(blank=True)
class SubComment(object):
@@ -141,7 +168,7 @@ class AlbumsSerializer(serializers.ModelSerializer):
class Meta:
model = Album
- fields = ['title'] # lists are also valid options
+ fields = ['title', 'ref'] # lists are also valid options
class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
@@ -150,6 +177,16 @@ class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
fields = ['some_integer']
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+
+
+class HyperlinkedForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+
+
class BasicTests(TestCase):
def setUp(self):
self.comment = Comment(
@@ -482,6 +519,32 @@ class ValidationTests(TestCase):
)
self.assertEqual(serializer.is_valid(), True)
+ def test_writable_star_source_on_nested_serializer_with_parent_object(self):
+ class TitleSerializer(serializers.Serializer):
+ title = serializers.WritableField(source='title')
+
+ class AlbumSerializer(serializers.ModelSerializer):
+ nested = TitleSerializer(source='*')
+
+ class Meta:
+ model = Album
+ fields = ('nested',)
+
+ class PhotoSerializer(serializers.ModelSerializer):
+ album = AlbumSerializer(source='album')
+
+ class Meta:
+ model = Photo
+ fields = ('album', )
+
+ photo = Photo(album=Album())
+
+ data = {'album': {'nested': {'title': 'test'}}}
+
+ serializer = PhotoSerializer(photo, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+
def test_writable_star_source_with_inner_source_fields(self):
"""
Tests that a serializer with source="*" correctly expands the
@@ -591,12 +654,15 @@ class ModelValidationTests(TestCase):
"""
Just check if serializers.ModelSerializer handles unique checks via .full_clean()
"""
- serializer = AlbumsSerializer(data={'title': 'a'})
+ serializer = AlbumsSerializer(data={'title': 'a', 'ref': '1'})
serializer.is_valid()
serializer.save()
second_serializer = AlbumsSerializer(data={'title': 'a'})
self.assertFalse(second_serializer.is_valid())
- self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
+ self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.'],})
+ third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}])
+ self.assertFalse(third_serializer.is_valid())
+ self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}])
def test_foreign_key_is_null_with_partial(self):
"""
@@ -880,6 +946,58 @@ class DefaultValueTests(TestCase):
self.assertEqual(instance.text, 'overridden')
+class WritableFieldDefaultValueTests(TestCase):
+
+ def setUp(self):
+ self.expected = {'default': 'value'}
+ self.create_field = fields.WritableField
+
+ def test_get_default_value_with_noncallable(self):
+ field = self.create_field(default=self.expected)
+ got = field.get_default_value()
+ self.assertEqual(got, self.expected)
+
+ def test_get_default_value_with_callable(self):
+ field = self.create_field(default=lambda : self.expected)
+ got = field.get_default_value()
+ self.assertEqual(got, self.expected)
+
+ def test_get_default_value_when_not_required(self):
+ field = self.create_field(default=self.expected, required=False)
+ got = field.get_default_value()
+ self.assertEqual(got, self.expected)
+
+ def test_get_default_value_returns_None(self):
+ field = self.create_field()
+ got = field.get_default_value()
+ self.assertIsNone(got)
+
+ def test_get_default_value_returns_non_True_values(self):
+ values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause
+ for expected in values:
+ field = self.create_field(default=expected)
+ got = field.get_default_value()
+ self.assertEqual(got, expected)
+
+
+class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests):
+
+ def setUp(self):
+ self.expected = {'foo': 'bar'}
+ self.create_field = relations.RelatedField
+
+ def test_get_default_value_returns_empty_list(self):
+ field = self.create_field(many=True)
+ got = field.get_default_value()
+ self.assertListEqual(got, [])
+
+ def test_get_default_value_returns_expected(self):
+ expected = [1, 2, 3]
+ field = self.create_field(many=True, default=expected)
+ got = field.get_default_value()
+ self.assertListEqual(got, expected)
+
+
class CallableDefaultValueTests(TestCase):
def setUp(self):
class CallableDefaultValueSerializer(serializers.ModelSerializer):
@@ -1493,18 +1611,23 @@ class ManyFieldHelpTextTest(TestCase):
self.assertEqual('Some help text.', rel_field.help_text)
+class AttributeMappingOnAutogeneratedRelatedFields(TestCase):
+
+ def test_primary_key_related_field(self):
+ serializer = ForeignKeySourceSerializer()
+ self.assertEqual(serializer.fields['target'].help_text, 'Target')
+ self.assertEqual(serializer.fields['target'].label, 'Target')
+
+ def test_hyperlinked_related_field(self):
+ serializer = HyperlinkedForeignKeySourceSerializer()
+ self.assertEqual(serializer.fields['target'].help_text, 'Target')
+ self.assertEqual(serializer.fields['target'].label, 'Target')
+
+
+@unittest.skipUnless(PIL is not None, 'PIL is not installed')
class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
def setUp(self):
- class AMOAFModel(RESTFrameworkModel):
- char_field = models.CharField(max_length=1024, blank=True)
- comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
- decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
- email_field = models.EmailField(max_length=1024, blank=True)
- file_field = models.FileField(max_length=1024, blank=True)
- image_field = models.ImageField(max_length=1024, blank=True)
- slug_field = models.SlugField(max_length=1024, blank=True)
- url_field = models.URLField(max_length=1024, blank=True)
class AMOAFSerializer(serializers.ModelSerializer):
class Meta:
@@ -1574,17 +1697,10 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
self.field_test('url_field')
+@unittest.skipUnless(PIL is not None, 'PIL is not installed')
class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
def setUp(self):
- class DVOAFModel(RESTFrameworkModel):
- positive_integer_field = models.PositiveIntegerField(blank=True)
- positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
- email_field = models.EmailField(blank=True)
- file_field = models.FileField(blank=True)
- image_field = models.ImageField(blank=True)
- slug_field = models.SlugField(blank=True)
- url_field = models.URLField(blank=True)
class DVOAFSerializer(serializers.ModelSerializer):
class Meta:
diff --git a/rest_framework/tests/test_templatetags.py b/rest_framework/tests/test_templatetags.py
index 609a9e08..d4da0c23 100644
--- a/rest_framework/tests/test_templatetags.py
+++ b/rest_framework/tests/test_templatetags.py
@@ -2,7 +2,7 @@
from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.test import APIRequestFactory
-from rest_framework.templatetags.rest_framework import add_query_param
+from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links
factory = APIRequestFactory()
@@ -17,3 +17,35 @@ class TemplateTagTests(TestCase):
json_url = add_query_param(request, "format", "json")
self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url)
self.assertIn("format=json", json_url)
+
+
+class Issue1386Tests(TestCase):
+ """
+ Covers #1386
+ """
+
+ def test_issue_1386(self):
+ """
+ Test function urlize_quoted_links with different args
+ """
+ correct_urls = [
+ "asdf.com",
+ "asdf.net",
+ "www.as_df.org",
+ "as.d8f.ghj8.gov",
+ ]
+ for i in correct_urls:
+ res = urlize_quoted_links(i)
+ self.assertNotEqual(res, i)
+ self.assertIn(i, res)
+
+ incorrect_urls = [
+ "mailto://asdf@fdf.com",
+ "asdf.netnet",
+ ]
+ for i in incorrect_urls:
+ res = urlize_quoted_links(i)
+ self.assertEqual(i, res)
+
+ # example from issue #1386, this shouldn't raise an exception
+ _ = urlize_quoted_links("asdf:[/p]zxcv.com")
diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py
index 71bd8b55..a55d4b22 100644
--- a/rest_framework/tests/test_testing.py
+++ b/rest_framework/tests/test_testing.py
@@ -152,3 +152,13 @@ class TestAPIRequestFactory(TestCase):
simple_png.name = 'test.png'
factory = APIRequestFactory()
factory.post('/', data={'image': simple_png})
+
+ def test_request_factory_url_arguments(self):
+ """
+ This is a non regression test against #1461
+ """
+ factory = APIRequestFactory()
+ request = factory.get('/view/?demo=test')
+ self.assertEqual(dict(request.GET), {'demo': ['test']})
+ request = factory.get('/view/', {'demo': 'test'})
+ self.assertEqual(dict(request.GET), {'demo': ['test']})
diff --git a/rest_framework/tests/test_urlizer.py b/rest_framework/tests/test_urlizer.py
new file mode 100644
index 00000000..3dc8e8fe
--- /dev/null
+++ b/rest_framework/tests/test_urlizer.py
@@ -0,0 +1,38 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.templatetags.rest_framework import urlize_quoted_links
+import sys
+
+
+class URLizerTests(TestCase):
+ """
+ Test if both JSON and YAML URLs are transformed into links well
+ """
+ def _urlize_dict_check(self, data):
+ """
+ For all items in dict test assert that the value is urlized key
+ """
+ for original, urlized in data.items():
+ assert urlize_quoted_links(original, nofollow=False) == urlized
+
+ def test_json_with_url(self):
+ """
+ Test if JSON URLs are transformed into links well
+ """
+ data = {}
+ data['"url": "http://api/users/1/", '] = \
+ '&quot;url&quot;: &quot;<a href="http://api/users/1/">http://api/users/1/</a>&quot;, '
+ data['"foo_set": [\n "http://api/foos/1/"\n], '] = \
+ '&quot;foo_set&quot;: [\n &quot;<a href="http://api/foos/1/">http://api/foos/1/</a>&quot;\n], '
+ self._urlize_dict_check(data)
+
+ def test_yaml_with_url(self):
+ """
+ Test if YAML URLs are transformed into links well
+ """
+ data = {}
+ data['''{users: 'http://api/users/'}'''] = \
+ '''{users: &#39;<a href="http://api/users/">http://api/users/</a>&#39;}'''
+ data['''foo_set: ['http://api/foos/1/']'''] = \
+ '''foo_set: [&#39;<a href="http://api/foos/1/">http://api/foos/1/</a>&#39;]'''
+ self._urlize_dict_check(data)
diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py
index 124c874d..e13e4078 100644
--- a/rest_framework/tests/test_validation.py
+++ b/rest_framework/tests/test_validation.py
@@ -1,4 +1,5 @@
from __future__ import unicode_literals
+from django.core.validators import MaxValueValidator
from django.db import models
from django.test import TestCase
from rest_framework import generics, serializers, status
@@ -102,3 +103,46 @@ class TestAvoidValidation(TestCase):
self.assertFalse(serializer.is_valid())
self.assertDictEqual(serializer.errors,
{'non_field_errors': ['Invalid data']})
+
+
+# regression tests for issue: 1493
+
+class ValidationMaxValueValidatorModel(models.Model):
+ number_value = models.PositiveIntegerField(validators=[MaxValueValidator(100)])
+
+
+class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ValidationMaxValueValidatorModel
+
+
+class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView):
+ model = ValidationMaxValueValidatorModel
+ serializer_class = ValidationMaxValueValidatorModelSerializer
+
+
+class TestMaxValueValidatorValidation(TestCase):
+
+ def test_max_value_validation_serializer_success(self):
+ serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 99})
+ self.assertTrue(serializer.is_valid())
+
+ def test_max_value_validation_serializer_fails(self):
+ serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101})
+ self.assertFalse(serializer.is_valid())
+ self.assertDictEqual({'number_value': ['Ensure this value is less than or equal to 100.']}, serializer.errors)
+
+ def test_max_value_validation_success(self):
+ obj = ValidationMaxValueValidatorModel.objects.create(number_value=100)
+ request = factory.patch('/{0}'.format(obj.pk), {'number_value': 98}, format='json')
+ view = UpdateMaxValueValidationModel().as_view()
+ response = view(request, pk=obj.pk).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_max_value_validation_fail(self):
+ obj = ValidationMaxValueValidatorModel.objects.create(number_value=100)
+ request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json')
+ view = UpdateMaxValueValidationModel().as_view()
+ response = view(request, pk=obj.pk).render()
+ self.assertEqual(response.content, b'{"number_value": ["Ensure this value is less than or equal to 100."]}')
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py
new file mode 100644
index 00000000..a8f2eb0b
--- /dev/null
+++ b/rest_framework/tests/utils.py
@@ -0,0 +1,25 @@
+from contextlib import contextmanager
+from rest_framework.compat import six
+from rest_framework.settings import api_settings
+
+
+@contextmanager
+def temporary_setting(setting, value, module=None):
+ """
+ Temporarily change value of setting for test.
+
+ Optionally reload given module, useful when module uses value of setting on
+ import.
+ """
+ original_value = getattr(api_settings, setting)
+ setattr(api_settings, setting, value)
+
+ if module is not None:
+ six.moves.reload_module(module)
+
+ yield
+
+ setattr(api_settings, setting, original_value)
+
+ if module is not None:
+ six.moves.reload_module(module)
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py
new file mode 100644
index 00000000..3917b74a
--- /dev/null
+++ b/rest_framework/tests/views.py
@@ -0,0 +1,8 @@
+from rest_framework import generics
+from rest_framework.tests.models import NullableForeignKeySource
+from rest_framework.tests.serializers import NullableFKSourceSerializer
+
+
+class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView):
+ model = NullableForeignKeySource
+ model_serializer_class = NullableFKSourceSerializer
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index c36b58bf..91be9cfd 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -136,6 +136,8 @@ class SimpleRateThrottle(BaseThrottle):
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
+ if available_requests <= 0:
+ return None
return remaining_duration / float(available_requests)
diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py
index c09c2933..92f99efd 100644
--- a/rest_framework/utils/mediatypes.py
+++ b/rest_framework/utils/mediatypes.py
@@ -74,7 +74,7 @@ class _MediaType(object):
return 0
elif self.sub_type == '*':
return 1
- elif not self.params or self.params.keys() == ['q']:
+ elif not self.params or list(self.params.keys()) == ['q']:
return 2
return 3
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 02a6e25a..a2668f2c 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -131,7 +131,7 @@ class APIView(View):
"""
If request is not permitted, determine what kind of exception to raise.
"""
- if not self.request.successful_authenticator:
+ if not request.successful_authenticator:
raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied()
@@ -295,7 +295,7 @@ class APIView(View):
# Dispatch methods
- def initialize_request(self, request, *args, **kargs):
+ def initialize_request(self, request, *args, **kwargs):
"""
Returns the initial request object.
"""