aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2012-09-04 21:58:35 +0100
committerTom Christie2012-09-04 21:58:35 +0100
commitc28b719333b16935e53c76fef79b096cb11322ed (patch)
tree5b22784601e52b9b9f7db9385cceb51339681065
parent8457c871963264c9f62552f30307e98221a1c25d (diff)
downloaddjango-rest-framework-c28b719333b16935e53c76fef79b096cb11322ed.tar.bz2
Refactored throttling
-rw-r--r--djangorestframework/exceptions.py12
-rw-r--r--djangorestframework/parsers.py2
-rw-r--r--djangorestframework/permissions.py172
-rw-r--r--djangorestframework/tests/throttling.py43
-rw-r--r--djangorestframework/throttling.py139
-rw-r--r--djangorestframework/views.py49
6 files changed, 234 insertions, 183 deletions
diff --git a/djangorestframework/exceptions.py b/djangorestframework/exceptions.py
index 51c5dbb7..0b4dacf7 100644
--- a/djangorestframework/exceptions.py
+++ b/djangorestframework/exceptions.py
@@ -49,8 +49,14 @@ class UnsupportedMediaType(APIException):
class Throttled(APIException):
status_code = status.HTTP_429_TOO_MANY_REQUESTS
- default_detail = "Request was throttled. Expected available in %d seconds."
+ default_detail = "Request was throttled."
+ extra_detail = "Expected available in %d second%s."
- def __init__(self, wait, detail=None):
+ def __init__(self, wait=None, detail=None):
import math
- self.detail = (detail or self.default_detail) % int(math.ceil(wait))
+ self.wait = wait and math.ceil(wait) or None
+ if wait is not None:
+ format = detail or self.default_detail + self.extra_detail
+ self.detail = format % (self.wait, self.wait != 1 and 's' or '')
+ else:
+ self.detail = detail or self.default_detail
diff --git a/djangorestframework/parsers.py b/djangorestframework/parsers.py
index 96dd81ed..fb08c5a0 100644
--- a/djangorestframework/parsers.py
+++ b/djangorestframework/parsers.py
@@ -81,7 +81,7 @@ class BaseParser(object):
Should return parsed data, or a DataAndFiles object consisting of the
parsed data and files.
"""
- raise NotImplementedError(".parse_stream() Must be overridden to be implemented.")
+ raise NotImplementedError(".parse_stream() must be overridden.")
class JSONParser(BaseParser):
diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py
index bdda4def..d6405a36 100644
--- a/djangorestframework/permissions.py
+++ b/djangorestframework/permissions.py
@@ -5,10 +5,6 @@ for checking if a request passes a certain set of constraints.
Permission behavior is provided by mixing the :class:`mixins.PermissionsMixin` class into a :class:`View` class.
"""
-from django.core.cache import cache
-from djangorestframework.exceptions import PermissionDenied, Throttled
-import time
-
__all__ = (
'BasePermission',
'FullAnonAccess',
@@ -32,20 +28,11 @@ class BasePermission(object):
"""
self.view = view
- def check_permission(self, auth):
+ def check_permission(self, request, obj=None):
"""
Should simply return, or raise an :exc:`response.ImmediateResponse`.
"""
- pass
-
-
-class FullAnonAccess(BasePermission):
- """
- Allows full access.
- """
-
- def check_permission(self, user):
- pass
+ raise NotImplementedError(".check_permission() must be overridden.")
class IsAuthenticated(BasePermission):
@@ -53,9 +40,10 @@ class IsAuthenticated(BasePermission):
Allows access only to authenticated users.
"""
- def check_permission(self, user):
- if not user.is_authenticated():
- raise PermissionDenied()
+ def check_permission(self, request, obj=None):
+ if request.user.is_authenticated():
+ return True
+ return False
class IsAdminUser(BasePermission):
@@ -63,20 +51,22 @@ class IsAdminUser(BasePermission):
Allows access only to admin users.
"""
- def check_permission(self, user):
- if not user.is_staff:
- raise PermissionDenied()
+ def check_permission(self, request, obj=None):
+ if request.user.is_staff:
+ return True
+ return False
-class IsUserOrIsAnonReadOnly(BasePermission):
+class IsAuthenticatedOrReadOnly(BasePermission):
"""
The request is authenticated as a user, or is a read-only request.
"""
- def check_permission(self, user):
- if (not user.is_authenticated() and
- self.view.method not in SAFE_METHODS):
- raise PermissionDenied()
+ def check_permission(self, request, obj=None):
+ if (request.user.is_authenticated() or
+ request.method in SAFE_METHODS):
+ return True
+ return False
class DjangoModelPermissions(BasePermission):
@@ -114,128 +104,10 @@ class DjangoModelPermissions(BasePermission):
}
return [perm % kwargs for perm in self.perms_map[method]]
- def check_permission(self, user):
- method = self.view.method
- model_cls = self.view.resource.model
- perms = self.get_required_permissions(method, model_cls)
-
- if not user.is_authenticated or not user.has_perms(perms):
- raise PermissionDenied()
-
-
-class BaseThrottle(BasePermission):
- """
- Rate throttling of requests.
-
- The rate (requests / seconds) is set by a :attr:`throttle` attribute
- on the :class:`.View` class. The attribute is a string of the form 'number of
- requests/period'.
-
- Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
-
- Previous request information used for throttling is stored in the cache.
- """
-
- attr_name = 'throttle'
- default = '0/sec'
- timer = time.time
-
- def get_cache_key(self):
- """
- Should return a unique cache-key which can be used for throttling.
- Must be overridden.
- """
- pass
-
- def check_permission(self, auth):
- """
- Check the throttling.
- Return `None` or raise an :exc:`.ImmediateResponse`.
- """
- num, period = getattr(self.view, self.attr_name, self.default).split('/')
- self.num_requests = int(num)
- self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
- self.auth = auth
- self.check_throttle()
-
- def check_throttle(self):
- """
- Implement the check to see if the request should be throttled.
-
- On success calls :meth:`throttle_success`.
- On failure calls :meth:`throttle_failure`.
- """
- self.key = self.get_cache_key()
- self.history = cache.get(self.key, [])
- self.now = self.timer()
-
- # Drop any requests from the history which have now passed the
- # throttle duration
- while self.history and self.history[-1] <= self.now - self.duration:
- self.history.pop()
- if len(self.history) >= self.num_requests:
- self.throttle_failure()
- else:
- self.throttle_success()
-
- def throttle_success(self):
- """
- Inserts the current request's timestamp along with the key
- into the cache.
- """
- self.history.insert(0, self.now)
- cache.set(self.key, self.history, self.duration)
- header = 'status=SUCCESS; next=%.2f sec' % self.next()
- self.view.headers['X-Throttle'] = header
-
- def throttle_failure(self):
- """
- Called when a request to the API has failed due to throttling.
- Raises a '503 service unavailable' response.
- """
- wait = self.next()
- header = 'status=FAILURE; next=%.2f sec' % wait
- self.view.headers['X-Throttle'] = header
- raise Throttled(wait)
-
- def next(self):
- """
- Returns the recommended next request time in seconds.
- """
- if self.history:
- remaining_duration = self.duration - (self.now - self.history[-1])
- else:
- remaining_duration = self.duration
-
- available_requests = self.num_requests - len(self.history) + 1
-
- return remaining_duration / float(available_requests)
-
-
-class PerUserThrottling(BaseThrottle):
- """
- Limits the rate of API calls that may be made by a given user.
-
- The user id will be used as a unique identifier if the user is
- authenticated. For anonymous requests, the IP address of the client will
- be used.
- """
-
- def get_cache_key(self):
- if self.auth.is_authenticated():
- ident = self.auth.id
- else:
- ident = self.view.request.META.get('REMOTE_ADDR', None)
- return 'throttle_user_%s' % ident
-
-
-class PerViewThrottling(BaseThrottle):
- """
- Limits the rate of API calls that may be used on a given view.
-
- The class name of the view is used as a unique identifier to
- throttle against.
- """
+ def check_permission(self, request, obj=None):
+ model_cls = self.view.model
+ perms = self.get_required_permissions(request.method, model_cls)
- def get_cache_key(self):
- return 'throttle_view_%s' % self.view.__class__.__name__
+ if request.user.is_authenticated() and request.user.has_perms(perms, obj):
+ return True
+ return False
diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py
index a8e446e8..d144d956 100644
--- a/djangorestframework/tests/throttling.py
+++ b/djangorestframework/tests/throttling.py
@@ -8,24 +8,24 @@ from django.core.cache import cache
from djangorestframework.compat import RequestFactory
from djangorestframework.views import APIView
-from djangorestframework.permissions import PerUserThrottling, PerViewThrottling
+from djangorestframework.throttling import PerUserThrottling, PerViewThrottling
from djangorestframework.response import Response
class MockView(APIView):
- permission_classes = (PerUserThrottling,)
- throttle = '3/sec'
+ throttle_classes = (PerUserThrottling,)
+ rate = '3/sec'
def get(self, request):
return Response('foo')
class MockView_PerViewThrottling(MockView):
- permission_classes = (PerViewThrottling,)
+ throttle_classes = (PerViewThrottling,)
class MockView_MinuteThrottling(MockView):
- throttle = '3/min'
+ rate = '3/min'
class ThrottlingTests(TestCase):
@@ -51,7 +51,7 @@ class ThrottlingTests(TestCase):
"""
Explicitly set the timer, overriding time.time()
"""
- view.permission_classes[0].timer = lambda self: value
+ view.throttle_classes[0].timer = lambda self: value
def test_request_throttling_expires(self):
"""
@@ -101,17 +101,20 @@ class ThrottlingTests(TestCase):
for timer, expect in expected_headers:
self.set_throttle_timer(view, timer)
response = view.as_view()(request)
- self.assertEquals(response['X-Throttle'], expect)
+ if expect is not None:
+ self.assertEquals(response['X-Throttle-Wait-Seconds'], expect)
+ else:
+ self.assertFalse('X-Throttle-Wait-Seconds' in response.headers)
def test_seconds_fields(self):
"""
Ensure for second based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView,
- ((0, 'status=SUCCESS; next=0.33 sec'),
- (0, 'status=SUCCESS; next=0.50 sec'),
- (0, 'status=SUCCESS; next=1.00 sec'),
- (0, 'status=FAILURE; next=1.00 sec')
+ ((0, None),
+ (0, None),
+ (0, None),
+ (0, '1')
))
def test_minutes_fields(self):
@@ -119,10 +122,10 @@ class ThrottlingTests(TestCase):
Ensure for minute based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
- ((0, 'status=SUCCESS; next=20.00 sec'),
- (0, 'status=SUCCESS; next=30.00 sec'),
- (0, 'status=SUCCESS; next=60.00 sec'),
- (0, 'status=FAILURE; next=60.00 sec')
+ ((0, None),
+ (0, None),
+ (0, None),
+ (0, '60')
))
def test_next_rate_remains_constant_if_followed(self):
@@ -131,9 +134,9 @@ class ThrottlingTests(TestCase):
the throttling rate should stay constant.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
- ((0, 'status=SUCCESS; next=20.00 sec'),
- (20, 'status=SUCCESS; next=20.00 sec'),
- (40, 'status=SUCCESS; next=20.00 sec'),
- (60, 'status=SUCCESS; next=20.00 sec'),
- (80, 'status=SUCCESS; next=20.00 sec')
+ ((0, None),
+ (20, None),
+ (40, None),
+ (60, None),
+ (80, None)
))
diff --git a/djangorestframework/throttling.py b/djangorestframework/throttling.py
new file mode 100644
index 00000000..a096eab7
--- /dev/null
+++ b/djangorestframework/throttling.py
@@ -0,0 +1,139 @@
+from django.core.cache import cache
+import time
+
+
+class BaseThrottle(object):
+ """
+ Rate throttling of requests.
+ """
+
+ def __init__(self, view=None):
+ """
+ All throttles hold a reference to the instantiating view.
+ """
+ self.view = view
+
+ def check_throttle(self, request):
+ """
+ Return `True` if the request should be allowed, `False` otherwise.
+ """
+ raise NotImplementedError('.check_throttle() must be overridden')
+
+ def wait(self):
+ """
+ Optionally, return a recommeded number of seconds to wait before
+ the next request.
+ """
+ return None
+
+
+class SimpleCachingThrottle(BaseThrottle):
+ """
+ A simple cache implementation, that only requires `.get_cache_key()`
+ to be overridden.
+
+ The rate (requests / seconds) is set by a :attr:`throttle` attribute
+ on the :class:`.View` class. The attribute is a string of the form 'number of
+ requests/period'.
+
+ Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
+
+ Previous request information used for throttling is stored in the cache.
+ """
+
+ attr_name = 'rate'
+ rate = '1000/day'
+ timer = time.time
+
+ def __init__(self, view):
+ """
+ Check the throttling.
+ Return `None` or raise an :exc:`.ImmediateResponse`.
+ """
+ super(SimpleCachingThrottle, self).__init__(view)
+ num, period = getattr(view, self.attr_name, self.rate).split('/')
+ self.num_requests = int(num)
+ self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
+
+ def get_cache_key(self, request):
+ """
+ Should return a unique cache-key which can be used for throttling.
+ Must be overridden.
+ """
+ raise NotImplementedError('.get_cache_key() must be overridden')
+
+ def check_throttle(self, request):
+ """
+ Implement the check to see if the request should be throttled.
+
+ On success calls :meth:`throttle_success`.
+ On failure calls :meth:`throttle_failure`.
+ """
+ self.key = self.get_cache_key(request)
+ self.history = cache.get(self.key, [])
+ self.now = self.timer()
+
+ # Drop any requests from the history which have now passed the
+ # throttle duration
+ while self.history and self.history[-1] <= self.now - self.duration:
+ self.history.pop()
+ if len(self.history) >= self.num_requests:
+ return self.throttle_failure()
+ return self.throttle_success()
+
+ def throttle_success(self):
+ """
+ Inserts the current request's timestamp along with the key
+ into the cache.
+ """
+ self.history.insert(0, self.now)
+ cache.set(self.key, self.history, self.duration)
+ return True
+
+ def throttle_failure(self):
+ """
+ Called when a request to the API has failed due to throttling.
+ """
+ return False
+
+ def wait(self):
+ """
+ Returns the recommended next request time in seconds.
+ """
+ if self.history:
+ remaining_duration = self.duration - (self.now - self.history[-1])
+ else:
+ remaining_duration = self.duration
+
+ available_requests = self.num_requests - len(self.history) + 1
+
+ return remaining_duration / float(available_requests)
+
+
+class PerUserThrottling(SimpleCachingThrottle):
+ """
+ Limits the rate of API calls that may be made by a given user.
+
+ The user id will be used as a unique identifier if the user is
+ authenticated. For anonymous requests, the IP address of the client will
+ be used.
+ """
+
+ def get_cache_key(self, request):
+ if request.user.is_authenticated():
+ ident = request.user.id
+ else:
+ ident = request.META.get('REMOTE_ADDR', None)
+ return 'throttle_user_%s' % ident
+
+
+class PerViewThrottling(SimpleCachingThrottle):
+ """
+ Limits the rate of API calls that may be used on a given view.
+
+ The class name of the view is used as a unique identifier to
+ throttle against.
+ """
+
+ def get_cache_key(self, request):
+ return 'throttle_view_%s' % self.view.__class__.__name__
diff --git a/djangorestframework/views.py b/djangorestframework/views.py
index 3f0138d8..9796b362 100644
--- a/djangorestframework/views.py
+++ b/djangorestframework/views.py
@@ -18,7 +18,7 @@ from djangorestframework.compat import View as _View, apply_markdown
from djangorestframework.response import Response
from djangorestframework.request import Request
from djangorestframework.settings import api_settings
-from djangorestframework import parsers, authentication, permissions, status, exceptions, mixins
+from djangorestframework import parsers, authentication, status, exceptions, mixins
__all__ = (
@@ -86,7 +86,12 @@ class APIView(_View):
List of all authenticating methods to attempt.
"""
- permission_classes = (permissions.FullAnonAccess,)
+ throttle_classes = ()
+ """
+ List of all throttles to check.
+ """
+
+ permission_classes = ()
"""
List of all permissions that must be checked.
"""
@@ -195,12 +200,27 @@ class APIView(_View):
"""
return [permission(self) for permission in self.permission_classes]
- def check_permissions(self, user):
+ def get_throttles(self):
"""
- Check user permissions and either raise an ``ImmediateResponse`` or return.
+ Instantiates and returns the list of thottles that this view requires.
+ """
+ return [throttle(self) for throttle in self.throttle_classes]
+
+ def check_permissions(self, request, obj=None):
+ """
+ Check user permissions and either raise an ``PermissionDenied`` or return.
"""
for permission in self.get_permissions():
- permission.check_permission(user)
+ if not permission.check_permission(request, obj):
+ raise exceptions.PermissionDenied()
+
+ def check_throttles(self, request):
+ """
+ Check throttles and either raise a `Throttled` exception or return.
+ """
+ for throttle in self.get_throttles():
+ if not throttle.check_throttle(request):
+ raise exceptions.Throttled(throttle.wait())
def initial(self, request, *args, **kargs):
"""
@@ -232,6 +252,9 @@ class APIView(_View):
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
+ if isinstance(exc, exceptions.Throttled):
+ self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
+
if isinstance(exc, exceptions.APIException):
return Response({'detail': exc.detail}, status=exc.status_code)
elif isinstance(exc, Http404):
@@ -255,8 +278,9 @@ class APIView(_View):
try:
self.initial(request, *args, **kwargs)
- # check that user has the relevant permissions
- self.check_permissions(request.user)
+ # Check that the request is allowed
+ self.check_permissions(request)
+ self.check_throttles(request)
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
@@ -283,11 +307,12 @@ class BaseView(APIView):
serializer_class = None
def get_serializer(self, data=None, files=None, instance=None):
+ # TODO: add support for files
context = {
'request': self.request,
'format': self.kwargs.get('format', None)
}
- return self.serializer_class(data, context=context)
+ return self.serializer_class(data, instance=instance, context=context)
class MultipleObjectBaseView(MultipleObjectMixin, BaseView):
@@ -301,7 +326,13 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView):
"""
Base class for generic views onto a model instance.
"""
- pass
+
+ def get_object(self):
+ """
+ Override default to add support for object-level permissions.
+ """
+ super(self, SingleObjectBaseView).get_object()
+ self.check_permissions(self.request, self.object)
# Concrete view classes that provide method handlers