diff options
Diffstat (limited to 'rest_framework/throttling.py')
| -rw-r--r-- | rest_framework/throttling.py | 36 |
1 files changed, 20 insertions, 16 deletions
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index e7750478..0d33f0e2 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -1,5 +1,6 @@ import time from django.core.cache import cache +from rest_framework import exceptions from rest_framework.settings import api_settings @@ -49,8 +50,9 @@ class SimpleRateThottle(BaseThrottle): def __init__(self, view): super(SimpleRateThottle, self).__init__(view) - rate = self.get_rate_description() - self.num_requests, self.duration = self.parse_rate_description(rate) + if not getattr(self, 'rate', None): + self.rate = self.get_rate() + self.num_requests, self.duration = self.parse_rate(self.rate) def get_cache_key(self, request): """ @@ -61,21 +63,28 @@ class SimpleRateThottle(BaseThrottle): """ raise NotImplementedError('.get_cache_key() must be overridden') - def get_rate_description(self): + def get_rate(self): """ Determine the string representation of the allowed request rate. """ + if not getattr(self, 'scope', None): + msg = ("You must set either `.scope` or `.rate` for '%s' thottle" % + self.__class__.__name__) + raise exceptions.ConfigurationError(msg) + try: - return self.rate - except AttributeError: - return self.settings.DEFAULT_THROTTLE_RATES.get(self.scope) + return self.settings.DEFAULT_THROTTLE_RATES[self.scope] + except KeyError: + msg = "No default throttle rate set for '%s' scope" % self.scope + raise exceptions.ConfigurationError(msg) - def parse_rate_description(self, rate): + def parse_rate(self, rate): """ Given the request rate string, return a two tuple of: <allowed number of requests>, <period of time in seconds> """ - assert rate, "No throttle rate set for '%s'" % self.__class__.__name__ + if rate is None: + return (None, None) num, period = rate.split('/') num_requests = int(num) duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] @@ -88,6 +97,9 @@ class SimpleRateThottle(BaseThrottle): On success calls `throttle_success`. On failure calls `throttle_failure`. """ + if self.rate is None: + return True + self.key = self.get_cache_key(request) self.history = cache.get(self.key, []) self.now = self.timer() @@ -188,14 +200,6 @@ class ScopedRateThrottle(SimpleRateThottle): self.scope = getattr(self.view, self.scope_attr, None) super(ScopedRateThrottle, self).__init__(view) - def parse_rate_description(self, rate): - """ - Subclassed so that we don't fail if `view.throttle_scope` is not set. - """ - if not rate: - return (None, None) - return super(ScopedRateThrottle, self).parse_rate_description(rate) - def get_cache_key(self, request): """ If `view.throttle_scope` is not set, don't apply this throttle. |
