aboutsummaryrefslogtreecommitdiffstats
path: root/djangorestframework/throttling.py
diff options
context:
space:
mode:
Diffstat (limited to 'djangorestframework/throttling.py')
-rw-r--r--djangorestframework/throttling.py139
1 files changed, 139 insertions, 0 deletions
diff --git a/djangorestframework/throttling.py b/djangorestframework/throttling.py
new file mode 100644
index 00000000..a096eab7
--- /dev/null
+++ b/djangorestframework/throttling.py
@@ -0,0 +1,139 @@
+from django.core.cache import cache
+import time
+
+
+class BaseThrottle(object):
+ """
+ Rate throttling of requests.
+ """
+
+ def __init__(self, view=None):
+ """
+ All throttles hold a reference to the instantiating view.
+ """
+ self.view = view
+
+ def check_throttle(self, request):
+ """
+ Return `True` if the request should be allowed, `False` otherwise.
+ """
+ raise NotImplementedError('.check_throttle() must be overridden')
+
+ def wait(self):
+ """
+ Optionally, return a recommeded number of seconds to wait before
+ the next request.
+ """
+ return None
+
+
+class SimpleCachingThrottle(BaseThrottle):
+ """
+ A simple cache implementation, that only requires `.get_cache_key()`
+ to be overridden.
+
+ The rate (requests / seconds) is set by a :attr:`throttle` attribute
+ on the :class:`.View` class. The attribute is a string of the form 'number of
+ requests/period'.
+
+ Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
+
+ Previous request information used for throttling is stored in the cache.
+ """
+
+ attr_name = 'rate'
+ rate = '1000/day'
+ timer = time.time
+
+ def __init__(self, view):
+ """
+ Check the throttling.
+ Return `None` or raise an :exc:`.ImmediateResponse`.
+ """
+ super(SimpleCachingThrottle, self).__init__(view)
+ num, period = getattr(view, self.attr_name, self.rate).split('/')
+ self.num_requests = int(num)
+ self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
+
+ def get_cache_key(self, request):
+ """
+ Should return a unique cache-key which can be used for throttling.
+ Must be overridden.
+ """
+ raise NotImplementedError('.get_cache_key() must be overridden')
+
+ def check_throttle(self, request):
+ """
+ Implement the check to see if the request should be throttled.
+
+ On success calls :meth:`throttle_success`.
+ On failure calls :meth:`throttle_failure`.
+ """
+ self.key = self.get_cache_key(request)
+ self.history = cache.get(self.key, [])
+ self.now = self.timer()
+
+ # Drop any requests from the history which have now passed the
+ # throttle duration
+ while self.history and self.history[-1] <= self.now - self.duration:
+ self.history.pop()
+ if len(self.history) >= self.num_requests:
+ return self.throttle_failure()
+ return self.throttle_success()
+
+ def throttle_success(self):
+ """
+ Inserts the current request's timestamp along with the key
+ into the cache.
+ """
+ self.history.insert(0, self.now)
+ cache.set(self.key, self.history, self.duration)
+ return True
+
+ def throttle_failure(self):
+ """
+ Called when a request to the API has failed due to throttling.
+ """
+ return False
+
+ def wait(self):
+ """
+ Returns the recommended next request time in seconds.
+ """
+ if self.history:
+ remaining_duration = self.duration - (self.now - self.history[-1])
+ else:
+ remaining_duration = self.duration
+
+ available_requests = self.num_requests - len(self.history) + 1
+
+ return remaining_duration / float(available_requests)
+
+
+class PerUserThrottling(SimpleCachingThrottle):
+ """
+ Limits the rate of API calls that may be made by a given user.
+
+ The user id will be used as a unique identifier if the user is
+ authenticated. For anonymous requests, the IP address of the client will
+ be used.
+ """
+
+ def get_cache_key(self, request):
+ if request.user.is_authenticated():
+ ident = request.user.id
+ else:
+ ident = request.META.get('REMOTE_ADDR', None)
+ return 'throttle_user_%s' % ident
+
+
+class PerViewThrottling(SimpleCachingThrottle):
+ """
+ Limits the rate of API calls that may be used on a given view.
+
+ The class name of the view is used as a unique identifier to
+ throttle against.
+ """
+
+ def get_cache_key(self, request):
+ return 'throttle_view_%s' % self.view.__class__.__name__