aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/tests/test_throttling.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/tests/test_throttling.py')
-rw-r--r--rest_framework/tests/test_throttling.py281
1 files changed, 0 insertions, 281 deletions
diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py
deleted file mode 100644
index b5ae02cd..00000000
--- a/rest_framework/tests/test_throttling.py
+++ /dev/null
@@ -1,281 +0,0 @@
-"""
-Tests for the throttling implementations in the permissions module.
-"""
-from __future__ import unicode_literals
-from django.test import TestCase
-from django.contrib.auth.models import User
-from django.core.cache import cache
-from rest_framework.test import APIRequestFactory
-from rest_framework.views import APIView
-from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle
-from rest_framework.response import Response
-
-
-class User3SecRateThrottle(UserRateThrottle):
- rate = '3/sec'
- scope = 'seconds'
-
-
-class User3MinRateThrottle(UserRateThrottle):
- rate = '3/min'
- scope = 'minutes'
-
-
-class NonTimeThrottle(BaseThrottle):
- def allow_request(self, request, view):
- if not hasattr(self.__class__, 'called'):
- self.__class__.called = True
- return True
- return False
-
-
-class MockView(APIView):
- throttle_classes = (User3SecRateThrottle,)
-
- def get(self, request):
- return Response('foo')
-
-
-class MockView_MinuteThrottling(APIView):
- throttle_classes = (User3MinRateThrottle,)
-
- def get(self, request):
- return Response('foo')
-
-
-class MockView_NonTimeThrottling(APIView):
- throttle_classes = (NonTimeThrottle,)
-
- def get(self, request):
- return Response('foo')
-
-
-class ThrottlingTests(TestCase):
- def setUp(self):
- """
- Reset the cache so that no throttles will be active
- """
- cache.clear()
- self.factory = APIRequestFactory()
-
- def test_requests_are_throttled(self):
- """
- Ensure request rate is limited
- """
- request = self.factory.get('/')
- for dummy in range(4):
- response = MockView.as_view()(request)
- self.assertEqual(429, response.status_code)
-
- def set_throttle_timer(self, view, value):
- """
- Explicitly set the timer, overriding time.time()
- """
- view.throttle_classes[0].timer = lambda self: value
-
- def test_request_throttling_expires(self):
- """
- Ensure request rate is limited for a limited duration only
- """
- self.set_throttle_timer(MockView, 0)
-
- request = self.factory.get('/')
- for dummy in range(4):
- response = MockView.as_view()(request)
- self.assertEqual(429, response.status_code)
-
- # Advance the timer by one second
- self.set_throttle_timer(MockView, 1)
-
- response = MockView.as_view()(request)
- self.assertEqual(200, response.status_code)
-
- def ensure_is_throttled(self, view, expect):
- request = self.factory.get('/')
- request.user = User.objects.create(username='a')
- for dummy in range(3):
- view.as_view()(request)
- request.user = User.objects.create(username='b')
- response = view.as_view()(request)
- self.assertEqual(expect, response.status_code)
-
- def test_request_throttling_is_per_user(self):
- """
- Ensure request rate is only limited per user, not globally for
- PerUserThrottles
- """
- self.ensure_is_throttled(MockView, 200)
-
- def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
- """
- Ensure the response returns an X-Throttle field with status and next attributes
- set properly.
- """
- request = self.factory.get('/')
- for timer, expect in expected_headers:
- self.set_throttle_timer(view, timer)
- response = view.as_view()(request)
- if expect is not None:
- self.assertEqual(response['X-Throttle-Wait-Seconds'], expect)
- self.assertEqual(response['Retry-After'], expect)
- else:
- self.assertFalse('X-Throttle-Wait-Seconds' in response)
- self.assertFalse('Retry-After' in response)
-
- def test_seconds_fields(self):
- """
- Ensure for second based throttles.
- """
- self.ensure_response_header_contains_proper_throttle_field(MockView,
- ((0, None),
- (0, None),
- (0, None),
- (0, '1')
- ))
-
- def test_minutes_fields(self):
- """
- Ensure for minute based throttles.
- """
- self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
- ((0, None),
- (0, None),
- (0, None),
- (0, '60')
- ))
-
- def test_next_rate_remains_constant_if_followed(self):
- """
- If a client follows the recommended next request rate,
- the throttling rate should stay constant.
- """
- self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
- ((0, None),
- (20, None),
- (40, None),
- (60, None),
- (80, None)
- ))
-
- def test_non_time_throttle(self):
- """
- Ensure for second based throttles.
- """
- request = self.factory.get('/')
-
- self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))
-
- response = MockView_NonTimeThrottling.as_view()(request)
- self.assertFalse('X-Throttle-Wait-Seconds' in response)
- self.assertFalse('Retry-After' in response)
-
- self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
-
- response = MockView_NonTimeThrottling.as_view()(request)
- self.assertFalse('X-Throttle-Wait-Seconds' in response)
- self.assertFalse('Retry-After' in response)
-
-
-class ScopedRateThrottleTests(TestCase):
- """
- Tests for ScopedRateThrottle.
- """
-
- def setUp(self):
- class XYScopedRateThrottle(ScopedRateThrottle):
- TIMER_SECONDS = 0
- THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
- timer = lambda self: self.TIMER_SECONDS
-
- class XView(APIView):
- throttle_classes = (XYScopedRateThrottle,)
- throttle_scope = 'x'
-
- def get(self, request):
- return Response('x')
-
- class YView(APIView):
- throttle_classes = (XYScopedRateThrottle,)
- throttle_scope = 'y'
-
- def get(self, request):
- return Response('y')
-
- class UnscopedView(APIView):
- throttle_classes = (XYScopedRateThrottle,)
-
- def get(self, request):
- return Response('y')
-
- self.throttle_class = XYScopedRateThrottle
- self.factory = APIRequestFactory()
- self.x_view = XView.as_view()
- self.y_view = YView.as_view()
- self.unscoped_view = UnscopedView.as_view()
-
- def increment_timer(self, seconds=1):
- self.throttle_class.TIMER_SECONDS += seconds
-
- def test_scoped_rate_throttle(self):
- request = self.factory.get('/')
-
- # Should be able to hit x view 3 times per minute.
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(429, response.status_code)
-
- # Should be able to hit y view 1 time per minute.
- self.increment_timer()
- response = self.y_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.y_view(request)
- self.assertEqual(429, response.status_code)
-
- # Ensure throttles properly reset by advancing the rest of the minute
- self.increment_timer(55)
-
- # Should still be able to hit x view 3 times per minute.
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(429, response.status_code)
-
- # Should still be able to hit y view 1 time per minute.
- self.increment_timer()
- response = self.y_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.y_view(request)
- self.assertEqual(429, response.status_code)
-
- def test_unscoped_view_not_throttled(self):
- request = self.factory.get('/')
-
- for idx in range(10):
- self.increment_timer()
- response = self.unscoped_view(request)
- self.assertEqual(200, response.status_code)