aboutsummaryrefslogtreecommitdiffstats
path: root/djangorestframework
diff options
context:
space:
mode:
Diffstat (limited to 'djangorestframework')
-rw-r--r--djangorestframework/permissions.py29
-rw-r--r--djangorestframework/runtests/runtests.py32
-rw-r--r--djangorestframework/tests/throttling.py105
-rw-r--r--djangorestframework/views.py18
4 files changed, 135 insertions, 49 deletions
diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py
index 34ab5bf4..7dcabcf0 100644
--- a/djangorestframework/permissions.py
+++ b/djangorestframework/permissions.py
@@ -31,11 +31,6 @@ _503_SERVICE_UNAVAILABLE = ErrorResponse(
{'detail': 'request was throttled'})
-class ConfigurationException(BaseException):
- """To alert for bad configuration decisions as a convenience."""
- pass
-
-
class BasePermission(object):
"""
A base class from which all permission classes should inherit.
@@ -142,14 +137,13 @@ class BaseThrottle(BasePermission):
# Drop any requests from the history which have now passed the
# throttle duration
- while self.history and self.history[0] <= self.now - self.duration:
+ while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
-
if len(self.history) >= self.num_requests:
self.throttle_failure()
else:
self.throttle_success()
-
+
def throttle_success(self):
"""
Inserts the current request's timestamp along with the key
@@ -157,13 +151,30 @@ class BaseThrottle(BasePermission):
"""
self.history.insert(0, self.now)
cache.set(self.key, self.history, self.duration)
-
+ header = 'status=SUCCESS; next=%s sec' % self.next()
+ self.view.add_header('X-Throttle', header)
+
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
Raises a '503 service unavailable' response.
"""
+ header = 'status=FAILURE; next=%s sec' % self.next()
+ self.view.add_header('X-Throttle', header)
raise _503_SERVICE_UNAVAILABLE
+
+ def next(self):
+ """
+ Returns the recommended next request time in seconds.
+ """
+ if self.history:
+ remaining_duration = self.duration - (self.now - self.history[-1])
+ else:
+ remaining_duration = self.duration
+
+ available_requests = self.num_requests - len(self.history) + 1
+
+ return '%.2f' % (remaining_duration / float(available_requests))
class PerUserThrottling(BaseThrottle):
diff --git a/djangorestframework/runtests/runtests.py b/djangorestframework/runtests/runtests.py
index b98a496f..1da918f5 100644
--- a/djangorestframework/runtests/runtests.py
+++ b/djangorestframework/runtests/runtests.py
@@ -13,25 +13,27 @@ os.environ['DJANGO_SETTINGS_MODULE'] = 'djangorestframework.runtests.settings'
from django.conf import settings
from django.test.utils import get_runner
+def usage():
+ return """
+ Usage: python runtests.py [UnitTestClass].[method]
+
+ You can pass the Class name of the `UnitTestClass` you want to test.
+
+ Append a method name if you only want to test a specific method of that class.
+ """
+
def main():
TestRunner = get_runner(settings)
- if hasattr(TestRunner, 'func_name'):
- # Pre 1.2 test runners were just functions,
- # and did not support the 'failfast' option.
- import warnings
- warnings.warn(
- 'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
- DeprecationWarning
- )
- failures = TestRunner(['djangorestframework'])
+ test_runner = TestRunner()
+ if len(sys.argv) == 2:
+ test_case = '.' + sys.argv[1]
+ elif len(sys.argv) == 1:
+ test_case = ''
else:
- test_runner = TestRunner()
- if len(sys.argv) > 1:
- test_case = '.' + sys.argv[1]
- else:
- test_case = ''
- failures = test_runner.run_tests(['djangorestframework' + test_case])
+ print usage()
+ sys.exit(1)
+ failures = test_runner.run_tests(['djangorestframework' + test_case])
sys.exit(failures)
diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py
index be552638..b620ee24 100644
--- a/djangorestframework/tests/throttling.py
+++ b/djangorestframework/tests/throttling.py
@@ -1,57 +1,65 @@
"""
Tests for the throttling implementations in the permissions module.
"""
-import time
-from django.conf.urls.defaults import patterns
from django.test import TestCase
-from django.utils import simplejson as json
from django.contrib.auth.models import User
from django.core.cache import cache
from djangorestframework.compat import RequestFactory
from djangorestframework.views import View
-from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling, ConfigurationException
+from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling
from djangorestframework.resources import FormResource
class MockView(View):
permissions = ( PerUserThrottling, )
- throttle = '3/sec' # 3 requests per second
+ throttle = '3/sec'
def get(self, request):
return 'foo'
-class MockView1(MockView):
+class MockView_PerViewThrottling(MockView):
permissions = ( PerViewThrottling, )
-class MockView2(MockView):
+class MockView_PerResourceThrottling(MockView):
permissions = ( PerResourceThrottling, )
- #No resource set
-
-class MockView3(MockView2):
resource = FormResource
+
+class MockView_MinuteThrottling(MockView):
+ throttle = '3/min'
+
+
class ThrottlingTests(TestCase):
urls = 'djangorestframework.tests.throttling'
def setUp(self):
- """Reset the cache so that no throttles will be active"""
+ """
+ Reset the cache so that no throttles will be active
+ """
cache.clear()
self.factory = RequestFactory()
def test_requests_are_throttled(self):
- """Ensure request rate is limited"""
+ """
+ Ensure request rate is limited
+ """
request = self.factory.get('/')
for dummy in range(4):
response = MockView.as_view()(request)
self.assertEqual(503, response.status_code)
+ def set_throttle_timer(self, view, value):
+ """
+ Explicitly set the timer, overriding time.time()
+ """
+ view.permissions[0].timer = lambda self: value
+
def test_request_throttling_expires(self):
"""
Ensure request rate is limited for a limited duration only
"""
- # Explicitly set the timer, overridding time.time()
- MockView.permissions[0].timer = lambda self: 0
+ self.set_throttle_timer(MockView, 0)
request = self.factory.get('/')
for dummy in range(4):
@@ -59,7 +67,7 @@ class ThrottlingTests(TestCase):
self.assertEqual(503, response.status_code)
# Advance the timer by one second
- MockView.permissions[0].timer = lambda self: 1
+ self.set_throttle_timer(MockView, 1)
response = MockView.as_view()(request)
self.assertEqual(200, response.status_code)
@@ -68,20 +76,73 @@ class ThrottlingTests(TestCase):
request = self.factory.get('/')
request.user = User.objects.create(username='a')
for dummy in range(3):
- response = view.as_view()(request)
+ view.as_view()(request)
request.user = User.objects.create(username='b')
response = view.as_view()(request)
self.assertEqual(expect, response.status_code)
def test_request_throttling_is_per_user(self):
- """Ensure request rate is only limited per user, not globally for PerUserThrottles"""
+ """
+ Ensure request rate is only limited per user, not globally for
+ PerUserThrottles
+ """
self.ensure_is_throttled(MockView, 200)
def test_request_throttling_is_per_view(self):
- """Ensure request rate is limited globally per View for PerViewThrottles"""
- self.ensure_is_throttled(MockView1, 503)
+ """
+ Ensure request rate is limited globally per View for PerViewThrottles
+ """
+ self.ensure_is_throttled(MockView_PerViewThrottling, 503)
def test_request_throttling_is_per_resource(self):
- """Ensure request rate is limited globally per Resource for PerResourceThrottles"""
- self.ensure_is_throttled(MockView3, 503)
- \ No newline at end of file
+ """
+ Ensure request rate is limited globally per Resource for PerResourceThrottles
+ """
+ self.ensure_is_throttled(MockView_PerResourceThrottling, 503)
+
+
+ def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
+ """
+ Ensure the response returns an X-Throttle field with status and next attributes
+ set properly.
+ """
+ request = self.factory.get('/')
+ for timer, expect in expected_headers:
+ self.set_throttle_timer(view, timer)
+ response = view.as_view()(request)
+ self.assertEquals(response['X-Throttle'], expect)
+
+ def test_seconds_fields(self):
+ """
+ Ensure for second based throttles.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView,
+ ((0, 'status=SUCCESS; next=0.33 sec'),
+ (0, 'status=SUCCESS; next=0.50 sec'),
+ (0, 'status=SUCCESS; next=1.00 sec'),
+ (0, 'status=FAILURE; next=1.00 sec')
+ ))
+
+ def test_minutes_fields(self):
+ """
+ Ensure for minute based throttles.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
+ ((0, 'status=SUCCESS; next=20.00 sec'),
+ (0, 'status=SUCCESS; next=30.00 sec'),
+ (0, 'status=SUCCESS; next=60.00 sec'),
+ (0, 'status=FAILURE; next=60.00 sec')
+ ))
+
+ def test_next_rate_remains_constant_if_followed(self):
+ """
+ If a client follows the recommended next request rate,
+ the throttling rate should stay constant.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
+ ((0, 'status=SUCCESS; next=20.00 sec'),
+ (20, 'status=SUCCESS; next=20.00 sec'),
+ (40, 'status=SUCCESS; next=20.00 sec'),
+ (60, 'status=SUCCESS; next=20.00 sec'),
+ (80, 'status=SUCCESS; next=20.00 sec')
+ ))
diff --git a/djangorestframework/views.py b/djangorestframework/views.py
index 49d722c5..18d064e1 100644
--- a/djangorestframework/views.py
+++ b/djangorestframework/views.py
@@ -64,7 +64,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
"""
permissions = ( permissions.FullAnonAccess, )
-
+
@classmethod
def as_view(cls, **initkwargs):
"""
@@ -101,6 +101,14 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
"""
pass
+
+ def add_header(self, field, value):
+ """
+ Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class.
+ """
+ self.headers[field] = value
+
+
# Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt.
@csrf_exempt
@@ -108,6 +116,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
self.request = request
self.args = args
self.kwargs = kwargs
+ self.headers = {}
# Calls to 'reverse' will not be fully qualified unless we set the scheme/host/port here.
prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host())
@@ -149,9 +158,12 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
# also it's currently sub-obtimal for HTTP caching - need to sort that out.
response.headers['Allow'] = ', '.join(self.allowed_methods)
response.headers['Vary'] = 'Authenticate, Accept'
+
+ # merge with headers possibly set at some point in the view
+ response.headers.update(self.headers)
+
+ return self.render(response)
- return self.render(response)
-
class ModelView(View):
"""A RESTful view that maps to a model in the database."""