diff options
| author | Tom Christie | 2013-06-14 14:18:40 +0100 | 
|---|---|---|
| committer | Tom Christie | 2013-06-14 14:18:40 +0100 | 
| commit | df957c8625c79e36c33f314c943a2c593f3a2701 (patch) | |
| tree | f4bbf53718323ba92a0cb9afdb833d788afcccf6 /rest_framework/tests | |
| parent | 6cc4fe56373ba844a02173ea2efdf288d8fbac03 (diff) | |
| download | django-rest-framework-df957c8625c79e36c33f314c943a2c593f3a2701.tar.bz2 | |
Fix and tests for ScopedRateThrottle. Closes #935
Diffstat (limited to 'rest_framework/tests')
| -rw-r--r-- | rest_framework/tests/test_throttling.py | 109 | 
1 files changed, 106 insertions, 3 deletions
diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py index da400b2f..d35d3709 100644 --- a/rest_framework/tests/test_throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -7,7 +7,7 @@ from django.contrib.auth.models import User  from django.core.cache import cache  from django.test.client import RequestFactory  from rest_framework.views import APIView -from rest_framework.throttling import UserRateThrottle +from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle  from rest_framework.response import Response @@ -36,8 +36,6 @@ class MockView_MinuteThrottling(APIView):  class ThrottlingTests(TestCase): -    urls = 'rest_framework.tests.test_throttling' -      def setUp(self):          """          Reset the cache so that no throttles will be active @@ -141,3 +139,108 @@ class ThrottlingTests(TestCase):            (60, None),            (80, None)           )) + + +class ScopedRateThrottleTests(TestCase): +    """ +    Tests for ScopedRateThrottle. +    """ + +    def setUp(self): +        class XYScopedRateThrottle(ScopedRateThrottle): +            TIMER_SECONDS = 0 +            THROTTLE_RATES = {'x': '3/min', 'y': '1/min'} +            timer = lambda self: self.TIMER_SECONDS + +        class XView(APIView): +            throttle_classes = (XYScopedRateThrottle,) +            throttle_scope = 'x' + +            def get(self, request): +                return Response('x') + +        class YView(APIView): +            throttle_classes = (XYScopedRateThrottle,) +            throttle_scope = 'y' + +            def get(self, request): +                return Response('y') + +        class UnscopedView(APIView): +            throttle_classes = (XYScopedRateThrottle,) + +            def get(self, request): +                return Response('y') + +        self.throttle_class = XYScopedRateThrottle +        self.factory = RequestFactory() +        self.x_view = XView.as_view() +        self.y_view = YView.as_view() +        self.unscoped_view = UnscopedView.as_view() + +    def increment_timer(self, seconds=1): +        self.throttle_class.TIMER_SECONDS += seconds + +    def test_scoped_rate_throttle(self): +        request = self.factory.get('/') + +        # Should be able to hit x view 3 times per minute. +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(429, response.status_code) + +        # Should be able to hit y view 1 time per minute. +        self.increment_timer() +        response = self.y_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.y_view(request) +        self.assertEqual(429, response.status_code) + +        # Ensure throttles properly reset by advancing the rest of the minute +        self.increment_timer(55) + +        # Should still be able to hit x view 3 times per minute. +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(429, response.status_code) + +        # Should still be able to hit y view 1 time per minute. +        self.increment_timer() +        response = self.y_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.y_view(request) +        self.assertEqual(429, response.status_code) + +    def test_unscoped_view_not_throttled(self): +        request = self.factory.get('/') + +        for idx in range(10): +            self.increment_timer() +            response = self.unscoped_view(request) +            self.assertEqual(200, response.status_code)  | 
