diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/__init__.py | 2 | ||||
| -rw-r--r-- | rest_framework/fields.py | 16 | ||||
| -rw-r--r-- | rest_framework/filters.py | 5 | ||||
| -rw-r--r-- | rest_framework/generics.py | 11 | ||||
| -rw-r--r-- | rest_framework/renderers.py | 5 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 8 | ||||
| -rw-r--r-- | rest_framework/settings.py | 6 | ||||
| -rw-r--r-- | rest_framework/templates/rest_framework/base.html | 2 | ||||
| -rw-r--r-- | rest_framework/test.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_description.py | 9 | ||||
| -rw-r--r-- | rest_framework/tests/test_fields.py | 9 | ||||
| -rw-r--r-- | rest_framework/tests/test_testing.py | 30 | ||||
| -rw-r--r-- | rest_framework/tests/test_throttling.py | 33 | ||||
| -rw-r--r-- | rest_framework/throttling.py | 3 | ||||
| -rw-r--r-- | rest_framework/utils/breadcrumbs.py | 5 | ||||
| -rw-r--r-- | rest_framework/utils/formatting.py | 42 | ||||
| -rw-r--r-- | rest_framework/views.py | 48 |
17 files changed, 169 insertions, 67 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 776618ac..087808e0 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.3.6' +__version__ = '2.3.7' VERSION = __version__ # synonym diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 9ba5c0eb..3e0ca1a1 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -16,6 +16,7 @@ from django.core import validators from django.core.exceptions import ValidationError from django.conf import settings from django.db.models.fields import BLANK_CHOICE_DASH +from django.http import QueryDict from django.forms import widgets from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ @@ -402,10 +403,15 @@ class BooleanField(WritableField): } empty = False - # Note: we set default to `False` in order to fill in missing value not - # supplied by html form. TODO: Fix so that only html form input gets - # this behavior. - default = False + def field_from_native(self, data, files, field_name, into): + # HTML checkboxes do not explicitly represent unchecked as `False` + # we deal with that here... + if isinstance(data, QueryDict): + self.default = False + + return super(BooleanField, self).field_from_native( + data, files, field_name, into + ) def from_native(self, value): if value in ('true', 't', 'True', '1'): @@ -927,7 +933,7 @@ class ImageField(FileField): if f is None: return None - from compat import Image + from rest_framework.compat import Image assert Image is not None, 'PIL must be installed for ImageField support' # We need to get a file object for PIL. We might have a path or we might diff --git a/rest_framework/filters.py b/rest_framework/filters.py index c058bc71..4079e1bd 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -109,8 +109,7 @@ class OrderingFilter(BaseFilterBackend): def get_ordering(self, request): """ - Search terms are set by a ?search=... query parameter, - and may be comma and/or whitespace delimited. + Ordering is set by a comma delimited ?ordering=... query parameter. """ params = request.QUERY_PARAMS.get(self.ordering_param) if params: @@ -134,7 +133,7 @@ class OrderingFilter(BaseFilterBackend): ordering = self.remove_invalid_fields(queryset, ordering) if not ordering: - # Use 'ordering' attribtue by default + # Use 'ordering' attribute by default ordering = self.get_default_ordering(view) if ordering: diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 99e9782e..5ecf6310 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -14,6 +14,15 @@ from rest_framework.settings import api_settings import warnings +def strict_positive_int(integer_string): + """ + Cast a string to a strictly positive integer. + """ + ret = int(integer_string) + if ret <= 0: + raise ValueError() + return ret + def get_object_or_404(queryset, **filter_kwargs): """ Same as Django's standard shortcut, but make sure to raise 404 @@ -135,7 +144,7 @@ class GenericAPIView(views.APIView): page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) page = page_kwarg or page_query_param or 1 try: - page_number = int(page) + page_number = strict_positive_int(page) except ValueError: if page == 'last': page_number = paginator.num_pages diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 3a03ca33..1006e26c 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -24,7 +24,6 @@ from rest_framework.settings import api_settings from rest_framework.request import clone_request from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs -from rest_framework.utils.formatting import get_view_name, get_view_description from rest_framework import exceptions, parsers, status, VERSION @@ -498,10 +497,10 @@ class BrowsableAPIRenderer(BaseRenderer): return GenericContentForm() def get_name(self, view): - return get_view_name(view.__class__, getattr(view, 'suffix', None)) + return view.get_view_name() def get_description(self, view): - return get_view_description(view.__class__, html=True) + return view.get_view_description(html=True) def get_breadcrumbs(self, request): return get_breadcrumbs(request.path) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 023f7ccf..31cfa344 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -683,14 +683,14 @@ class ModelSerializer(Serializer): # in the `read_only_fields` option for field_name in self.opts.read_only_fields: assert field_name not in self.base_fields.keys(), \ - "field '%s' on serializer '%s' specfied in " \ + "field '%s' on serializer '%s' specified in " \ "`read_only_fields`, but also added " \ - "as an explict field. Remove it from `read_only_fields`." % \ + "as an explicit field. Remove it from `read_only_fields`." % \ (field_name, self.__class__.__name__) assert field_name in ret, \ - "Noexistant field '%s' specified in `read_only_fields` " \ + "Non-existant field '%s' specified in `read_only_fields` " \ "on serializer '%s'." % \ - (self.__class__.__name__, field_name) + (field_name, self.__class__.__name__) ret[field_name].read_only = True return ret diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 8fd177d5..7d25e513 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -69,6 +69,10 @@ DEFAULTS = { 'PAGINATE_BY': None, 'PAGINATE_BY_PARAM': None, + # View configuration + 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', + 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', + # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -125,6 +129,8 @@ IMPORT_STRINGS = ( 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', + 'VIEW_NAME_FUNCTION', + 'VIEW_DESCRIPTION_FUNCTION' ) diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 9d939e73..51f9c291 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -196,7 +196,7 @@ <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button> {% endif %} {% if raw_data_patch_form %} - <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PATCH" title="Make a PUT request on the {{ name }} resource">PATCH</button> + <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PATCH" title="Make a PATCH request on the {{ name }} resource">PATCH</button> {% endif %} </div> </fieldset> diff --git a/rest_framework/test.py b/rest_framework/test.py index a18f5a29..234d10a4 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -134,6 +134,8 @@ class APIClient(APIRequestFactory, DjangoClient): """ self.handler._force_user = user self.handler._force_token = token + if user is None: + self.logout() # Also clear any possible session info if required def request(self, **kwargs): # Ensure that any credentials set get added to every request. diff --git a/rest_framework/tests/test_description.py b/rest_framework/tests/test_description.py index 8019f5ec..4c03c1de 100644 --- a/rest_framework/tests/test_description.py +++ b/rest_framework/tests/test_description.py @@ -6,7 +6,6 @@ from rest_framework.compat import apply_markdown, smart_text from rest_framework.views import APIView from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring from rest_framework.tests.description import UTF8_TEST_DOCSTRING -from rest_framework.utils.formatting import get_view_name, get_view_description # We check that docstrings get nicely un-indented. DESCRIPTION = """an example docstring @@ -58,7 +57,7 @@ class TestViewNamesAndDescriptions(TestCase): """ class MockView(APIView): pass - self.assertEqual(get_view_name(MockView), 'Mock') + self.assertEqual(MockView().get_view_name(), 'Mock') def test_view_description_uses_docstring(self): """Ensure view descriptions are based on the docstring.""" @@ -78,7 +77,7 @@ class TestViewNamesAndDescriptions(TestCase): # hash style header #""" - self.assertEqual(get_view_description(MockView), DESCRIPTION) + self.assertEqual(MockView().get_view_description(), DESCRIPTION) def test_view_description_supports_unicode(self): """ @@ -86,7 +85,7 @@ class TestViewNamesAndDescriptions(TestCase): """ self.assertEqual( - get_view_description(ViewWithNonASCIICharactersInDocstring), + ViewWithNonASCIICharactersInDocstring().get_view_description(), smart_text(UTF8_TEST_DOCSTRING) ) @@ -97,7 +96,7 @@ class TestViewNamesAndDescriptions(TestCase): """ class MockView(APIView): pass - self.assertEqual(get_view_description(MockView), '') + self.assertEqual(MockView().get_view_description(), '') def test_markdown(self): """ diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index 6836ec86..ebccba7d 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -896,3 +896,12 @@ class CustomIntegerField(TestCase): self.assertFalse(serializer.is_valid()) +class BooleanField(TestCase): + """ + Tests for BooleanField + """ + def test_boolean_required(self): + class BooleanRequiredSerializer(serializers.Serializer): + bool_field = serializers.BooleanField(required=True) + + self.assertFalse(BooleanRequiredSerializer(data={}).is_valid()) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 49d45fc2..48b8956b 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -17,8 +17,18 @@ def view(request): }) +@api_view(['GET', 'POST']) +def session_view(request): + active_session = request.session.get('active_session', False) + request.session['active_session'] = True + return Response({ + 'active_session': active_session + }) + + urlpatterns = patterns('', url(r'^view/$', view), + url(r'^session-view/$', session_view), ) @@ -46,6 +56,26 @@ class TestAPITestClient(TestCase): response = self.client.get('/view/') self.assertEqual(response.data['user'], 'example') + def test_force_authenticate_with_sessions(self): + """ + Setting `.force_authenticate()` forcibly authenticates each request. + """ + user = User.objects.create_user('example', 'example@example.com') + self.client.force_authenticate(user) + + # First request does not yet have an active session + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], False) + + # Subsequant requests have an active session + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], True) + + # Force authenticating as `None` should also logout the user session. + self.client.force_authenticate(None) + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], False) + def test_csrf_exempt_by_default(self): """ By default, the test client is CSRF exempt. diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py index 19bc691a..41bff692 100644 --- a/rest_framework/tests/test_throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -7,7 +7,7 @@ from django.contrib.auth.models import User from django.core.cache import cache from rest_framework.test import APIRequestFactory from rest_framework.views import APIView -from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle +from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle from rest_framework.response import Response @@ -21,6 +21,14 @@ class User3MinRateThrottle(UserRateThrottle): scope = 'minutes' +class NonTimeThrottle(BaseThrottle): + def allow_request(self, request, view): + if not hasattr(self.__class__, 'called'): + self.__class__.called = True + return True + return False + + class MockView(APIView): throttle_classes = (User3SecRateThrottle,) @@ -35,6 +43,13 @@ class MockView_MinuteThrottling(APIView): return Response('foo') +class MockView_NonTimeThrottling(APIView): + throttle_classes = (NonTimeThrottle,) + + def get(self, request): + return Response('foo') + + class ThrottlingTests(TestCase): def setUp(self): """ @@ -140,6 +155,22 @@ class ThrottlingTests(TestCase): (80, None) )) + def test_non_time_throttle(self): + """ + Ensure for second based throttles. + """ + request = self.factory.get('/') + + self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called')) + + response = MockView_NonTimeThrottling.as_view()(request) + self.assertFalse('X-Throttle-Wait-Seconds' in response) + + self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called) + + response = MockView_NonTimeThrottling.as_view()(request) + self.assertFalse('X-Throttle-Wait-Seconds' in response) + class ScopedRateThrottleTests(TestCase): """ diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index f6bb1cc8..65b45593 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -96,6 +96,9 @@ class SimpleRateThrottle(BaseThrottle): return True self.key = self.get_cache_key(request, view) + if self.key is None: + return True + self.history = cache.get(self.key, []) self.now = self.timer() diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index d51374b0..0384faba 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -1,6 +1,5 @@ from __future__ import unicode_literals from django.core.urlresolvers import resolve, get_script_prefix -from rest_framework.utils.formatting import get_view_name def get_breadcrumbs(url): @@ -29,8 +28,8 @@ def get_breadcrumbs(url): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: - suffix = getattr(view, 'suffix', None) - name = get_view_name(view.cls, suffix) + instance = view.cls() + name = instance.get_view_name() breadcrumbs_list.insert(0, (name, prefix + url)) seen.append(view) diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 4bec8387..4b59ba84 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -5,11 +5,13 @@ from __future__ import unicode_literals from django.utils.html import escape from django.utils.safestring import mark_safe -from rest_framework.compat import apply_markdown, smart_text +from rest_framework.compat import apply_markdown +from rest_framework.settings import api_settings +from textwrap import dedent import re -def _remove_trailing_string(content, trailing): +def remove_trailing_string(content, trailing): """ Strip trailing component `trailing` from `content` if it exists. Used when generating names from view classes. @@ -19,10 +21,14 @@ def _remove_trailing_string(content, trailing): return content -def _remove_leading_indent(content): +def dedent(content): """ Remove leading indent from a block of text. Used when generating descriptions from docstrings. + + Note that python's `textwrap.dedent` doesn't quite cut it, + as it fails to dedent multiline docstrings that include + unindented text on the initial line. """ whitespace_counts = [len(line) - len(line.lstrip(' ')) for line in content.splitlines()[1:] if line.lstrip()] @@ -31,11 +37,10 @@ def _remove_leading_indent(content): if whitespace_counts: whitespace_pattern = '^' + (' ' * min(whitespace_counts)) content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) - content = content.strip('\n') - return content + return content.strip() -def _camelcase_to_spaces(content): +def camelcase_to_spaces(content): """ Translate 'CamelCaseNames' to 'Camel Case Names'. Used when generating names from view classes. @@ -44,31 +49,6 @@ def _camelcase_to_spaces(content): content = re.sub(camelcase_boundry, ' \\1', content).strip() return ' '.join(content.split('_')).title() - -def get_view_name(cls, suffix=None): - """ - Return a formatted name for an `APIView` class or `@api_view` function. - """ - name = cls.__name__ - name = _remove_trailing_string(name, 'View') - name = _remove_trailing_string(name, 'ViewSet') - name = _camelcase_to_spaces(name) - if suffix: - name += ' ' + suffix - return name - - -def get_view_description(cls, html=False): - """ - Return a description for an `APIView` class or `@api_view` function. - """ - description = cls.__doc__ or '' - description = _remove_leading_indent(smart_text(description)) - if html: - return markup_description(description) - return description - - def markup_description(description): """ Apply HTML markup to the given description. diff --git a/rest_framework/views.py b/rest_framework/views.py index 37bba7f0..727a9f95 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -8,11 +8,29 @@ from django.http import Http404 from django.utils.datastructures import SortedDict from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions -from rest_framework.compat import View, HttpResponseBase +from rest_framework.compat import smart_text, HttpResponseBase, View from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings -from rest_framework.utils.formatting import get_view_name, get_view_description +from rest_framework.utils import formatting + + +def get_view_name(cls, suffix=None): + name = cls.__name__ + name = formatting.remove_trailing_string(name, 'View') + name = formatting.remove_trailing_string(name, 'ViewSet') + name = formatting.camelcase_to_spaces(name) + if suffix: + name += ' ' + suffix + + return name + +def get_view_description(cls, html=False): + description = cls.__doc__ or '' + description = formatting.dedent(smart_text(description)) + if html: + return formatting.markup_description(description) + return description class APIView(View): @@ -110,6 +128,22 @@ class APIView(View): 'request': getattr(self, 'request', None) } + def get_view_name(self): + """ + Return the view name, as used in OPTIONS responses and in the + browsable API. + """ + func = api_settings.VIEW_NAME_FUNCTION + return func(self.__class__, getattr(self, 'suffix', None)) + + def get_view_description(self, html=False): + """ + Return some descriptive text for the view, as used in OPTIONS responses + and in the browsable API. + """ + func = api_settings.VIEW_DESCRIPTION_FUNCTION + return func(self.__class__, html) + # API policy instantiation methods def get_format_suffix(self, **kwargs): @@ -269,7 +303,7 @@ class APIView(View): Handle any exception that occurs, by returning an appropriate response, or re-raising the error. """ - if isinstance(exc, exceptions.Throttled): + if isinstance(exc, exceptions.Throttled) and exc.wait is not None: # Throttle wait header self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait @@ -342,16 +376,12 @@ class APIView(View): Return a dictionary of metadata about the view. Used to return responses for OPTIONS requests. """ - - # This is used by ViewSets to disambiguate instance vs list views - view_name_suffix = getattr(self, 'suffix', None) - # By default we can't provide any form-like information, however the # generic views override this implementation and add additional # information for POST and PUT methods, based on the serializer. ret = SortedDict() - ret['name'] = get_view_name(self.__class__, view_name_suffix) - ret['description'] = get_view_description(self.__class__) + ret['name'] = self.get_view_name() + ret['description'] = self.get_view_description() ret['renders'] = [renderer.media_type for renderer in self.renderer_classes] ret['parses'] = [parser.media_type for parser in self.parser_classes] return ret |
