aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/tests/test_throttling.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/tests/test_throttling.py')
-rw-r--r--rest_framework/tests/test_throttling.py109
1 files changed, 106 insertions, 3 deletions
diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py
index da400b2f..d35d3709 100644
--- a/rest_framework/tests/test_throttling.py
+++ b/rest_framework/tests/test_throttling.py
@@ -7,7 +7,7 @@ from django.contrib.auth.models import User
from django.core.cache import cache
from django.test.client import RequestFactory
from rest_framework.views import APIView
-from rest_framework.throttling import UserRateThrottle
+from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle
from rest_framework.response import Response
@@ -36,8 +36,6 @@ class MockView_MinuteThrottling(APIView):
class ThrottlingTests(TestCase):
- urls = 'rest_framework.tests.test_throttling'
-
def setUp(self):
"""
Reset the cache so that no throttles will be active
@@ -141,3 +139,108 @@ class ThrottlingTests(TestCase):
(60, None),
(80, None)
))
+
+
+class ScopedRateThrottleTests(TestCase):
+ """
+ Tests for ScopedRateThrottle.
+ """
+
+ def setUp(self):
+ class XYScopedRateThrottle(ScopedRateThrottle):
+ TIMER_SECONDS = 0
+ THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
+ timer = lambda self: self.TIMER_SECONDS
+
+ class XView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+ throttle_scope = 'x'
+
+ def get(self, request):
+ return Response('x')
+
+ class YView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+ throttle_scope = 'y'
+
+ def get(self, request):
+ return Response('y')
+
+ class UnscopedView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+
+ def get(self, request):
+ return Response('y')
+
+ self.throttle_class = XYScopedRateThrottle
+ self.factory = RequestFactory()
+ self.x_view = XView.as_view()
+ self.y_view = YView.as_view()
+ self.unscoped_view = UnscopedView.as_view()
+
+ def increment_timer(self, seconds=1):
+ self.throttle_class.TIMER_SECONDS += seconds
+
+ def test_scoped_rate_throttle(self):
+ request = self.factory.get('/')
+
+ # Should be able to hit x view 3 times per minute.
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Should be able to hit y view 1 time per minute.
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Ensure throttles properly reset by advancing the rest of the minute
+ self.increment_timer(55)
+
+ # Should still be able to hit x view 3 times per minute.
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Should still be able to hit y view 1 time per minute.
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(429, response.status_code)
+
+ def test_unscoped_view_not_throttled(self):
+ request = self.factory.get('/')
+
+ for idx in range(10):
+ self.increment_timer()
+ response = self.unscoped_view(request)
+ self.assertEqual(200, response.status_code)