aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/throttling.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/throttling.py')
-rw-r--r--rest_framework/throttling.py45
1 files changed, 34 insertions, 11 deletions
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index f6bb1cc8..261fc246 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -2,7 +2,7 @@
Provides various throttling policies.
"""
from __future__ import unicode_literals
-from django.core.cache import cache
+from django.core.cache import cache as default_cache
from django.core.exceptions import ImproperlyConfigured
from rest_framework.settings import api_settings
import time
@@ -18,6 +18,25 @@ class BaseThrottle(object):
"""
raise NotImplementedError('.allow_request() must be overridden')
+ def get_ident(self, request):
+ """
+ Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
+ if present and number of proxies is > 0. If not use all of
+ HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
+ """
+ xff = request.META.get('HTTP_X_FORWARDED_FOR')
+ remote_addr = request.META.get('REMOTE_ADDR')
+ num_proxies = api_settings.NUM_PROXIES
+
+ if num_proxies is not None:
+ if num_proxies == 0 or xff is None:
+ return remote_addr
+ addrs = xff.split(',')
+ client_addr = addrs[-min(num_proxies, len(addrs))]
+ return client_addr.strip()
+
+ return ''.join(xff.split()) if xff else remote_addr
+
def wait(self):
"""
Optionally, return a recommended number of seconds to wait before
@@ -39,8 +58,9 @@ class SimpleRateThrottle(BaseThrottle):
Previous request information used for throttling is stored in the cache.
"""
+ cache = default_cache
timer = time.time
- cache_format = 'throtte_%(scope)s_%(ident)s'
+ cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
@@ -96,7 +116,10 @@ class SimpleRateThrottle(BaseThrottle):
return True
self.key = self.get_cache_key(request, view)
- self.history = cache.get(self.key, [])
+ if self.key is None:
+ return True
+
+ self.history = self.cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
@@ -113,7 +136,7 @@ class SimpleRateThrottle(BaseThrottle):
into the cache.
"""
self.history.insert(0, self.now)
- cache.set(self.key, self.history, self.duration)
+ self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
@@ -132,6 +155,8 @@ class SimpleRateThrottle(BaseThrottle):
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
+ if available_requests <= 0:
+ return None
return remaining_duration / float(available_requests)
@@ -148,11 +173,9 @@ class AnonRateThrottle(SimpleRateThrottle):
if request.user.is_authenticated():
return None # Only throttle unauthenticated requests.
- ident = request.META.get('REMOTE_ADDR', None)
-
return self.cache_format % {
'scope': self.scope,
- 'ident': ident
+ 'ident': self.get_ident(request)
}
@@ -168,9 +191,9 @@ class UserRateThrottle(SimpleRateThrottle):
def get_cache_key(self, request, view):
if request.user.is_authenticated():
- ident = request.user.id
+ ident = request.user.pk
else:
- ident = request.META.get('REMOTE_ADDR', None)
+ ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
@@ -216,9 +239,9 @@ class ScopedRateThrottle(SimpleRateThrottle):
with the '.throttle_scope` property of the view.
"""
if request.user.is_authenticated():
- ident = request.user.id
+ ident = request.user.pk
else:
- ident = request.META.get('REMOTE_ADDR', None)
+ ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,