aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/fields.py16
-rw-r--r--rest_framework/filters.py5
-rw-r--r--rest_framework/generics.py11
-rw-r--r--rest_framework/renderers.py5
-rw-r--r--rest_framework/serializers.py8
-rw-r--r--rest_framework/settings.py6
-rw-r--r--rest_framework/templates/rest_framework/base.html2
-rw-r--r--rest_framework/test.py2
-rw-r--r--rest_framework/tests/test_description.py9
-rw-r--r--rest_framework/tests/test_fields.py9
-rw-r--r--rest_framework/tests/test_testing.py30
-rw-r--r--rest_framework/tests/test_throttling.py33
-rw-r--r--rest_framework/throttling.py3
-rw-r--r--rest_framework/utils/breadcrumbs.py5
-rw-r--r--rest_framework/utils/formatting.py42
-rw-r--r--rest_framework/views.py48
17 files changed, 169 insertions, 67 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 776618ac..087808e0 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,4 +1,4 @@
-__version__ = '2.3.6'
+__version__ = '2.3.7'
VERSION = __version__ # synonym
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 9ba5c0eb..3e0ca1a1 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -16,6 +16,7 @@ from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
from django.db.models.fields import BLANK_CHOICE_DASH
+from django.http import QueryDict
from django.forms import widgets
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
@@ -402,10 +403,15 @@ class BooleanField(WritableField):
}
empty = False
- # Note: we set default to `False` in order to fill in missing value not
- # supplied by html form. TODO: Fix so that only html form input gets
- # this behavior.
- default = False
+ def field_from_native(self, data, files, field_name, into):
+ # HTML checkboxes do not explicitly represent unchecked as `False`
+ # we deal with that here...
+ if isinstance(data, QueryDict):
+ self.default = False
+
+ return super(BooleanField, self).field_from_native(
+ data, files, field_name, into
+ )
def from_native(self, value):
if value in ('true', 't', 'True', '1'):
@@ -927,7 +933,7 @@ class ImageField(FileField):
if f is None:
return None
- from compat import Image
+ from rest_framework.compat import Image
assert Image is not None, 'PIL must be installed for ImageField support'
# We need to get a file object for PIL. We might have a path or we might
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index c058bc71..4079e1bd 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -109,8 +109,7 @@ class OrderingFilter(BaseFilterBackend):
def get_ordering(self, request):
"""
- Search terms are set by a ?search=... query parameter,
- and may be comma and/or whitespace delimited.
+ Ordering is set by a comma delimited ?ordering=... query parameter.
"""
params = request.QUERY_PARAMS.get(self.ordering_param)
if params:
@@ -134,7 +133,7 @@ class OrderingFilter(BaseFilterBackend):
ordering = self.remove_invalid_fields(queryset, ordering)
if not ordering:
- # Use 'ordering' attribtue by default
+ # Use 'ordering' attribute by default
ordering = self.get_default_ordering(view)
if ordering:
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 99e9782e..5ecf6310 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -14,6 +14,15 @@ from rest_framework.settings import api_settings
import warnings
+def strict_positive_int(integer_string):
+ """
+ Cast a string to a strictly positive integer.
+ """
+ ret = int(integer_string)
+ if ret <= 0:
+ raise ValueError()
+ return ret
+
def get_object_or_404(queryset, **filter_kwargs):
"""
Same as Django's standard shortcut, but make sure to raise 404
@@ -135,7 +144,7 @@ class GenericAPIView(views.APIView):
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1
try:
- page_number = int(page)
+ page_number = strict_positive_int(page)
except ValueError:
if page == 'last':
page_number = paginator.num_pages
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 3a03ca33..1006e26c 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -24,7 +24,6 @@ from rest_framework.settings import api_settings
from rest_framework.request import clone_request
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
-from rest_framework.utils.formatting import get_view_name, get_view_description
from rest_framework import exceptions, parsers, status, VERSION
@@ -498,10 +497,10 @@ class BrowsableAPIRenderer(BaseRenderer):
return GenericContentForm()
def get_name(self, view):
- return get_view_name(view.__class__, getattr(view, 'suffix', None))
+ return view.get_view_name()
def get_description(self, view):
- return get_view_description(view.__class__, html=True)
+ return view.get_view_description(html=True)
def get_breadcrumbs(self, request):
return get_breadcrumbs(request.path)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 023f7ccf..31cfa344 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -683,14 +683,14 @@ class ModelSerializer(Serializer):
# in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
assert field_name not in self.base_fields.keys(), \
- "field '%s' on serializer '%s' specfied in " \
+ "field '%s' on serializer '%s' specified in " \
"`read_only_fields`, but also added " \
- "as an explict field. Remove it from `read_only_fields`." % \
+ "as an explicit field. Remove it from `read_only_fields`." % \
(field_name, self.__class__.__name__)
assert field_name in ret, \
- "Noexistant field '%s' specified in `read_only_fields` " \
+ "Non-existant field '%s' specified in `read_only_fields` " \
"on serializer '%s'." % \
- (self.__class__.__name__, field_name)
+ (field_name, self.__class__.__name__)
ret[field_name].read_only = True
return ret
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 8fd177d5..7d25e513 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -69,6 +69,10 @@ DEFAULTS = {
'PAGINATE_BY': None,
'PAGINATE_BY_PARAM': None,
+ # View configuration
+ 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
+ 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
+
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
@@ -125,6 +129,8 @@ IMPORT_STRINGS = (
'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
+ 'VIEW_NAME_FUNCTION',
+ 'VIEW_DESCRIPTION_FUNCTION'
)
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 9d939e73..51f9c291 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -196,7 +196,7 @@
<button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button>
{% endif %}
{% if raw_data_patch_form %}
- <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PATCH" title="Make a PUT request on the {{ name }} resource">PATCH</button>
+ <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PATCH" title="Make a PATCH request on the {{ name }} resource">PATCH</button>
{% endif %}
</div>
</fieldset>
diff --git a/rest_framework/test.py b/rest_framework/test.py
index a18f5a29..234d10a4 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -134,6 +134,8 @@ class APIClient(APIRequestFactory, DjangoClient):
"""
self.handler._force_user = user
self.handler._force_token = token
+ if user is None:
+ self.logout() # Also clear any possible session info if required
def request(self, **kwargs):
# Ensure that any credentials set get added to every request.
diff --git a/rest_framework/tests/test_description.py b/rest_framework/tests/test_description.py
index 8019f5ec..4c03c1de 100644
--- a/rest_framework/tests/test_description.py
+++ b/rest_framework/tests/test_description.py
@@ -6,7 +6,6 @@ from rest_framework.compat import apply_markdown, smart_text
from rest_framework.views import APIView
from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring
from rest_framework.tests.description import UTF8_TEST_DOCSTRING
-from rest_framework.utils.formatting import get_view_name, get_view_description
# We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring
@@ -58,7 +57,7 @@ class TestViewNamesAndDescriptions(TestCase):
"""
class MockView(APIView):
pass
- self.assertEqual(get_view_name(MockView), 'Mock')
+ self.assertEqual(MockView().get_view_name(), 'Mock')
def test_view_description_uses_docstring(self):
"""Ensure view descriptions are based on the docstring."""
@@ -78,7 +77,7 @@ class TestViewNamesAndDescriptions(TestCase):
# hash style header #"""
- self.assertEqual(get_view_description(MockView), DESCRIPTION)
+ self.assertEqual(MockView().get_view_description(), DESCRIPTION)
def test_view_description_supports_unicode(self):
"""
@@ -86,7 +85,7 @@ class TestViewNamesAndDescriptions(TestCase):
"""
self.assertEqual(
- get_view_description(ViewWithNonASCIICharactersInDocstring),
+ ViewWithNonASCIICharactersInDocstring().get_view_description(),
smart_text(UTF8_TEST_DOCSTRING)
)
@@ -97,7 +96,7 @@ class TestViewNamesAndDescriptions(TestCase):
"""
class MockView(APIView):
pass
- self.assertEqual(get_view_description(MockView), '')
+ self.assertEqual(MockView().get_view_description(), '')
def test_markdown(self):
"""
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index 6836ec86..ebccba7d 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -896,3 +896,12 @@ class CustomIntegerField(TestCase):
self.assertFalse(serializer.is_valid())
+class BooleanField(TestCase):
+ """
+ Tests for BooleanField
+ """
+ def test_boolean_required(self):
+ class BooleanRequiredSerializer(serializers.Serializer):
+ bool_field = serializers.BooleanField(required=True)
+
+ self.assertFalse(BooleanRequiredSerializer(data={}).is_valid())
diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py
index 49d45fc2..48b8956b 100644
--- a/rest_framework/tests/test_testing.py
+++ b/rest_framework/tests/test_testing.py
@@ -17,8 +17,18 @@ def view(request):
})
+@api_view(['GET', 'POST'])
+def session_view(request):
+ active_session = request.session.get('active_session', False)
+ request.session['active_session'] = True
+ return Response({
+ 'active_session': active_session
+ })
+
+
urlpatterns = patterns('',
url(r'^view/$', view),
+ url(r'^session-view/$', session_view),
)
@@ -46,6 +56,26 @@ class TestAPITestClient(TestCase):
response = self.client.get('/view/')
self.assertEqual(response.data['user'], 'example')
+ def test_force_authenticate_with_sessions(self):
+ """
+ Setting `.force_authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.force_authenticate(user)
+
+ # First request does not yet have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
+ # Subsequant requests have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], True)
+
+ # Force authenticating as `None` should also logout the user session.
+ self.client.force_authenticate(None)
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
def test_csrf_exempt_by_default(self):
"""
By default, the test client is CSRF exempt.
diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py
index 19bc691a..41bff692 100644
--- a/rest_framework/tests/test_throttling.py
+++ b/rest_framework/tests/test_throttling.py
@@ -7,7 +7,7 @@ from django.contrib.auth.models import User
from django.core.cache import cache
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
-from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle
+from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle
from rest_framework.response import Response
@@ -21,6 +21,14 @@ class User3MinRateThrottle(UserRateThrottle):
scope = 'minutes'
+class NonTimeThrottle(BaseThrottle):
+ def allow_request(self, request, view):
+ if not hasattr(self.__class__, 'called'):
+ self.__class__.called = True
+ return True
+ return False
+
+
class MockView(APIView):
throttle_classes = (User3SecRateThrottle,)
@@ -35,6 +43,13 @@ class MockView_MinuteThrottling(APIView):
return Response('foo')
+class MockView_NonTimeThrottling(APIView):
+ throttle_classes = (NonTimeThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+
class ThrottlingTests(TestCase):
def setUp(self):
"""
@@ -140,6 +155,22 @@ class ThrottlingTests(TestCase):
(80, None)
))
+ def test_non_time_throttle(self):
+ """
+ Ensure for second based throttles.
+ """
+ request = self.factory.get('/')
+
+ self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))
+
+ response = MockView_NonTimeThrottling.as_view()(request)
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
+
+ self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
+
+ response = MockView_NonTimeThrottling.as_view()(request)
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
+
class ScopedRateThrottleTests(TestCase):
"""
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index f6bb1cc8..65b45593 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -96,6 +96,9 @@ class SimpleRateThrottle(BaseThrottle):
return True
self.key = self.get_cache_key(request, view)
+ if self.key is None:
+ return True
+
self.history = cache.get(self.key, [])
self.now = self.timer()
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py
index d51374b0..0384faba 100644
--- a/rest_framework/utils/breadcrumbs.py
+++ b/rest_framework/utils/breadcrumbs.py
@@ -1,6 +1,5 @@
from __future__ import unicode_literals
from django.core.urlresolvers import resolve, get_script_prefix
-from rest_framework.utils.formatting import get_view_name
def get_breadcrumbs(url):
@@ -29,8 +28,8 @@ def get_breadcrumbs(url):
# Don't list the same view twice in a row.
# Probably an optional trailing slash.
if not seen or seen[-1] != view:
- suffix = getattr(view, 'suffix', None)
- name = get_view_name(view.cls, suffix)
+ instance = view.cls()
+ name = instance.get_view_name()
breadcrumbs_list.insert(0, (name, prefix + url))
seen.append(view)
diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py
index 4bec8387..4b59ba84 100644
--- a/rest_framework/utils/formatting.py
+++ b/rest_framework/utils/formatting.py
@@ -5,11 +5,13 @@ from __future__ import unicode_literals
from django.utils.html import escape
from django.utils.safestring import mark_safe
-from rest_framework.compat import apply_markdown, smart_text
+from rest_framework.compat import apply_markdown
+from rest_framework.settings import api_settings
+from textwrap import dedent
import re
-def _remove_trailing_string(content, trailing):
+def remove_trailing_string(content, trailing):
"""
Strip trailing component `trailing` from `content` if it exists.
Used when generating names from view classes.
@@ -19,10 +21,14 @@ def _remove_trailing_string(content, trailing):
return content
-def _remove_leading_indent(content):
+def dedent(content):
"""
Remove leading indent from a block of text.
Used when generating descriptions from docstrings.
+
+ Note that python's `textwrap.dedent` doesn't quite cut it,
+ as it fails to dedent multiline docstrings that include
+ unindented text on the initial line.
"""
whitespace_counts = [len(line) - len(line.lstrip(' '))
for line in content.splitlines()[1:] if line.lstrip()]
@@ -31,11 +37,10 @@ def _remove_leading_indent(content):
if whitespace_counts:
whitespace_pattern = '^' + (' ' * min(whitespace_counts))
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
- content = content.strip('\n')
- return content
+ return content.strip()
-def _camelcase_to_spaces(content):
+def camelcase_to_spaces(content):
"""
Translate 'CamelCaseNames' to 'Camel Case Names'.
Used when generating names from view classes.
@@ -44,31 +49,6 @@ def _camelcase_to_spaces(content):
content = re.sub(camelcase_boundry, ' \\1', content).strip()
return ' '.join(content.split('_')).title()
-
-def get_view_name(cls, suffix=None):
- """
- Return a formatted name for an `APIView` class or `@api_view` function.
- """
- name = cls.__name__
- name = _remove_trailing_string(name, 'View')
- name = _remove_trailing_string(name, 'ViewSet')
- name = _camelcase_to_spaces(name)
- if suffix:
- name += ' ' + suffix
- return name
-
-
-def get_view_description(cls, html=False):
- """
- Return a description for an `APIView` class or `@api_view` function.
- """
- description = cls.__doc__ or ''
- description = _remove_leading_indent(smart_text(description))
- if html:
- return markup_description(description)
- return description
-
-
def markup_description(description):
"""
Apply HTML markup to the given description.
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 37bba7f0..727a9f95 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -8,11 +8,29 @@ from django.http import Http404
from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions
-from rest_framework.compat import View, HttpResponseBase
+from rest_framework.compat import smart_text, HttpResponseBase, View
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.settings import api_settings
-from rest_framework.utils.formatting import get_view_name, get_view_description
+from rest_framework.utils import formatting
+
+
+def get_view_name(cls, suffix=None):
+ name = cls.__name__
+ name = formatting.remove_trailing_string(name, 'View')
+ name = formatting.remove_trailing_string(name, 'ViewSet')
+ name = formatting.camelcase_to_spaces(name)
+ if suffix:
+ name += ' ' + suffix
+
+ return name
+
+def get_view_description(cls, html=False):
+ description = cls.__doc__ or ''
+ description = formatting.dedent(smart_text(description))
+ if html:
+ return formatting.markup_description(description)
+ return description
class APIView(View):
@@ -110,6 +128,22 @@ class APIView(View):
'request': getattr(self, 'request', None)
}
+ def get_view_name(self):
+ """
+ Return the view name, as used in OPTIONS responses and in the
+ browsable API.
+ """
+ func = api_settings.VIEW_NAME_FUNCTION
+ return func(self.__class__, getattr(self, 'suffix', None))
+
+ def get_view_description(self, html=False):
+ """
+ Return some descriptive text for the view, as used in OPTIONS responses
+ and in the browsable API.
+ """
+ func = api_settings.VIEW_DESCRIPTION_FUNCTION
+ return func(self.__class__, html)
+
# API policy instantiation methods
def get_format_suffix(self, **kwargs):
@@ -269,7 +303,7 @@ class APIView(View):
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
- if isinstance(exc, exceptions.Throttled):
+ if isinstance(exc, exceptions.Throttled) and exc.wait is not None:
# Throttle wait header
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
@@ -342,16 +376,12 @@ class APIView(View):
Return a dictionary of metadata about the view.
Used to return responses for OPTIONS requests.
"""
-
- # This is used by ViewSets to disambiguate instance vs list views
- view_name_suffix = getattr(self, 'suffix', None)
-
# By default we can't provide any form-like information, however the
# generic views override this implementation and add additional
# information for POST and PUT methods, based on the serializer.
ret = SortedDict()
- ret['name'] = get_view_name(self.__class__, view_name_suffix)
- ret['description'] = get_view_description(self.__class__)
+ ret['name'] = self.get_view_name()
+ ret['description'] = self.get_view_description()
ret['renders'] = [renderer.media_type for renderer in self.renderer_classes]
ret['parses'] = [parser.media_type for parser in self.parser_classes]
return ret