diff options
| author | markotibold | 2011-06-13 20:42:37 +0200 |
|---|---|---|
| committer | markotibold | 2011-06-13 20:42:37 +0200 |
| commit | 437a062b6c389530b337e809c472fb470827aa78 (patch) | |
| tree | b563e86dfb0b490a9643714069379167586291d0 /djangorestframework | |
| parent | 1720c449045fba54f7af776f0259d6dc84e7e54b (diff) | |
| download | django-rest-framework-437a062b6c389530b337e809c472fb470827aa78.tar.bz2 | |
implemeneted #28
Diffstat (limited to 'djangorestframework')
| -rw-r--r-- | djangorestframework/permissions.py | 22 | ||||
| -rw-r--r-- | djangorestframework/tests/throttling.py | 78 | ||||
| -rw-r--r-- | djangorestframework/views.py | 17 |
3 files changed, 91 insertions, 26 deletions
diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py index 34ab5bf4..4825a174 100644 --- a/djangorestframework/permissions.py +++ b/djangorestframework/permissions.py @@ -31,11 +31,6 @@ _503_SERVICE_UNAVAILABLE = ErrorResponse( {'detail': 'request was throttled'}) -class ConfigurationException(BaseException): - """To alert for bad configuration decisions as a convenience.""" - pass - - class BasePermission(object): """ A base class from which all permission classes should inherit. @@ -144,12 +139,11 @@ class BaseThrottle(BasePermission): # throttle duration while self.history and self.history[0] <= self.now - self.duration: self.history.pop() - if len(self.history) >= self.num_requests: self.throttle_failure() else: self.throttle_success() - + def throttle_success(self): """ Inserts the current request's timestamp along with the key @@ -157,15 +151,23 @@ class BaseThrottle(BasePermission): """ self.history.insert(0, self.now) cache.set(self.key, self.history, self.duration) - + self.view.add_header('X-Throttle', 'status=SUCCESS; next=%s sec' % self.next()) + def throttle_failure(self): """ Called when a request to the API has failed due to throttling. Raises a '503 service unavailable' response. """ + self.view.add_header('X-Throttle', 'status=FAILURE; next=%s sec' % self.next()) raise _503_SERVICE_UNAVAILABLE - - + + def next(self): + """ + Returns the recommended next request time in seconds. + """ + return '%.2f' % (self.duration / (self.num_requests - len(self.history) *1.0 + 1)) + + class PerUserThrottling(BaseThrottle): """ Limits the rate of API calls that may be made by a given user. diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py index be552638..80cfc2e1 100644 --- a/djangorestframework/tests/throttling.py +++ b/djangorestframework/tests/throttling.py @@ -1,17 +1,14 @@ """ Tests for the throttling implementations in the permissions module. """ -import time -from django.conf.urls.defaults import patterns from django.test import TestCase -from django.utils import simplejson as json from django.contrib.auth.models import User from django.core.cache import cache from djangorestframework.compat import RequestFactory from djangorestframework.views import View -from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling, ConfigurationException +from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling from djangorestframework.resources import FormResource class MockView(View): @@ -30,28 +27,40 @@ class MockView2(MockView): class MockView3(MockView2): resource = FormResource + +class MockView4(MockView): + throttle = '3/min' # 3 request per minute class ThrottlingTests(TestCase): urls = 'djangorestframework.tests.throttling' def setUp(self): - """Reset the cache so that no throttles will be active""" + """ + Reset the cache so that no throttles will be active + """ cache.clear() self.factory = RequestFactory() def test_requests_are_throttled(self): - """Ensure request rate is limited""" + """ + Ensure request rate is limited + """ request = self.factory.get('/') for dummy in range(4): response = MockView.as_view()(request) self.assertEqual(503, response.status_code) + def set_throttle_timer(self, view, value): + """ + Explicitly set the timer, overriding time.time() + """ + view.permissions[0].timer = lambda self: value + def test_request_throttling_expires(self): """ Ensure request rate is limited for a limited duration only """ - # Explicitly set the timer, overridding time.time() - MockView.permissions[0].timer = lambda self: 0 + self.set_throttle_timer(MockView, 0) request = self.factory.get('/') for dummy in range(4): @@ -59,7 +68,7 @@ class ThrottlingTests(TestCase): self.assertEqual(503, response.status_code) # Advance the timer by one second - MockView.permissions[0].timer = lambda self: 1 + self.set_throttle_timer(MockView, 1) response = MockView.as_view()(request) self.assertEqual(200, response.status_code) @@ -68,20 +77,61 @@ class ThrottlingTests(TestCase): request = self.factory.get('/') request.user = User.objects.create(username='a') for dummy in range(3): - response = view.as_view()(request) + view.as_view()(request) request.user = User.objects.create(username='b') response = view.as_view()(request) self.assertEqual(expect, response.status_code) def test_request_throttling_is_per_user(self): - """Ensure request rate is only limited per user, not globally for PerUserThrottles""" + """ + Ensure request rate is only limited per user, not globally for + PerUserThrottles + """ self.ensure_is_throttled(MockView, 200) def test_request_throttling_is_per_view(self): - """Ensure request rate is limited globally per View for PerViewThrottles""" + """ + Ensure request rate is limited globally per View for PerViewThrottles + """ self.ensure_is_throttled(MockView1, 503) def test_request_throttling_is_per_resource(self): - """Ensure request rate is limited globally per Resource for PerResourceThrottles""" + """ + Ensure request rate is limited globally per Resource for PerResourceThrottles + """ self.ensure_is_throttled(MockView3, 503) -
\ No newline at end of file + + + def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): + """ + Ensure the response returns an X-Throttle field with status and next attributes + set properly. + """ + request = self.factory.get('/') + for expect in expected_headers: + self.set_throttle_timer(view, 0) + response = view.as_view()(request) + self.assertEquals(response['X-Throttle'], expect) + + def test_seconds_fields(self): + """ + Ensure for second based throttles. + """ + self.ensure_response_header_contains_proper_throttle_field(MockView, + ('status=SUCCESS; next=0.33 sec', + 'status=SUCCESS; next=0.50 sec', + 'status=SUCCESS; next=1.00 sec', + 'status=FAILURE; next=1.00 sec' + )) + + def test_minutes_fields(self): + """ + Ensure for minute based throttles. + """ + self.ensure_response_header_contains_proper_throttle_field(MockView4, + ('status=SUCCESS; next=20.00 sec', + 'status=SUCCESS; next=30.00 sec', + 'status=SUCCESS; next=60.00 sec', + 'status=FAILURE; next=60.00 sec' + )) +
\ No newline at end of file diff --git a/djangorestframework/views.py b/djangorestframework/views.py index 6f2ab5b7..e38207ac 100644 --- a/djangorestframework/views.py +++ b/djangorestframework/views.py @@ -64,7 +64,11 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): """ permissions = ( permissions.FullAnonAccess, ) - + """ + Headers to be sent with response. + """ + headers = {} + @classmethod def as_view(cls, **initkwargs): """ @@ -101,6 +105,12 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): """ pass + def add_header(self, field, value): + """ + Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class. + """ + self.headers[field] = value + # Note: session based authentication is explicitly CSRF validated, # all other authentication is CSRF exempt. @csrf_exempt @@ -149,7 +159,10 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): # also it's currently sub-obtimal for HTTP caching - need to sort that out. response.headers['Allow'] = ', '.join(self.allowed_methods) response.headers['Vary'] = 'Authenticate, Accept' - + + # merge with headers possibly set by a Throttle class + response.headers = dict(response.headers.items() + self.headers.items()) + return self.render(response) |
