aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--djangorestframework/permissions.py117
-rw-r--r--djangorestframework/tests/throttling.py101
2 files changed, 117 insertions, 101 deletions
diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py
index b3fd212b..b8a5224c 100644
--- a/djangorestframework/permissions.py
+++ b/djangorestframework/permissions.py
@@ -29,6 +29,10 @@ _503_THROTTLED_RESPONSE = ErrorResponse(
{'detail': 'request was throttled'})
+class ConfigurationException(BaseException):
+ """To alert for bad configuration desicions as a convenience."""
+ pass
+
class BasePermission(object):
"""
@@ -87,70 +91,83 @@ class IsUserOrIsAnonReadOnly(BasePermission):
self.view.method != 'HEAD'):
raise _403_FORBIDDEN_RESPONSE
-
-class PerUserThrottling(BasePermission):
+class BaseThrottle(BasePermission):
"""
- Rate throttling of requests on a per-user basis.
+ Rate throttling of requests.
The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class.
- The attribute is a two tuple of the form (number of requests, duration in seconds).
-
- The user id will be used as a unique identifier if the user is authenticated.
- For anonymous requests, the IP address of the client will be used.
+ The attribute is a string of the form 'number of requests/period'. Period must be an element
+ of (sec, min, hour, day)
Previous request information used for throttling is stored in the cache.
- """
+ """
- def check_permission(self, user):
- (num_requests, duration) = getattr(self.view, 'throttle', (0, 0))
-
- if user.is_authenticated():
- ident = str(user)
- else:
- ident = self.view.request.META.get('REMOTE_ADDR', None)
+ def get_cache_key(self):
+ """Should return the cache-key corresponding to the semantics of the class that implements
+ the throttling behaviour.
+ """
+ pass
- key = 'throttle_%s' % ident
- history = cache.get(key, [])
- now = time.time()
+ def check_permission(self, auth):
+ num, period = getattr(self.view, 'throttle', '0/sec').split('/')
+ self.num_requests = int(num)
+ self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
+ self.auth = auth
+ self.check_throttle()
+
+ def check_throttle(self):
+ """On success calls `throttle_success`. On failure calls `throttle_failure`. """
+ self.key = self.get_cache_key()
+ self.history = cache.get(self.key, [])
+ self.now = time.time()
# Drop any requests from the history which have now passed the throttle duration
- while history and history[0] < now - duration:
- history.pop()
+ while self.history and self.history[0] < self.now - self.duration:
+ self.history.pop()
- if len(history) >= num_requests:
- raise _503_THROTTLED_RESPONSE
-
- history.insert(0, now)
- cache.set(key, history, duration)
-
-class PerResourceThrottling(BasePermission):
+ if len(self.history) >= self.num_requests:
+ self.throttle_failure()
+ else:
+ self.throttle_success()
+
+ def throttle_success(self):
+ """Inserts the current request's timesatmp along with the key into the cache."""
+ self.history.insert(0, self.now)
+ cache.set(self.key, self.history, self.duration)
+
+ def throttle_failure(self):
+ """Raises a 503 """
+ raise _503_THROTTLED_RESPONSE
+
+class PerUserThrottling(BaseThrottle):
"""
- Rate throttling of requests on a per-resource basis.
-
- The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class.
- The attribute is a two tuple of the form (number of requests, duration in seconds).
-
The user id will be used as a unique identifier if the user is authenticated.
For anonymous requests, the IP address of the client will be used.
-
- Previous request information used for throttling is stored in the cache.
"""
- def check_permission(self, ignore):
- (num_requests, duration) = getattr(self.view, 'throttle', (0, 0))
-
-
- key = 'throttle_%s' % self.view.__class__.__name__
-
- history = cache.get(key, [])
- now = time.time()
-
- # Drop any requests from the history which have now passed the throttle duration
- while history and history[0] < now - duration:
- history.pop()
+ def get_cache_key(self):
+ if self.auth.is_authenticated():
+ ident = str(self.auth)
+ else:
+ ident = self.view.request.META.get('REMOTE_ADDR', None)
+ return 'throttle_%s' % ident
- if len(history) >= num_requests:
- raise _503_THROTTLED_RESPONSE
+class PerViewThrottling(BaseThrottle):
+ """
+ The class name of the cuurent view will be used as a unique identifier.
+ """
+
+ def get_cache_key(self):
+ return 'throttle_%s' % self.view.__class__.__name__
+
+class PerResourceThrottling(BaseThrottle):
+ """
+ The class name of the cuurent resource will be used as a unique identifier.
+ Raises :exc:`ConfigurationException` if no resource attribute is set on the view class.
+ """
- history.insert(0, now)
- cache.set(key, history, duration)
+ def get_cache_key(self):
+ if self.view.resource != None:
+ return 'throttle_%s' % self.view.resource.__class__.__name__
+ raise ConfigurationException(
+ "A per-resource throttle was set to a view that does not have a resource.") \ No newline at end of file
diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py
index a9e6803b..6cd69766 100644
--- a/djangorestframework/tests/throttling.py
+++ b/djangorestframework/tests/throttling.py
@@ -8,78 +8,77 @@ from django.core.cache import cache
from djangorestframework.compat import RequestFactory
from djangorestframework.views import View
-from djangorestframework.permissions import PerUserThrottling, PerResourceThrottling
-
+from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling, ConfigurationException
+from djangorestframework.resources import FormResource
class MockView(View):
permissions = ( PerUserThrottling, )
- throttle = (3, 1) # 3 requests per second
-
- def get(self, request):
- return 'foo'
-
-class MockView1(View):
- permissions = ( PerResourceThrottling, )
- throttle = (3, 1) # 3 requests per second
+ throttle = '3/sec' # 3 requests per second
def get(self, request):
return 'foo'
-urlpatterns = patterns('',
- (r'^$', MockView.as_view()),
- (r'^1$', MockView1.as_view()),
-)
+class MockView1(MockView):
+ permissions = ( PerViewThrottling, )
+class MockView2(MockView):
+ permissions = ( PerResourceThrottling, )
+ #No resource set
+
+class MockView3(MockView2):
+ resource = FormResource
+
class ThrottlingTests(TestCase):
urls = 'djangorestframework.tests.throttling'
def setUp(self):
"""Reset the cache so that no throttles will be active"""
cache.clear()
+ self.factory = RequestFactory()
def test_requests_are_throttled(self):
"""Ensure request rate is limited"""
- for dummy in range(3):
- response = self.client.get('/')
- response = self.client.get('/')
- self.assertEqual(503, response.status_code)
-
- def test_request_throttling_is_per_user(self):
- """Ensure request rate is only limited per user, not globally"""
- for username in ('testuser', 'another_testuser'):
- user = User.objects.create(username=username)
- user.set_password('test')
- user.save()
-
- self.assertTrue(self.client.login(username='testuser', password='test'), msg='Login Failed')
- for dummy in range(3):
- response = self.client.get('/')
- self.client.logout()
- self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed')
- response = self.client.get('/')
- self.assertEqual(200, response.status_code)
-
- def test_request_throttling_is_per_resource(self):
- """Ensure request rate is limited globally per View"""
- for username in ('testuser', 'another_testuser'):
- user = User.objects.create(username=username)
- user.set_password('test')
- user.save()
-
- self.assertTrue(self.client.login(username='testuser', password='test'), msg='Login Failed')
- for dummy in range(3):
- response = self.client.get('/1')
- self.client.logout()
- self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed')
- response = self.client.get('/1')
+ request = self.factory.get('/')
+ for dummy in range(4):
+ response = MockView.as_view()(request)
self.assertEqual(503, response.status_code)
def test_request_throttling_expires(self):
"""Ensure request rate is limited for a limited duration only"""
- for dummy in range(3):
- response = self.client.get('/')
- response = self.client.get('/')
+ request = self.factory.get('/')
+ for dummy in range(4):
+ response = MockView.as_view()(request)
self.assertEqual(503, response.status_code)
time.sleep(1)
- response = self.client.get('/')
+ response = MockView.as_view()(request)
self.assertEqual(200, response.status_code)
+
+ def ensure_is_throttled(self, view, expect):
+ request = self.factory.get('/')
+ request.user = User.objects.create(username='a')
+ for dummy in range(3):
+ response = view.as_view()(request)
+ request.user = User.objects.create(username='b')
+ response = view.as_view()(request)
+ self.assertEqual(expect, response.status_code)
+
+ def test_request_throttling_is_per_user(self):
+ """Ensure request rate is only limited per user, not globally for PerUserTrottles"""
+ self.ensure_is_throttled(MockView, 200)
+
+ def test_request_throttling_is_per_view(self):
+ """Ensure request rate is limited globally per View for PerViewThrottles"""
+ self.ensure_is_throttled(MockView1, 503)
+
+ def test_request_throttling_is_per_resource(self):
+ """Ensure request rate is limited globally per Resource for PerResourceThrottles"""
+ self.ensure_is_throttled(MockView3, 503)
+
+ def test_raises_no_resource_found(self):
+ """Ensure an Exception is raised when someone sets at per-resource throttle
+ on a view with no resource set."""
+ request = self.factory.get('/')
+ view = MockView2.as_view()
+ self.assertRaises(ConfigurationException, view, request)
+
+ \ No newline at end of file