aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/throttling.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/throttling.py')
-rw-r--r--rest_framework/throttling.py217
1 files changed, 217 insertions, 0 deletions
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
new file mode 100644
index 00000000..b66284c3
--- /dev/null
+++ b/rest_framework/throttling.py
@@ -0,0 +1,217 @@
+from django.core.cache import cache
+from rest_framework.settings import api_settings
+import time
+
+
+class BaseThrottle(object):
+ """
+ Rate throttling of requests.
+ """
+
+ def __init__(self, view=None):
+ """
+ All throttles hold a reference to the instantiating view.
+ """
+ self.view = view
+
+ def allow_request(self, request):
+ """
+ Return `True` if the request should be allowed, `False` otherwise.
+ """
+ raise NotImplementedError('.allow_request() must be overridden')
+
+ def wait(self):
+ """
+ Optionally, return a recommeded number of seconds to wait before
+ the next request.
+ """
+ return None
+
+
+class SimpleRateThottle(BaseThrottle):
+ """
+ A simple cache implementation, that only requires `.get_cache_key()`
+ to be overridden.
+
+ The rate (requests / seconds) is set by a :attr:`throttle` attribute
+ on the :class:`.View` class. The attribute is a string of the form 'number of
+ requests/period'.
+
+ Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
+
+ Previous request information used for throttling is stored in the cache.
+ """
+
+ timer = time.time
+ settings = api_settings
+ cache_format = 'throtte_%(scope)s_%(ident)s'
+ scope = None
+
+ def __init__(self, view):
+ super(SimpleRateThottle, self).__init__(view)
+ rate = self.get_rate_description()
+ self.num_requests, self.duration = self.parse_rate_description(rate)
+
+ def get_cache_key(self, request):
+ """
+ Should return a unique cache-key which can be used for throttling.
+ Must be overridden.
+
+ May return `None` if the request should not be throttled.
+ """
+ raise NotImplementedError('.get_cache_key() must be overridden')
+
+ def get_rate_description(self):
+ """
+ Determine the string representation of the allowed request rate.
+ """
+ try:
+ return self.rate
+ except AttributeError:
+ return self.settings.DEFAULT_THROTTLE_RATES.get(self.scope)
+
+ def parse_rate_description(self, rate):
+ """
+ Given the request rate string, return a two tuple of:
+ <allowed number of requests>, <period of time in seconds>
+ """
+ assert rate, "No throttle rate set for '%s'" % self.__class__.__name__
+ num, period = rate.split('/')
+ num_requests = int(num)
+ duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
+ return (num_requests, duration)
+
+ def allow_request(self, request):
+ """
+ Implement the check to see if the request should be throttled.
+
+ On success calls `throttle_success`.
+ On failure calls `throttle_failure`.
+ """
+ self.key = self.get_cache_key(request)
+ self.history = cache.get(self.key, [])
+ self.now = self.timer()
+
+ # Drop any requests from the history which have now passed the
+ # throttle duration
+ while self.history and self.history[-1] <= self.now - self.duration:
+ self.history.pop()
+ if len(self.history) >= self.num_requests:
+ return self.throttle_failure()
+ return self.throttle_success()
+
+ def throttle_success(self):
+ """
+ Inserts the current request's timestamp along with the key
+ into the cache.
+ """
+ self.history.insert(0, self.now)
+ cache.set(self.key, self.history, self.duration)
+ return True
+
+ def throttle_failure(self):
+ """
+ Called when a request to the API has failed due to throttling.
+ """
+ return False
+
+ def wait(self):
+ """
+ Returns the recommended next request time in seconds.
+ """
+ if self.history:
+ remaining_duration = self.duration - (self.now - self.history[-1])
+ else:
+ remaining_duration = self.duration
+
+ available_requests = self.num_requests - len(self.history) + 1
+
+ return remaining_duration / float(available_requests)
+
+
+class AnonRateThrottle(SimpleRateThottle):
+ """
+ Limits the rate of API calls that may be made by a anonymous users.
+
+ The IP address of the request will be used as the unqiue cache key.
+ """
+ scope = 'anon'
+
+ def get_cache_key(self, request):
+ if request.user.is_authenticated():
+ return None # Only throttle unauthenticated requests.
+
+ ident = request.META.get('REMOTE_ADDR', None)
+
+ return self.cache_format % {
+ 'scope': self.scope,
+ 'ident': ident
+ }
+
+
+class UserRateThrottle(SimpleRateThottle):
+ """
+ Limits the rate of API calls that may be made by a given user.
+
+ The user id will be used as a unique cache key if the user is
+ authenticated. For anonymous requests, the IP address of the request will
+ be used.
+ """
+ scope = 'user'
+
+ def get_cache_key(self, request):
+ if request.user.is_authenticated():
+ ident = request.user.id
+ else:
+ ident = request.META.get('REMOTE_ADDR', None)
+
+ return self.cache_format % {
+ 'scope': self.scope,
+ 'ident': ident
+ }
+
+
+class ScopedRateThrottle(SimpleRateThottle):
+ """
+ Limits the rate of API calls by different amounts for various parts of
+ the API. Any view that has the `throttle_scope` property set will be
+ throttled. The unique cache key will be generated by concatenating the
+ user id of the request, and the scope of the view being accessed.
+ """
+
+ scope_attr = 'throttle_scope'
+
+ def __init__(self, view):
+ """
+ Scope is determined from the view being accessed.
+ """
+ self.scope = getattr(self.view, self.scope_attr, None)
+ super(ScopedRateThrottle, self).__init__(view)
+
+ def parse_rate_description(self, rate):
+ """
+ Subclassed so that we don't fail if `view.throttle_scope` is not set.
+ """
+ if not rate:
+ return (None, None)
+ return super(ScopedRateThrottle, self).parse_rate_description(rate)
+
+ def get_cache_key(self, request):
+ """
+ If `view.throttle_scope` is not set, don't apply this throttle.
+
+ Otherwise generate the unique cache key by concatenating the user id
+ with the '.throttle_scope` property of the view.
+ """
+ if not self.scope:
+ return None # Only throttle views if `.throttle_scope` is set.
+
+ if request.user.is_authenticated():
+ ident = request.user.id
+ else:
+ ident = request.META.get('REMOTE_ADDR', None)
+
+ return self.cache_format % {
+ 'scope': self.scope,
+ 'ident': ident
+ }