diff options
Diffstat (limited to 'rest_framework/throttling.py')
| -rw-r--r-- | rest_framework/throttling.py | 45 | 
1 files changed, 34 insertions, 11 deletions
| diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index f6bb1cc8..261fc246 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -2,7 +2,7 @@  Provides various throttling policies.  """  from __future__ import unicode_literals -from django.core.cache import cache +from django.core.cache import cache as default_cache  from django.core.exceptions import ImproperlyConfigured  from rest_framework.settings import api_settings  import time @@ -18,6 +18,25 @@ class BaseThrottle(object):          """          raise NotImplementedError('.allow_request() must be overridden') +    def get_ident(self, request): +        """ +        Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR +        if present and number of proxies is > 0. If not use all of +        HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. +        """ +        xff = request.META.get('HTTP_X_FORWARDED_FOR') +        remote_addr = request.META.get('REMOTE_ADDR') +        num_proxies = api_settings.NUM_PROXIES + +        if num_proxies is not None: +            if num_proxies == 0 or xff is None: +                return remote_addr +            addrs = xff.split(',') +            client_addr = addrs[-min(num_proxies, len(addrs))] +            return client_addr.strip() + +        return ''.join(xff.split()) if xff else remote_addr +      def wait(self):          """          Optionally, return a recommended number of seconds to wait before @@ -39,8 +58,9 @@ class SimpleRateThrottle(BaseThrottle):      Previous request information used for throttling is stored in the cache.      """ +    cache = default_cache      timer = time.time -    cache_format = 'throtte_%(scope)s_%(ident)s' +    cache_format = 'throttle_%(scope)s_%(ident)s'      scope = None      THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES @@ -96,7 +116,10 @@ class SimpleRateThrottle(BaseThrottle):              return True          self.key = self.get_cache_key(request, view) -        self.history = cache.get(self.key, []) +        if self.key is None: +            return True + +        self.history = self.cache.get(self.key, [])          self.now = self.timer()          # Drop any requests from the history which have now passed the @@ -113,7 +136,7 @@ class SimpleRateThrottle(BaseThrottle):          into the cache.          """          self.history.insert(0, self.now) -        cache.set(self.key, self.history, self.duration) +        self.cache.set(self.key, self.history, self.duration)          return True      def throttle_failure(self): @@ -132,6 +155,8 @@ class SimpleRateThrottle(BaseThrottle):              remaining_duration = self.duration          available_requests = self.num_requests - len(self.history) + 1 +        if available_requests <= 0: +            return None          return remaining_duration / float(available_requests) @@ -148,11 +173,9 @@ class AnonRateThrottle(SimpleRateThrottle):          if request.user.is_authenticated():              return None  # Only throttle unauthenticated requests. -        ident = request.META.get('REMOTE_ADDR', None) -          return self.cache_format % {              'scope': self.scope, -            'ident': ident +            'ident': self.get_ident(request)          } @@ -168,9 +191,9 @@ class UserRateThrottle(SimpleRateThrottle):      def get_cache_key(self, request, view):          if request.user.is_authenticated(): -            ident = request.user.id +            ident = request.user.pk          else: -            ident = request.META.get('REMOTE_ADDR', None) +            ident = self.get_ident(request)          return self.cache_format % {              'scope': self.scope, @@ -216,9 +239,9 @@ class ScopedRateThrottle(SimpleRateThrottle):          with the '.throttle_scope` property of the view.          """          if request.user.is_authenticated(): -            ident = request.user.id +            ident = request.user.pk          else: -            ident = request.META.get('REMOTE_ADDR', None) +            ident = self.get_ident(request)          return self.cache_format % {              'scope': self.scope, | 
