aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/throttling.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/throttling.py')
-rw-r--r--rest_framework/throttling.py39
1 files changed, 27 insertions, 12 deletions
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 93ea9816..f6bb1cc8 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -3,7 +3,7 @@ Provides various throttling policies.
"""
from __future__ import unicode_literals
from django.core.cache import cache
-from rest_framework import exceptions
+from django.core.exceptions import ImproperlyConfigured
from rest_framework.settings import api_settings
import time
@@ -40,9 +40,9 @@ class SimpleRateThrottle(BaseThrottle):
"""
timer = time.time
- settings = api_settings
cache_format = 'throtte_%(scope)s_%(ident)s'
scope = None
+ THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
@@ -65,13 +65,13 @@ class SimpleRateThrottle(BaseThrottle):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
- raise exceptions.ConfigurationError(msg)
+ raise ImproperlyConfigured(msg)
try:
- return self.settings.DEFAULT_THROTTLE_RATES[self.scope]
+ return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
- raise exceptions.ConfigurationError(msg)
+ raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
"""
@@ -187,6 +187,27 @@ class ScopedRateThrottle(SimpleRateThrottle):
"""
scope_attr = 'throttle_scope'
+ def __init__(self):
+ # Override the usual SimpleRateThrottle, because we can't determine
+ # the rate until called by the view.
+ pass
+
+ def allow_request(self, request, view):
+ # We can only determine the scope once we're called by the view.
+ self.scope = getattr(view, self.scope_attr, None)
+
+ # If a view does not have a `throttle_scope` always allow the request
+ if not self.scope:
+ return True
+
+ # Determine the allowed request rate as we normally would during
+ # the `__init__` call.
+ self.rate = self.get_rate()
+ self.num_requests, self.duration = self.parse_rate(self.rate)
+
+ # We can now proceed as normal.
+ return super(ScopedRateThrottle, self).allow_request(request, view)
+
def get_cache_key(self, request, view):
"""
If `view.throttle_scope` is not set, don't apply this throttle.
@@ -194,18 +215,12 @@ class ScopedRateThrottle(SimpleRateThrottle):
Otherwise generate the unique cache key by concatenating the user id
with the '.throttle_scope` property of the view.
"""
- scope = getattr(view, self.scope_attr, None)
-
- if not scope:
- # Only throttle views if `.throttle_scope` is set on the view.
- return None
-
if request.user.is_authenticated():
ident = request.user.id
else:
ident = request.META.get('REMOTE_ADDR', None)
return self.cache_format % {
- 'scope': scope,
+ 'scope': self.scope,
'ident': ident
}