diff options
Diffstat (limited to 'rest_framework/throttling.py')
| -rw-r--r-- | rest_framework/throttling.py | 217 | 
1 files changed, 217 insertions, 0 deletions
| diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py new file mode 100644 index 00000000..b66284c3 --- /dev/null +++ b/rest_framework/throttling.py @@ -0,0 +1,217 @@ +from django.core.cache import cache +from rest_framework.settings import api_settings +import time + + +class BaseThrottle(object): +    """ +    Rate throttling of requests. +    """ + +    def __init__(self, view=None): +        """ +        All throttles hold a reference to the instantiating view. +        """ +        self.view = view + +    def allow_request(self, request): +        """ +        Return `True` if the request should be allowed, `False` otherwise. +        """ +        raise NotImplementedError('.allow_request() must be overridden') + +    def wait(self): +        """ +        Optionally, return a recommeded number of seconds to wait before +        the next request. +        """ +        return None + + +class SimpleRateThottle(BaseThrottle): +    """ +    A simple cache implementation, that only requires `.get_cache_key()` +    to be overridden. + +    The rate (requests / seconds) is set by a :attr:`throttle` attribute +    on the :class:`.View` class.  The attribute is a string of the form 'number of +    requests/period'. + +    Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') + +    Previous request information used for throttling is stored in the cache. +    """ + +    timer = time.time +    settings = api_settings +    cache_format = 'throtte_%(scope)s_%(ident)s' +    scope = None + +    def __init__(self, view): +        super(SimpleRateThottle, self).__init__(view) +        rate = self.get_rate_description() +        self.num_requests, self.duration = self.parse_rate_description(rate) + +    def get_cache_key(self, request): +        """ +        Should return a unique cache-key which can be used for throttling. +        Must be overridden. + +        May return `None` if the request should not be throttled. +        """ +        raise NotImplementedError('.get_cache_key() must be overridden') + +    def get_rate_description(self): +        """ +        Determine the string representation of the allowed request rate. +        """ +        try: +            return self.rate +        except AttributeError: +            return self.settings.DEFAULT_THROTTLE_RATES.get(self.scope) + +    def parse_rate_description(self, rate): +        """ +        Given the request rate string, return a two tuple of: +        <allowed number of requests>, <period of time in seconds> +        """ +        assert rate, "No throttle rate set for '%s'" % self.__class__.__name__ +        num, period = rate.split('/') +        num_requests = int(num) +        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] +        return (num_requests, duration) + +    def allow_request(self, request): +        """ +        Implement the check to see if the request should be throttled. + +        On success calls `throttle_success`. +        On failure calls `throttle_failure`. +        """ +        self.key = self.get_cache_key(request) +        self.history = cache.get(self.key, []) +        self.now = self.timer() + +        # Drop any requests from the history which have now passed the +        # throttle duration +        while self.history and self.history[-1] <= self.now - self.duration: +            self.history.pop() +        if len(self.history) >= self.num_requests: +            return self.throttle_failure() +        return self.throttle_success() + +    def throttle_success(self): +        """ +        Inserts the current request's timestamp along with the key +        into the cache. +        """ +        self.history.insert(0, self.now) +        cache.set(self.key, self.history, self.duration) +        return True + +    def throttle_failure(self): +        """ +        Called when a request to the API has failed due to throttling. +        """ +        return False + +    def wait(self): +        """ +        Returns the recommended next request time in seconds. +        """ +        if self.history: +            remaining_duration = self.duration - (self.now - self.history[-1]) +        else: +            remaining_duration = self.duration + +        available_requests = self.num_requests - len(self.history) + 1 + +        return remaining_duration / float(available_requests) + + +class AnonRateThrottle(SimpleRateThottle): +    """ +    Limits the rate of API calls that may be made by a anonymous users. + +    The IP address of the request will be used as the unqiue cache key. +    """ +    scope = 'anon' + +    def get_cache_key(self, request): +        if request.user.is_authenticated(): +            return None  # Only throttle unauthenticated requests. + +        ident = request.META.get('REMOTE_ADDR', None) + +        return self.cache_format % { +            'scope': self.scope, +            'ident': ident +        } + + +class UserRateThrottle(SimpleRateThottle): +    """ +    Limits the rate of API calls that may be made by a given user. + +    The user id will be used as a unique cache key if the user is +    authenticated.  For anonymous requests, the IP address of the request will +    be used. +    """ +    scope = 'user' + +    def get_cache_key(self, request): +        if request.user.is_authenticated(): +            ident = request.user.id +        else: +            ident = request.META.get('REMOTE_ADDR', None) + +        return self.cache_format % { +            'scope': self.scope, +            'ident': ident +        } + + +class ScopedRateThrottle(SimpleRateThottle): +    """ +    Limits the rate of API calls by different amounts for various parts of +    the API.  Any view that has the `throttle_scope` property set will be +    throttled.  The unique cache key will be generated by concatenating the +    user id of the request, and the scope of the view being accessed. +    """ + +    scope_attr = 'throttle_scope' + +    def __init__(self, view): +        """ +        Scope is determined from the view being accessed. +        """ +        self.scope = getattr(self.view, self.scope_attr, None) +        super(ScopedRateThrottle, self).__init__(view) + +    def parse_rate_description(self, rate): +        """ +        Subclassed so that we don't fail if `view.throttle_scope` is not set. +        """ +        if not rate: +            return (None, None) +        return super(ScopedRateThrottle, self).parse_rate_description(rate) + +    def get_cache_key(self, request): +        """ +        If `view.throttle_scope` is not set, don't apply this throttle. + +        Otherwise generate the unique cache key by concatenating the user id +        with the '.throttle_scope` property of the view. +        """ +        if not self.scope: +            return None  # Only throttle views if `.throttle_scope` is set. + +        if request.user.is_authenticated(): +            ident = request.user.id +        else: +            ident = request.META.get('REMOTE_ADDR', None) + +        return self.cache_format % { +            'scope': self.scope, +            'ident': ident +        } | 
