diff options
| -rw-r--r-- | djangorestframework/permissions.py | 117 | ||||
| -rw-r--r-- | djangorestframework/tests/throttling.py | 101 |
2 files changed, 117 insertions, 101 deletions
diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py index b3fd212b..b8a5224c 100644 --- a/djangorestframework/permissions.py +++ b/djangorestframework/permissions.py @@ -29,6 +29,10 @@ _503_THROTTLED_RESPONSE = ErrorResponse( {'detail': 'request was throttled'}) +class ConfigurationException(BaseException): + """To alert for bad configuration desicions as a convenience.""" + pass + class BasePermission(object): """ @@ -87,70 +91,83 @@ class IsUserOrIsAnonReadOnly(BasePermission): self.view.method != 'HEAD'): raise _403_FORBIDDEN_RESPONSE - -class PerUserThrottling(BasePermission): +class BaseThrottle(BasePermission): """ - Rate throttling of requests on a per-user basis. + Rate throttling of requests. The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class. - The attribute is a two tuple of the form (number of requests, duration in seconds). - - The user id will be used as a unique identifier if the user is authenticated. - For anonymous requests, the IP address of the client will be used. + The attribute is a string of the form 'number of requests/period'. Period must be an element + of (sec, min, hour, day) Previous request information used for throttling is stored in the cache. - """ + """ - def check_permission(self, user): - (num_requests, duration) = getattr(self.view, 'throttle', (0, 0)) - - if user.is_authenticated(): - ident = str(user) - else: - ident = self.view.request.META.get('REMOTE_ADDR', None) + def get_cache_key(self): + """Should return the cache-key corresponding to the semantics of the class that implements + the throttling behaviour. + """ + pass - key = 'throttle_%s' % ident - history = cache.get(key, []) - now = time.time() + def check_permission(self, auth): + num, period = getattr(self.view, 'throttle', '0/sec').split('/') + self.num_requests = int(num) + self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] + self.auth = auth + self.check_throttle() + + def check_throttle(self): + """On success calls `throttle_success`. On failure calls `throttle_failure`. """ + self.key = self.get_cache_key() + self.history = cache.get(self.key, []) + self.now = time.time() # Drop any requests from the history which have now passed the throttle duration - while history and history[0] < now - duration: - history.pop() + while self.history and self.history[0] < self.now - self.duration: + self.history.pop() - if len(history) >= num_requests: - raise _503_THROTTLED_RESPONSE - - history.insert(0, now) - cache.set(key, history, duration) - -class PerResourceThrottling(BasePermission): + if len(self.history) >= self.num_requests: + self.throttle_failure() + else: + self.throttle_success() + + def throttle_success(self): + """Inserts the current request's timesatmp along with the key into the cache.""" + self.history.insert(0, self.now) + cache.set(self.key, self.history, self.duration) + + def throttle_failure(self): + """Raises a 503 """ + raise _503_THROTTLED_RESPONSE + +class PerUserThrottling(BaseThrottle): """ - Rate throttling of requests on a per-resource basis. - - The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class. - The attribute is a two tuple of the form (number of requests, duration in seconds). - The user id will be used as a unique identifier if the user is authenticated. For anonymous requests, the IP address of the client will be used. - - Previous request information used for throttling is stored in the cache. """ - def check_permission(self, ignore): - (num_requests, duration) = getattr(self.view, 'throttle', (0, 0)) - - - key = 'throttle_%s' % self.view.__class__.__name__ - - history = cache.get(key, []) - now = time.time() - - # Drop any requests from the history which have now passed the throttle duration - while history and history[0] < now - duration: - history.pop() + def get_cache_key(self): + if self.auth.is_authenticated(): + ident = str(self.auth) + else: + ident = self.view.request.META.get('REMOTE_ADDR', None) + return 'throttle_%s' % ident - if len(history) >= num_requests: - raise _503_THROTTLED_RESPONSE +class PerViewThrottling(BaseThrottle): + """ + The class name of the cuurent view will be used as a unique identifier. + """ + + def get_cache_key(self): + return 'throttle_%s' % self.view.__class__.__name__ + +class PerResourceThrottling(BaseThrottle): + """ + The class name of the cuurent resource will be used as a unique identifier. + Raises :exc:`ConfigurationException` if no resource attribute is set on the view class. + """ - history.insert(0, now) - cache.set(key, history, duration) + def get_cache_key(self): + if self.view.resource != None: + return 'throttle_%s' % self.view.resource.__class__.__name__ + raise ConfigurationException( + "A per-resource throttle was set to a view that does not have a resource.")
\ No newline at end of file diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py index a9e6803b..6cd69766 100644 --- a/djangorestframework/tests/throttling.py +++ b/djangorestframework/tests/throttling.py @@ -8,78 +8,77 @@ from django.core.cache import cache from djangorestframework.compat import RequestFactory from djangorestframework.views import View -from djangorestframework.permissions import PerUserThrottling, PerResourceThrottling - +from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling, ConfigurationException +from djangorestframework.resources import FormResource class MockView(View): permissions = ( PerUserThrottling, ) - throttle = (3, 1) # 3 requests per second - - def get(self, request): - return 'foo' - -class MockView1(View): - permissions = ( PerResourceThrottling, ) - throttle = (3, 1) # 3 requests per second + throttle = '3/sec' # 3 requests per second def get(self, request): return 'foo' -urlpatterns = patterns('', - (r'^$', MockView.as_view()), - (r'^1$', MockView1.as_view()), -) +class MockView1(MockView): + permissions = ( PerViewThrottling, ) +class MockView2(MockView): + permissions = ( PerResourceThrottling, ) + #No resource set + +class MockView3(MockView2): + resource = FormResource + class ThrottlingTests(TestCase): urls = 'djangorestframework.tests.throttling' def setUp(self): """Reset the cache so that no throttles will be active""" cache.clear() + self.factory = RequestFactory() def test_requests_are_throttled(self): """Ensure request rate is limited""" - for dummy in range(3): - response = self.client.get('/') - response = self.client.get('/') - self.assertEqual(503, response.status_code) - - def test_request_throttling_is_per_user(self): - """Ensure request rate is only limited per user, not globally""" - for username in ('testuser', 'another_testuser'): - user = User.objects.create(username=username) - user.set_password('test') - user.save() - - self.assertTrue(self.client.login(username='testuser', password='test'), msg='Login Failed') - for dummy in range(3): - response = self.client.get('/') - self.client.logout() - self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed') - response = self.client.get('/') - self.assertEqual(200, response.status_code) - - def test_request_throttling_is_per_resource(self): - """Ensure request rate is limited globally per View""" - for username in ('testuser', 'another_testuser'): - user = User.objects.create(username=username) - user.set_password('test') - user.save() - - self.assertTrue(self.client.login(username='testuser', password='test'), msg='Login Failed') - for dummy in range(3): - response = self.client.get('/1') - self.client.logout() - self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed') - response = self.client.get('/1') + request = self.factory.get('/') + for dummy in range(4): + response = MockView.as_view()(request) self.assertEqual(503, response.status_code) def test_request_throttling_expires(self): """Ensure request rate is limited for a limited duration only""" - for dummy in range(3): - response = self.client.get('/') - response = self.client.get('/') + request = self.factory.get('/') + for dummy in range(4): + response = MockView.as_view()(request) self.assertEqual(503, response.status_code) time.sleep(1) - response = self.client.get('/') + response = MockView.as_view()(request) self.assertEqual(200, response.status_code) + + def ensure_is_throttled(self, view, expect): + request = self.factory.get('/') + request.user = User.objects.create(username='a') + for dummy in range(3): + response = view.as_view()(request) + request.user = User.objects.create(username='b') + response = view.as_view()(request) + self.assertEqual(expect, response.status_code) + + def test_request_throttling_is_per_user(self): + """Ensure request rate is only limited per user, not globally for PerUserTrottles""" + self.ensure_is_throttled(MockView, 200) + + def test_request_throttling_is_per_view(self): + """Ensure request rate is limited globally per View for PerViewThrottles""" + self.ensure_is_throttled(MockView1, 503) + + def test_request_throttling_is_per_resource(self): + """Ensure request rate is limited globally per Resource for PerResourceThrottles""" + self.ensure_is_throttled(MockView3, 503) + + def test_raises_no_resource_found(self): + """Ensure an Exception is raised when someone sets at per-resource throttle + on a view with no resource set.""" + request = self.factory.get('/') + view = MockView2.as_view() + self.assertRaises(ConfigurationException, view, request) + +
\ No newline at end of file |
