diff options
| author | Tom Christie | 2013-12-12 15:18:35 -0800 |
|---|---|---|
| committer | Tom Christie | 2013-12-12 15:18:35 -0800 |
| commit | d6d4621c4580ae4902fc895bdca78aced0ec7eab (patch) | |
| tree | dec9003eacb39906dffb7690640e0ebd694893f0 /rest_framework | |
| parent | 36046168b62967f815e328402fcf2afad21c69fb (diff) | |
| parent | 23db6c98495d7b3c18a3069c6cb770d5cbc18ee1 (diff) | |
| download | django-rest-framework-d6d4621c4580ae4902fc895bdca78aced0ec7eab.tar.bz2 | |
Merge pull request #1273 from kahnjw/add_get_ident_method_to_base_throttle
Add get_ident method to base throttle class
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/settings.py | 1 | ||||
| -rw-r--r-- | rest_framework/tests/test_throttling.py | 66 | ||||
| -rw-r--r-- | rest_framework/throttling.py | 25 |
3 files changed, 85 insertions, 7 deletions
diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 8abaf140..383de72e 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -63,6 +63,7 @@ DEFAULTS = { 'user': None, 'anon': None, }, + 'NUM_PROXIES': None, # Pagination 'PAGINATE_BY': None, diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py index 41bff692..8c5eefe9 100644 --- a/rest_framework/tests/test_throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -5,6 +5,7 @@ from __future__ import unicode_literals from django.test import TestCase from django.contrib.auth.models import User from django.core.cache import cache +from rest_framework.settings import api_settings from rest_framework.test import APIRequestFactory from rest_framework.views import APIView from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle @@ -275,3 +276,68 @@ class ScopedRateThrottleTests(TestCase): self.increment_timer() response = self.unscoped_view(request) self.assertEqual(200, response.status_code) + + +class XffTestingBase(TestCase): + def setUp(self): + + class Throttle(ScopedRateThrottle): + THROTTLE_RATES = {'test_limit': '1/day'} + TIMER_SECONDS = 0 + timer = lambda self: self.TIMER_SECONDS + + class View(APIView): + throttle_classes = (Throttle,) + throttle_scope = 'test_limit' + + def get(self, request): + return Response('test_limit') + + cache.clear() + self.throttle = Throttle() + self.view = View.as_view() + self.request = APIRequestFactory().get('/some_uri') + self.request.META['REMOTE_ADDR'] = '3.3.3.3' + self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2' + + def config_proxy(self, num_proxies): + setattr(api_settings, 'NUM_PROXIES', num_proxies) + + +class IdWithXffBasicTests(XffTestingBase): + def test_accepts_request_under_limit(self): + self.config_proxy(0) + self.assertEqual(200, self.view(self.request).status_code) + + def test_denies_request_over_limit(self): + self.config_proxy(0) + self.view(self.request) + self.assertEqual(429, self.view(self.request).status_code) + + +class XffSpoofingTests(XffTestingBase): + def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self): + self.config_proxy(1) + self.view(self.request) + self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2' + self.assertEqual(429, self.view(self.request).status_code) + + def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self): + self.config_proxy(2) + self.view(self.request) + self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2' + self.assertEqual(429, self.view(self.request).status_code) + + +class XffUniqueMachinesTest(XffTestingBase): + def test_unique_clients_are_counted_independently_with_one_proxy(self): + self.config_proxy(1) + self.view(self.request) + self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7' + self.assertEqual(200, self.view(self.request).status_code) + + def test_unique_clients_are_counted_independently_with_two_proxies(self): + self.config_proxy(2) + self.view(self.request) + self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2' + self.assertEqual(200, self.view(self.request).status_code) diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index a946d837..60e46d47 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -18,6 +18,21 @@ class BaseThrottle(object): """ raise NotImplementedError('.allow_request() must be overridden') + def get_ident(self, request): + """ + Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR + if present and number of proxies is > 0. If not use all of + HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. + """ + xff = request.META.get('HTTP_X_FORWARDED_FOR') + remote_addr = request.META.get('REMOTE_ADDR') + num_proxies = api_settings.NUM_PROXIES + + if xff and num_proxies: + return xff.split(',')[-min(num_proxies, len(xff))].strip() + + return xff if xff else remote_addr + def wait(self): """ Optionally, return a recommended number of seconds to wait before @@ -152,13 +167,9 @@ class AnonRateThrottle(SimpleRateThrottle): if request.user.is_authenticated(): return None # Only throttle unauthenticated requests. - ident = request.META.get('HTTP_X_FORWARDED_FOR') - if ident is None: - ident = request.META.get('REMOTE_ADDR') - return self.cache_format % { 'scope': self.scope, - 'ident': ident + 'ident': self.get_ident(request) } @@ -176,7 +187,7 @@ class UserRateThrottle(SimpleRateThrottle): if request.user.is_authenticated(): ident = request.user.id else: - ident = request.META.get('REMOTE_ADDR', None) + ident = self.get_ident(request) return self.cache_format % { 'scope': self.scope, @@ -224,7 +235,7 @@ class ScopedRateThrottle(SimpleRateThrottle): if request.user.is_authenticated(): ident = request.user.id else: - ident = request.META.get('REMOTE_ADDR', None) + ident = self.get_ident(request) return self.cache_format % { 'scope': self.scope, |
