diff options
| author | Tom Christie | 2011-06-15 14:41:09 +0100 | 
|---|---|---|
| committer | Tom Christie | 2011-06-15 14:41:09 +0100 | 
| commit | 1cb84cd4e82880caea645ebd99a947cead3096b9 (patch) | |
| tree | 5ae955f4a7ef5cd9754f2b5c11caecdedcad1f10 /djangorestframework | |
| parent | ff6e78323f88fd58b1de5b02e2440c2fc24c9c8b (diff) | |
| parent | 49a2817eb5ccf5f176ff5366d69df3a307dfcda2 (diff) | |
| download | django-rest-framework-1cb84cd4e82880caea645ebd99a947cead3096b9.tar.bz2 | |
Merge throttling and fix up a coupla things
Diffstat (limited to 'djangorestframework')
| -rw-r--r-- | djangorestframework/permissions.py | 29 | ||||
| -rw-r--r-- | djangorestframework/runtests/runtests.py | 32 | ||||
| -rw-r--r-- | djangorestframework/tests/throttling.py | 105 | ||||
| -rw-r--r-- | djangorestframework/views.py | 18 | 
4 files changed, 135 insertions, 49 deletions
diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py index 34ab5bf4..7dcabcf0 100644 --- a/djangorestframework/permissions.py +++ b/djangorestframework/permissions.py @@ -31,11 +31,6 @@ _503_SERVICE_UNAVAILABLE = ErrorResponse(      {'detail': 'request was throttled'}) -class ConfigurationException(BaseException): -    """To alert for bad configuration decisions as a convenience.""" -    pass - -  class BasePermission(object):      """      A base class from which all permission classes should inherit. @@ -142,14 +137,13 @@ class BaseThrottle(BasePermission):          # Drop any requests from the history which have now passed the          # throttle duration -        while self.history and self.history[0] <= self.now - self.duration: +        while self.history and self.history[-1] <= self.now - self.duration:              self.history.pop() -          if len(self.history) >= self.num_requests:              self.throttle_failure()          else:              self.throttle_success() -     +      def throttle_success(self):          """          Inserts the current request's timestamp along with the key @@ -157,13 +151,30 @@ class BaseThrottle(BasePermission):          """          self.history.insert(0, self.now)          cache.set(self.key, self.history, self.duration) -     +        header = 'status=SUCCESS; next=%s sec' % self.next() +        self.view.add_header('X-Throttle', header) +                  def throttle_failure(self):          """          Called when a request to the API has failed due to throttling.          Raises a '503 service unavailable' response.          """ +        header = 'status=FAILURE; next=%s sec' % self.next() +        self.view.add_header('X-Throttle', header)          raise _503_SERVICE_UNAVAILABLE +     +    def next(self): +        """ +        Returns the recommended next request time in seconds. +        """ +        if self.history: +            remaining_duration = self.duration - (self.now - self.history[-1]) +        else: +            remaining_duration = self.duration + +        available_requests = self.num_requests - len(self.history) + 1 + +        return '%.2f' % (remaining_duration / float(available_requests))  class PerUserThrottling(BaseThrottle): diff --git a/djangorestframework/runtests/runtests.py b/djangorestframework/runtests/runtests.py index b98a496f..1da918f5 100644 --- a/djangorestframework/runtests/runtests.py +++ b/djangorestframework/runtests/runtests.py @@ -13,25 +13,27 @@ os.environ['DJANGO_SETTINGS_MODULE'] = 'djangorestframework.runtests.settings'  from django.conf import settings  from django.test.utils import get_runner +def usage(): +    return """ +    Usage: python runtests.py [UnitTestClass].[method] +     +    You can pass the Class name of the `UnitTestClass` you want to test. +     +    Append a method name if you only want to test a specific method of that class. +    """ +      def main():      TestRunner = get_runner(settings) -    if hasattr(TestRunner, 'func_name'): -        # Pre 1.2 test runners were just functions, -        # and did not support the 'failfast' option. -        import warnings -        warnings.warn( -            'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.', -            DeprecationWarning -        ) -        failures = TestRunner(['djangorestframework']) +    test_runner = TestRunner() +    if len(sys.argv) == 2: +        test_case = '.' + sys.argv[1] +    elif len(sys.argv) == 1: +        test_case = ''      else: -        test_runner = TestRunner() -        if len(sys.argv) > 1: -            test_case = '.' + sys.argv[1] -        else: -            test_case = '' -        failures = test_runner.run_tests(['djangorestframework' + test_case]) +        print usage() +        sys.exit(1) +    failures = test_runner.run_tests(['djangorestframework' + test_case])      sys.exit(failures) diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py index be552638..b620ee24 100644 --- a/djangorestframework/tests/throttling.py +++ b/djangorestframework/tests/throttling.py @@ -1,57 +1,65 @@  """  Tests for the throttling implementations in the permissions module.  """ -import time -from django.conf.urls.defaults import patterns  from django.test import TestCase -from django.utils import simplejson as json  from django.contrib.auth.models import User  from django.core.cache import cache  from djangorestframework.compat import RequestFactory  from djangorestframework.views import View -from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling, ConfigurationException +from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling  from djangorestframework.resources import FormResource  class MockView(View):      permissions = ( PerUserThrottling, ) -    throttle = '3/sec' # 3 requests per second +    throttle = '3/sec'      def get(self, request):          return 'foo' -class MockView1(MockView): +class MockView_PerViewThrottling(MockView):      permissions = ( PerViewThrottling, ) -class MockView2(MockView): +class MockView_PerResourceThrottling(MockView):          permissions = ( PerResourceThrottling, ) -    #No resource set -     -class MockView3(MockView2):          resource = FormResource + +class MockView_MinuteThrottling(MockView): +    throttle = '3/min' +  +   class ThrottlingTests(TestCase):      urls = 'djangorestframework.tests.throttling'         def setUp(self): -        """Reset the cache so that no throttles will be active""" +        """ +        Reset the cache so that no throttles will be active +        """          cache.clear()          self.factory = RequestFactory()      def test_requests_are_throttled(self): -        """Ensure request rate is limited""" +        """ +        Ensure request rate is limited +        """          request = self.factory.get('/')          for dummy in range(4):              response = MockView.as_view()(request)          self.assertEqual(503, response.status_code) +    def set_throttle_timer(self, view, value): +        """ +        Explicitly set the timer, overriding time.time() +        """ +        view.permissions[0].timer = lambda self: value +      def test_request_throttling_expires(self):          """          Ensure request rate is limited for a limited duration only          """ -        # Explicitly set the timer, overridding time.time() -        MockView.permissions[0].timer = lambda self: 0 +        self.set_throttle_timer(MockView, 0)          request = self.factory.get('/')          for dummy in range(4): @@ -59,7 +67,7 @@ class ThrottlingTests(TestCase):          self.assertEqual(503, response.status_code)          # Advance the timer by one second -        MockView.permissions[0].timer = lambda self: 1 +        self.set_throttle_timer(MockView, 1)          response = MockView.as_view()(request)          self.assertEqual(200, response.status_code) @@ -68,20 +76,73 @@ class ThrottlingTests(TestCase):          request = self.factory.get('/')          request.user = User.objects.create(username='a')          for dummy in range(3): -            response = view.as_view()(request) +            view.as_view()(request)          request.user = User.objects.create(username='b')          response = view.as_view()(request)          self.assertEqual(expect, response.status_code)      def test_request_throttling_is_per_user(self): -        """Ensure request rate is only limited per user, not globally for PerUserThrottles""" +        """ +        Ensure request rate is only limited per user, not globally for  +        PerUserThrottles +        """          self.ensure_is_throttled(MockView, 200)      def test_request_throttling_is_per_view(self): -        """Ensure request rate is limited globally per View for PerViewThrottles""" -        self.ensure_is_throttled(MockView1, 503) +        """ +        Ensure request rate is limited globally per View for PerViewThrottles +        """ +        self.ensure_is_throttled(MockView_PerViewThrottling, 503)      def test_request_throttling_is_per_resource(self): -        """Ensure request rate is limited globally per Resource for PerResourceThrottles"""         -        self.ensure_is_throttled(MockView3, 503) -    
\ No newline at end of file +        """ +        Ensure request rate is limited globally per Resource for PerResourceThrottles +        """         +        self.ensure_is_throttled(MockView_PerResourceThrottling, 503) +         +         +    def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): +        """ +        Ensure the response returns an X-Throttle field with status and next attributes +        set properly. +        """ +        request = self.factory.get('/') +        for timer, expect in expected_headers: +            self.set_throttle_timer(view, timer) +            response = view.as_view()(request) +            self.assertEquals(response['X-Throttle'], expect) +             +    def test_seconds_fields(self): +        """ +        Ensure for second based throttles. +        """ +        self.ensure_response_header_contains_proper_throttle_field(MockView, +         ((0, 'status=SUCCESS; next=0.33 sec'), +          (0, 'status=SUCCESS; next=0.50 sec'), +          (0, 'status=SUCCESS; next=1.00 sec'), +          (0, 'status=FAILURE; next=1.00 sec') +         )) +             +    def test_minutes_fields(self): +        """ +        Ensure for minute based throttles. +        """ +        self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, +         ((0, 'status=SUCCESS; next=20.00 sec'), +          (0, 'status=SUCCESS; next=30.00 sec'), +          (0, 'status=SUCCESS; next=60.00 sec'), +          (0, 'status=FAILURE; next=60.00 sec') +         )) +     +    def test_next_rate_remains_constant_if_followed(self): +        """ +        If a client follows the recommended next request rate, +        the throttling rate should stay constant. +        """ +        self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, +         ((0, 'status=SUCCESS; next=20.00 sec'), +          (20, 'status=SUCCESS; next=20.00 sec'), +          (40, 'status=SUCCESS; next=20.00 sec'), +          (60, 'status=SUCCESS; next=20.00 sec'), +          (80, 'status=SUCCESS; next=20.00 sec') +         )) diff --git a/djangorestframework/views.py b/djangorestframework/views.py index 49d722c5..18d064e1 100644 --- a/djangorestframework/views.py +++ b/djangorestframework/views.py @@ -64,7 +64,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):      """      permissions = ( permissions.FullAnonAccess, ) - +          @classmethod      def as_view(cls, **initkwargs):          """ @@ -101,6 +101,14 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):          """          pass + +    def add_header(self, field, value): +        """ +        Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class.  +        """ +        self.headers[field] = value + +      # Note: session based authentication is explicitly CSRF validated,      # all other authentication is CSRF exempt.      @csrf_exempt @@ -108,6 +116,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):          self.request = request          self.args = args          self.kwargs = kwargs +        self.headers = {}          # Calls to 'reverse' will not be fully qualified unless we set the scheme/host/port here.          prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host()) @@ -149,9 +158,12 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):          # also it's currently sub-obtimal for HTTP caching - need to sort that out.           response.headers['Allow'] = ', '.join(self.allowed_methods)          response.headers['Vary'] = 'Authenticate, Accept' +         +        # merge with headers possibly set at some point in the view +        response.headers.update(self.headers) +         +        return self.render(response)     -        return self.render(response) -      class ModelView(View):      """A RESTful view that maps to a model in the database."""  | 
