diff options
Diffstat (limited to 'rest_framework/throttling.py')
| -rw-r--r-- | rest_framework/throttling.py | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py new file mode 100644 index 00000000..b66284c3 --- /dev/null +++ b/rest_framework/throttling.py @@ -0,0 +1,217 @@ +from django.core.cache import cache +from rest_framework.settings import api_settings +import time + + +class BaseThrottle(object): + """ + Rate throttling of requests. + """ + + def __init__(self, view=None): + """ + All throttles hold a reference to the instantiating view. + """ + self.view = view + + def allow_request(self, request): + """ + Return `True` if the request should be allowed, `False` otherwise. + """ + raise NotImplementedError('.allow_request() must be overridden') + + def wait(self): + """ + Optionally, return a recommeded number of seconds to wait before + the next request. + """ + return None + + +class SimpleRateThottle(BaseThrottle): + """ + A simple cache implementation, that only requires `.get_cache_key()` + to be overridden. + + The rate (requests / seconds) is set by a :attr:`throttle` attribute + on the :class:`.View` class. The attribute is a string of the form 'number of + requests/period'. + + Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') + + Previous request information used for throttling is stored in the cache. + """ + + timer = time.time + settings = api_settings + cache_format = 'throtte_%(scope)s_%(ident)s' + scope = None + + def __init__(self, view): + super(SimpleRateThottle, self).__init__(view) + rate = self.get_rate_description() + self.num_requests, self.duration = self.parse_rate_description(rate) + + def get_cache_key(self, request): + """ + Should return a unique cache-key which can be used for throttling. + Must be overridden. + + May return `None` if the request should not be throttled. + """ + raise NotImplementedError('.get_cache_key() must be overridden') + + def get_rate_description(self): + """ + Determine the string representation of the allowed request rate. + """ + try: + return self.rate + except AttributeError: + return self.settings.DEFAULT_THROTTLE_RATES.get(self.scope) + + def parse_rate_description(self, rate): + """ + Given the request rate string, return a two tuple of: + <allowed number of requests>, <period of time in seconds> + """ + assert rate, "No throttle rate set for '%s'" % self.__class__.__name__ + num, period = rate.split('/') + num_requests = int(num) + duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] + return (num_requests, duration) + + def allow_request(self, request): + """ + Implement the check to see if the request should be throttled. + + On success calls `throttle_success`. + On failure calls `throttle_failure`. + """ + self.key = self.get_cache_key(request) + self.history = cache.get(self.key, []) + self.now = self.timer() + + # Drop any requests from the history which have now passed the + # throttle duration + while self.history and self.history[-1] <= self.now - self.duration: + self.history.pop() + if len(self.history) >= self.num_requests: + return self.throttle_failure() + return self.throttle_success() + + def throttle_success(self): + """ + Inserts the current request's timestamp along with the key + into the cache. + """ + self.history.insert(0, self.now) + cache.set(self.key, self.history, self.duration) + return True + + def throttle_failure(self): + """ + Called when a request to the API has failed due to throttling. + """ + return False + + def wait(self): + """ + Returns the recommended next request time in seconds. + """ + if self.history: + remaining_duration = self.duration - (self.now - self.history[-1]) + else: + remaining_duration = self.duration + + available_requests = self.num_requests - len(self.history) + 1 + + return remaining_duration / float(available_requests) + + +class AnonRateThrottle(SimpleRateThottle): + """ + Limits the rate of API calls that may be made by a anonymous users. + + The IP address of the request will be used as the unqiue cache key. + """ + scope = 'anon' + + def get_cache_key(self, request): + if request.user.is_authenticated(): + return None # Only throttle unauthenticated requests. + + ident = request.META.get('REMOTE_ADDR', None) + + return self.cache_format % { + 'scope': self.scope, + 'ident': ident + } + + +class UserRateThrottle(SimpleRateThottle): + """ + Limits the rate of API calls that may be made by a given user. + + The user id will be used as a unique cache key if the user is + authenticated. For anonymous requests, the IP address of the request will + be used. + """ + scope = 'user' + + def get_cache_key(self, request): + if request.user.is_authenticated(): + ident = request.user.id + else: + ident = request.META.get('REMOTE_ADDR', None) + + return self.cache_format % { + 'scope': self.scope, + 'ident': ident + } + + +class ScopedRateThrottle(SimpleRateThottle): + """ + Limits the rate of API calls by different amounts for various parts of + the API. Any view that has the `throttle_scope` property set will be + throttled. The unique cache key will be generated by concatenating the + user id of the request, and the scope of the view being accessed. + """ + + scope_attr = 'throttle_scope' + + def __init__(self, view): + """ + Scope is determined from the view being accessed. + """ + self.scope = getattr(self.view, self.scope_attr, None) + super(ScopedRateThrottle, self).__init__(view) + + def parse_rate_description(self, rate): + """ + Subclassed so that we don't fail if `view.throttle_scope` is not set. + """ + if not rate: + return (None, None) + return super(ScopedRateThrottle, self).parse_rate_description(rate) + + def get_cache_key(self, request): + """ + If `view.throttle_scope` is not set, don't apply this throttle. + + Otherwise generate the unique cache key by concatenating the user id + with the '.throttle_scope` property of the view. + """ + if not self.scope: + return None # Only throttle views if `.throttle_scope` is set. + + if request.user.is_authenticated(): + ident = request.user.id + else: + ident = request.META.get('REMOTE_ADDR', None) + + return self.cache_format % { + 'scope': self.scope, + 'ident': ident + } |
