aboutsummaryrefslogtreecommitdiffstats
path: root/djangorestframework
diff options
context:
space:
mode:
Diffstat (limited to 'djangorestframework')
-rw-r--r--djangorestframework/permissions.py12
-rw-r--r--djangorestframework/tests/throttling.py28
-rw-r--r--djangorestframework/throttling.py131
-rw-r--r--djangorestframework/views.py4
4 files changed, 127 insertions, 48 deletions
diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py
index 64e455f5..3a669822 100644
--- a/djangorestframework/permissions.py
+++ b/djangorestframework/permissions.py
@@ -28,11 +28,11 @@ class BasePermission(object):
"""
self.view = view
- def check_permission(self, request, obj=None):
+ def has_permission(self, request, obj=None):
"""
Should simply return, or raise an :exc:`response.ImmediateResponse`.
"""
- raise NotImplementedError(".check_permission() must be overridden.")
+ raise NotImplementedError(".has_permission() must be overridden.")
class IsAuthenticated(BasePermission):
@@ -40,7 +40,7 @@ class IsAuthenticated(BasePermission):
Allows access only to authenticated users.
"""
- def check_permission(self, request, obj=None):
+ def has_permission(self, request, obj=None):
if request.user and request.user.is_authenticated():
return True
return False
@@ -51,7 +51,7 @@ class IsAdminUser(BasePermission):
Allows access only to admin users.
"""
- def check_permission(self, request, obj=None):
+ def has_permission(self, request, obj=None):
if request.user and request.user.is_staff:
return True
return False
@@ -62,7 +62,7 @@ class IsAuthenticatedOrReadOnly(BasePermission):
The request is authenticated as a user, or is a read-only request.
"""
- def check_permission(self, request, obj=None):
+ def has_permission(self, request, obj=None):
if (request.method in SAFE_METHODS or
request.user and
request.user.is_authenticated()):
@@ -105,7 +105,7 @@ class DjangoModelPermissions(BasePermission):
}
return [perm % kwargs for perm in self.perms_map[method]]
- def check_permission(self, request, obj=None):
+ def has_permission(self, request, obj=None):
model_cls = self.view.model
perms = self.get_required_permissions(request.method, model_cls)
diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py
index d144d956..9ee4ffa4 100644
--- a/djangorestframework/tests/throttling.py
+++ b/djangorestframework/tests/throttling.py
@@ -8,24 +8,30 @@ from django.core.cache import cache
from djangorestframework.compat import RequestFactory
from djangorestframework.views import APIView
-from djangorestframework.throttling import PerUserThrottling, PerViewThrottling
+from djangorestframework.throttling import UserRateThrottle
from djangorestframework.response import Response
-class MockView(APIView):
- throttle_classes = (PerUserThrottling,)
+class User3SecRateThrottle(UserRateThrottle):
rate = '3/sec'
+
+class User3MinRateThrottle(UserRateThrottle):
+ rate = '3/min'
+
+
+class MockView(APIView):
+ throttle_classes = (User3SecRateThrottle,)
+
def get(self, request):
return Response('foo')
-class MockView_PerViewThrottling(MockView):
- throttle_classes = (PerViewThrottling,)
+class MockView_MinuteThrottling(APIView):
+ throttle_classes = (User3MinRateThrottle,)
-
-class MockView_MinuteThrottling(MockView):
- rate = '3/min'
+ def get(self, request):
+ return Response('foo')
class ThrottlingTests(TestCase):
@@ -86,12 +92,6 @@ class ThrottlingTests(TestCase):
"""
self.ensure_is_throttled(MockView, 200)
- def test_request_throttling_is_per_view(self):
- """
- Ensure request rate is limited globally per View for PerViewThrottles
- """
- self.ensure_is_throttled(MockView_PerViewThrottling, 429)
-
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
"""
Ensure the response returns an X-Throttle field with status and next attributes
diff --git a/djangorestframework/throttling.py b/djangorestframework/throttling.py
index a096eab7..f8b098d7 100644
--- a/djangorestframework/throttling.py
+++ b/djangorestframework/throttling.py
@@ -1,4 +1,5 @@
from django.core.cache import cache
+from djangorestframework.settings import api_settings
import time
@@ -13,11 +14,11 @@ class BaseThrottle(object):
"""
self.view = view
- def check_throttle(self, request):
+ def allow_request(self, request):
"""
Return `True` if the request should be allowed, `False` otherwise.
"""
- raise NotImplementedError('.check_throttle() must be overridden')
+ raise NotImplementedError('.allow_request() must be overridden')
def wait(self):
"""
@@ -27,7 +28,7 @@ class BaseThrottle(object):
return None
-class SimpleCachingThrottle(BaseThrottle):
+class SimpleRateThottle(BaseThrottle):
"""
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.
@@ -41,33 +42,51 @@ class SimpleCachingThrottle(BaseThrottle):
Previous request information used for throttling is stored in the cache.
"""
- attr_name = 'rate'
- rate = '1000/day'
timer = time.time
+ settings = api_settings
+ cache_format = '%(class)s_%(scope)s_%(ident)s'
+ scope = None
def __init__(self, view):
- """
- Check the throttling.
- Return `None` or raise an :exc:`.ImmediateResponse`.
- """
- super(SimpleCachingThrottle, self).__init__(view)
- num, period = getattr(view, self.attr_name, self.rate).split('/')
- self.num_requests = int(num)
- self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
+ super(SimpleRateThottle, self).__init__(view)
+ rate = self.get_rate_description()
+ self.num_requests, self.duration = self.parse_rate_description(rate)
def get_cache_key(self, request):
"""
Should return a unique cache-key which can be used for throttling.
Must be overridden.
+
+ May return `None` if the request should not be throttled.
"""
raise NotImplementedError('.get_cache_key() must be overridden')
- def check_throttle(self, request):
+ def get_rate_description(self):
+ """
+ Determine the string representation of the allowed request rate.
+ """
+ try:
+ return self.rate
+ except AttributeError:
+ return self.settings.DEFAULT_THROTTLE_RATES.get(self.scope)
+
+ def parse_rate_description(self, rate):
+ """
+ Given the request rate string, return a two tuple of:
+ <allowed number of requests>, <period of time in seconds>
+ """
+ assert rate, "No throttle rate set for '%s'" % self.__class__.__name__
+ num, period = rate.split('/')
+ num_requests = int(num)
+ duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
+ return (num_requests, duration)
+
+ def allow_request(self, request):
"""
Implement the check to see if the request should be throttled.
- On success calls :meth:`throttle_success`.
- On failure calls :meth:`throttle_failure`.
+ On success calls `throttle_success`.
+ On failure calls `throttle_failure`.
"""
self.key = self.get_cache_key(request)
self.history = cache.get(self.key, [])
@@ -110,30 +129,90 @@ class SimpleCachingThrottle(BaseThrottle):
return remaining_duration / float(available_requests)
-class PerUserThrottling(SimpleCachingThrottle):
+class AnonRateThrottle(SimpleRateThottle):
+ """
+ Limits the rate of API calls that may be made by a anonymous users.
+
+ The IP address of the request will be used as the unqiue cache key.
+ """
+ scope = 'anon'
+
+ def get_cache_key(self, request):
+ if request.user.is_authenticated():
+ return None # Only throttle unauthenticated requests.
+
+ ident = request.META.get('REMOTE_ADDR', None)
+
+ return self.cache_format % {
+ 'class': self.__class__.__name__,
+ 'scope': self.scope,
+ 'ident': ident
+ }
+
+
+class UserRateThrottle(SimpleRateThottle):
"""
Limits the rate of API calls that may be made by a given user.
- The user id will be used as a unique identifier if the user is
- authenticated. For anonymous requests, the IP address of the client will
+ The user id will be used as a unique cache key if the user is
+ authenticated. For anonymous requests, the IP address of the request will
be used.
"""
+ scope = 'user'
def get_cache_key(self, request):
if request.user.is_authenticated():
ident = request.user.id
else:
ident = request.META.get('REMOTE_ADDR', None)
- return 'throttle_user_%s' % ident
+ return self.cache_format % {
+ 'class': self.__class__.__name__,
+ 'scope': self.scope,
+ 'ident': ident
+ }
-class PerViewThrottling(SimpleCachingThrottle):
- """
- Limits the rate of API calls that may be used on a given view.
- The class name of the view is used as a unique identifier to
- throttle against.
+class ScopedRateThrottle(SimpleRateThottle):
+ """
+ Limits the rate of API calls by different amounts for various parts of
+ the API. Any view that has the `throttle_scope` property set will be
+ throttled. The unique cache key will be generated by concatenating the
+ user id of the request, and the scope of the view being accessed.
"""
+ def __init__(self, view):
+ """
+ Scope is determined from the view being accessed.
+ """
+ self.scope = getattr(self.view, 'throttle_scope', None)
+ super(ScopedRateThrottle, self).__init__(view)
+
+ def parse_rate_description(self, rate):
+ """
+ Subclassed so that we don't fail if `view.throttle_scope` is not set.
+ """
+ if not rate:
+ return (None, None)
+ return super(ScopedRateThrottle, self).parse_rate_description(rate)
+
def get_cache_key(self, request):
- return 'throttle_view_%s' % self.view.__class__.__name__
+ """
+ If `view.throttle_scope` is not set, don't apply this throttle.
+
+ Otherwise generate the unique cache key by concatenating the user id
+ with the '.throttle_scope` property of the view.
+ """
+ if not self.scope:
+ return None # Only throttle views with `.throttle_scope` set.
+
+ if request.user.is_authenticated():
+ ident = request.user.id
+ else:
+ ident = request.META.get('REMOTE_ADDR', None)
+
+ return self.cache_format % {
+ 'class': self.__class__.__name__,
+ 'scope': self.scope,
+ 'ident': ident
+ }
diff --git a/djangorestframework/views.py b/djangorestframework/views.py
index 5ec55d8c..a309386b 100644
--- a/djangorestframework/views.py
+++ b/djangorestframework/views.py
@@ -186,7 +186,7 @@ class APIView(_View):
Check if request should be permitted.
"""
for permission in self.get_permissions():
- if not permission.check_permission(request, obj):
+ if not permission.has_permission(request, obj):
self.permission_denied(request)
def check_throttles(self, request):
@@ -194,7 +194,7 @@ class APIView(_View):
Check if request should be throttled.
"""
for throttle in self.get_throttles():
- if not throttle.check_throttle(request):
+ if not throttle.allow_request(request):
self.throttled(request, throttle.wait())
def initialize_request(self, request, *args, **kargs):