aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-08-29 20:52:46 +0100
committerTom Christie2013-08-29 20:52:46 +0100
commit19f9adacb254841d02f43295baf81406ce3c60eb (patch)
treef77644b5515c15e09d49d12aef0855c67262f9ba /rest_framework
parente4d2f54529bcf538be93da5770e05b88a32da1c7 (diff)
parent02b6836ee88498861521dfff743467b0456ad109 (diff)
downloaddjango-rest-framework-19f9adacb254841d02f43295baf81406ce3c60eb.tar.bz2
Merge branch 'master' into display-raw-data
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/fields.py5
-rw-r--r--rest_framework/generics.py11
-rw-r--r--rest_framework/settings.py10
-rw-r--r--rest_framework/tests/test_fields.py8
-rw-r--r--rest_framework/tests/test_pagination.py47
-rw-r--r--rest_framework/throttling.py11
-rw-r--r--rest_framework/utils/breadcrumbs.py7
-rw-r--r--rest_framework/views.py88
8 files changed, 148 insertions, 39 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 3e0ca1a1..210c2537 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -514,6 +514,11 @@ class ChoiceField(WritableField):
return True
return False
+ def from_native(self, value):
+ if value in validators.EMPTY_VALUES:
+ return None
+ return super(ChoiceField, self).from_native(value)
+
class EmailField(CharField):
type_name = 'EmailField'
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 5ecf6310..14feed20 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -14,13 +14,15 @@ from rest_framework.settings import api_settings
import warnings
-def strict_positive_int(integer_string):
+def strict_positive_int(integer_string, cutoff=None):
"""
Cast a string to a strictly positive integer.
"""
ret = int(integer_string)
if ret <= 0:
raise ValueError()
+ if cutoff:
+ ret = min(ret, cutoff)
return ret
def get_object_or_404(queryset, **filter_kwargs):
@@ -56,6 +58,7 @@ class GenericAPIView(views.APIView):
# Pagination settings
paginate_by = api_settings.PAGINATE_BY
paginate_by_param = api_settings.PAGINATE_BY_PARAM
+ max_paginate_by = api_settings.MAX_PAGINATE_BY
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
page_kwarg = 'page'
@@ -205,9 +208,11 @@ class GenericAPIView(views.APIView):
PendingDeprecationWarning, stacklevel=2)
if self.paginate_by_param:
- query_params = self.request.QUERY_PARAMS
try:
- return int(query_params[self.paginate_by_param])
+ return strict_positive_int(
+ self.request.QUERY_PARAMS[self.paginate_by_param],
+ cutoff=self.max_paginate_by
+ )
except (KeyError, ValueError):
pass
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 7d25e513..8c084751 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -48,7 +48,6 @@ DEFAULTS = {
),
'DEFAULT_THROTTLE_CLASSES': (
),
-
'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
@@ -68,15 +67,16 @@ DEFAULTS = {
# Pagination
'PAGINATE_BY': None,
'PAGINATE_BY_PARAM': None,
-
- # View configuration
- 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
- 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
+ 'MAX_PAGINATE_BY': None,
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
+ # View configuration
+ 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
+ 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
+
# Testing
'TEST_REQUEST_RENDERER_CLASSES': (
'rest_framework.renderers.MultiPartRenderer',
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index ebccba7d..34fbab9c 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -688,6 +688,14 @@ class ChoiceFieldTests(TestCase):
f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)
self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES)
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES)
+ result = f.from_native('')
+ self.assertEqual(result, None)
+
class EmailFieldTests(TestCase):
"""
diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py
index 85d4640e..4170d4b6 100644
--- a/rest_framework/tests/test_pagination.py
+++ b/rest_framework/tests/test_pagination.py
@@ -42,6 +42,16 @@ class PaginateByParamView(generics.ListAPIView):
paginate_by_param = 'page_size'
+class MaxPaginateByView(generics.ListAPIView):
+ """
+ View for testing custom max_paginate_by usage
+ """
+ model = BasicModel
+ paginate_by = 3
+ max_paginate_by = 5
+ paginate_by_param = 'page_size'
+
+
class IntegrationTestPagination(TestCase):
"""
Integration tests for paginated list views.
@@ -313,6 +323,43 @@ class TestCustomPaginateByParam(TestCase):
self.assertEqual(response.data['results'], self.data[:5])
+class TestMaxPaginateByParam(TestCase):
+ """
+ Tests for list views with max_paginate_by kwarg
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = MaxPaginateByView.as_view()
+
+ def test_max_paginate_by(self):
+ """
+ If max_paginate_by is set, it should limit page size for the view.
+ """
+ request = factory.get('/?page_size=10')
+ response = self.view(request).render()
+ self.assertEqual(response.data['count'], 13)
+ self.assertEqual(response.data['results'], self.data[:5])
+
+ def test_max_paginate_by_without_page_size_param(self):
+ """
+ If max_paginate_by is set, but client does not specifiy page_size,
+ standard `paginate_by` behavior should be used.
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEqual(response.data['results'], self.data[:3])
+
+
### Tests for context in pagination serializers
class CustomField(serializers.Field):
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 65b45593..a946d837 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -2,7 +2,7 @@
Provides various throttling policies.
"""
from __future__ import unicode_literals
-from django.core.cache import cache
+from django.core.cache import cache as default_cache
from django.core.exceptions import ImproperlyConfigured
from rest_framework.settings import api_settings
import time
@@ -39,6 +39,7 @@ class SimpleRateThrottle(BaseThrottle):
Previous request information used for throttling is stored in the cache.
"""
+ cache = default_cache
timer = time.time
cache_format = 'throtte_%(scope)s_%(ident)s'
scope = None
@@ -99,7 +100,7 @@ class SimpleRateThrottle(BaseThrottle):
if self.key is None:
return True
- self.history = cache.get(self.key, [])
+ self.history = self.cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
@@ -116,7 +117,7 @@ class SimpleRateThrottle(BaseThrottle):
into the cache.
"""
self.history.insert(0, self.now)
- cache.set(self.key, self.history, self.duration)
+ self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
@@ -151,7 +152,9 @@ class AnonRateThrottle(SimpleRateThrottle):
if request.user.is_authenticated():
return None # Only throttle unauthenticated requests.
- ident = request.META.get('REMOTE_ADDR', None)
+ ident = request.META.get('HTTP_X_FORWARDED_FOR')
+ if ident is None:
+ ident = request.META.get('REMOTE_ADDR')
return self.cache_format % {
'scope': self.scope,
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py
index 0384faba..e6690d17 100644
--- a/rest_framework/utils/breadcrumbs.py
+++ b/rest_framework/utils/breadcrumbs.py
@@ -8,8 +8,11 @@ def get_breadcrumbs(url):
tuple of (name, url).
"""
+ from rest_framework.settings import api_settings
from rest_framework.views import APIView
+ view_name_func = api_settings.VIEW_NAME_FUNCTION
+
def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):
"""
Add tuples of (name, url) to the breadcrumbs list,
@@ -28,8 +31,8 @@ def get_breadcrumbs(url):
# Don't list the same view twice in a row.
# Probably an optional trailing slash.
if not seen or seen[-1] != view:
- instance = view.cls()
- name = instance.get_view_name()
+ suffix = getattr(view, 'suffix', None)
+ name = view_name_func(cls, suffix)
breadcrumbs_list.insert(0, (name, prefix + url))
seen.append(view)
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 727a9f95..4cff0422 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -15,8 +15,14 @@ from rest_framework.settings import api_settings
from rest_framework.utils import formatting
-def get_view_name(cls, suffix=None):
- name = cls.__name__
+def get_view_name(view_cls, suffix=None):
+ """
+ Given a view class, return a textual name to represent the view.
+ This name is used in the browsable API, and in OPTIONS responses.
+
+ This function is the default for the `VIEW_NAME_FUNCTION` setting.
+ """
+ name = view_cls.__name__
name = formatting.remove_trailing_string(name, 'View')
name = formatting.remove_trailing_string(name, 'ViewSet')
name = formatting.camelcase_to_spaces(name)
@@ -25,17 +31,56 @@ def get_view_name(cls, suffix=None):
return name
-def get_view_description(cls, html=False):
- description = cls.__doc__ or ''
+def get_view_description(view_cls, html=False):
+ """
+ Given a view class, return a textual description to represent the view.
+ This name is used in the browsable API, and in OPTIONS responses.
+
+ This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting.
+ """
+ description = view_cls.__doc__ or ''
description = formatting.dedent(smart_text(description))
if html:
return formatting.markup_description(description)
return description
+def exception_handler(exc):
+ """
+ Returns the response that should be used for any given exception.
+
+ By default we handle the REST framework `APIException`, and also
+ Django's builtin `Http404` and `PermissionDenied` exceptions.
+
+ Any unhandled exceptions may return `None`, which will cause a 500 error
+ to be raised.
+ """
+ if isinstance(exc, exceptions.APIException):
+ headers = {}
+ if getattr(exc, 'auth_header', None):
+ headers['WWW-Authenticate'] = exc.auth_header
+ if getattr(exc, 'wait', None):
+ headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
+
+ return Response({'detail': exc.detail},
+ status=exc.status_code,
+ headers=headers)
+
+ elif isinstance(exc, Http404):
+ return Response({'detail': 'Not found'},
+ status=status.HTTP_404_NOT_FOUND)
+
+ elif isinstance(exc, PermissionDenied):
+ return Response({'detail': 'Permission denied'},
+ status=status.HTTP_403_FORBIDDEN)
+
+ # Note: Unhandled exceptions will raise a 500 error.
+ return None
+
+
class APIView(View):
- settings = api_settings
+ # The following policies may be set at either globally, or per-view.
renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
parser_classes = api_settings.DEFAULT_PARSER_CLASSES
authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
@@ -43,6 +88,9 @@ class APIView(View):
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
+ # Allow dependancy injection of other settings to make testing easier.
+ settings = api_settings
+
@classmethod
def as_view(cls, **initkwargs):
"""
@@ -133,7 +181,7 @@ class APIView(View):
Return the view name, as used in OPTIONS responses and in the
browsable API.
"""
- func = api_settings.VIEW_NAME_FUNCTION
+ func = self.settings.VIEW_NAME_FUNCTION
return func(self.__class__, getattr(self, 'suffix', None))
def get_view_description(self, html=False):
@@ -141,7 +189,7 @@ class APIView(View):
Return some descriptive text for the view, as used in OPTIONS responses
and in the browsable API.
"""
- func = api_settings.VIEW_DESCRIPTION_FUNCTION
+ func = self.settings.VIEW_DESCRIPTION_FUNCTION
return func(self.__class__, html)
# API policy instantiation methods
@@ -303,33 +351,23 @@ class APIView(View):
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
- if isinstance(exc, exceptions.Throttled) and exc.wait is not None:
- # Throttle wait header
- self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
-
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
if auth_header:
- self.headers['WWW-Authenticate'] = auth_header
+ exc.auth_header = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
- if isinstance(exc, exceptions.APIException):
- return Response({'detail': exc.detail},
- status=exc.status_code,
- exception=True)
- elif isinstance(exc, Http404):
- return Response({'detail': 'Not found'},
- status=status.HTTP_404_NOT_FOUND,
- exception=True)
- elif isinstance(exc, PermissionDenied):
- return Response({'detail': 'Permission denied'},
- status=status.HTTP_403_FORBIDDEN,
- exception=True)
- raise
+ response = exception_handler(exc)
+
+ if response is None:
+ raise
+
+ response.exception = True
+ return response
# Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt.