diff options
Diffstat (limited to 'djangorestframework')
| -rw-r--r-- | djangorestframework/permissions.py | 12 | ||||
| -rw-r--r-- | djangorestframework/tests/throttling.py | 28 | ||||
| -rw-r--r-- | djangorestframework/throttling.py | 131 | ||||
| -rw-r--r-- | djangorestframework/views.py | 4 |
4 files changed, 127 insertions, 48 deletions
diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py index 64e455f5..3a669822 100644 --- a/djangorestframework/permissions.py +++ b/djangorestframework/permissions.py @@ -28,11 +28,11 @@ class BasePermission(object): """ self.view = view - def check_permission(self, request, obj=None): + def has_permission(self, request, obj=None): """ Should simply return, or raise an :exc:`response.ImmediateResponse`. """ - raise NotImplementedError(".check_permission() must be overridden.") + raise NotImplementedError(".has_permission() must be overridden.") class IsAuthenticated(BasePermission): @@ -40,7 +40,7 @@ class IsAuthenticated(BasePermission): Allows access only to authenticated users. """ - def check_permission(self, request, obj=None): + def has_permission(self, request, obj=None): if request.user and request.user.is_authenticated(): return True return False @@ -51,7 +51,7 @@ class IsAdminUser(BasePermission): Allows access only to admin users. """ - def check_permission(self, request, obj=None): + def has_permission(self, request, obj=None): if request.user and request.user.is_staff: return True return False @@ -62,7 +62,7 @@ class IsAuthenticatedOrReadOnly(BasePermission): The request is authenticated as a user, or is a read-only request. """ - def check_permission(self, request, obj=None): + def has_permission(self, request, obj=None): if (request.method in SAFE_METHODS or request.user and request.user.is_authenticated()): @@ -105,7 +105,7 @@ class DjangoModelPermissions(BasePermission): } return [perm % kwargs for perm in self.perms_map[method]] - def check_permission(self, request, obj=None): + def has_permission(self, request, obj=None): model_cls = self.view.model perms = self.get_required_permissions(request.method, model_cls) diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py index d144d956..9ee4ffa4 100644 --- a/djangorestframework/tests/throttling.py +++ b/djangorestframework/tests/throttling.py @@ -8,24 +8,30 @@ from django.core.cache import cache from djangorestframework.compat import RequestFactory from djangorestframework.views import APIView -from djangorestframework.throttling import PerUserThrottling, PerViewThrottling +from djangorestframework.throttling import UserRateThrottle from djangorestframework.response import Response -class MockView(APIView): - throttle_classes = (PerUserThrottling,) +class User3SecRateThrottle(UserRateThrottle): rate = '3/sec' + +class User3MinRateThrottle(UserRateThrottle): + rate = '3/min' + + +class MockView(APIView): + throttle_classes = (User3SecRateThrottle,) + def get(self, request): return Response('foo') -class MockView_PerViewThrottling(MockView): - throttle_classes = (PerViewThrottling,) +class MockView_MinuteThrottling(APIView): + throttle_classes = (User3MinRateThrottle,) - -class MockView_MinuteThrottling(MockView): - rate = '3/min' + def get(self, request): + return Response('foo') class ThrottlingTests(TestCase): @@ -86,12 +92,6 @@ class ThrottlingTests(TestCase): """ self.ensure_is_throttled(MockView, 200) - def test_request_throttling_is_per_view(self): - """ - Ensure request rate is limited globally per View for PerViewThrottles - """ - self.ensure_is_throttled(MockView_PerViewThrottling, 429) - def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): """ Ensure the response returns an X-Throttle field with status and next attributes diff --git a/djangorestframework/throttling.py b/djangorestframework/throttling.py index a096eab7..f8b098d7 100644 --- a/djangorestframework/throttling.py +++ b/djangorestframework/throttling.py @@ -1,4 +1,5 @@ from django.core.cache import cache +from djangorestframework.settings import api_settings import time @@ -13,11 +14,11 @@ class BaseThrottle(object): """ self.view = view - def check_throttle(self, request): + def allow_request(self, request): """ Return `True` if the request should be allowed, `False` otherwise. """ - raise NotImplementedError('.check_throttle() must be overridden') + raise NotImplementedError('.allow_request() must be overridden') def wait(self): """ @@ -27,7 +28,7 @@ class BaseThrottle(object): return None -class SimpleCachingThrottle(BaseThrottle): +class SimpleRateThottle(BaseThrottle): """ A simple cache implementation, that only requires `.get_cache_key()` to be overridden. @@ -41,33 +42,51 @@ class SimpleCachingThrottle(BaseThrottle): Previous request information used for throttling is stored in the cache. """ - attr_name = 'rate' - rate = '1000/day' timer = time.time + settings = api_settings + cache_format = '%(class)s_%(scope)s_%(ident)s' + scope = None def __init__(self, view): - """ - Check the throttling. - Return `None` or raise an :exc:`.ImmediateResponse`. - """ - super(SimpleCachingThrottle, self).__init__(view) - num, period = getattr(view, self.attr_name, self.rate).split('/') - self.num_requests = int(num) - self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] + super(SimpleRateThottle, self).__init__(view) + rate = self.get_rate_description() + self.num_requests, self.duration = self.parse_rate_description(rate) def get_cache_key(self, request): """ Should return a unique cache-key which can be used for throttling. Must be overridden. + + May return `None` if the request should not be throttled. """ raise NotImplementedError('.get_cache_key() must be overridden') - def check_throttle(self, request): + def get_rate_description(self): + """ + Determine the string representation of the allowed request rate. + """ + try: + return self.rate + except AttributeError: + return self.settings.DEFAULT_THROTTLE_RATES.get(self.scope) + + def parse_rate_description(self, rate): + """ + Given the request rate string, return a two tuple of: + <allowed number of requests>, <period of time in seconds> + """ + assert rate, "No throttle rate set for '%s'" % self.__class__.__name__ + num, period = rate.split('/') + num_requests = int(num) + duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] + return (num_requests, duration) + + def allow_request(self, request): """ Implement the check to see if the request should be throttled. - On success calls :meth:`throttle_success`. - On failure calls :meth:`throttle_failure`. + On success calls `throttle_success`. + On failure calls `throttle_failure`. """ self.key = self.get_cache_key(request) self.history = cache.get(self.key, []) @@ -110,30 +129,90 @@ class SimpleCachingThrottle(BaseThrottle): return remaining_duration / float(available_requests) -class PerUserThrottling(SimpleCachingThrottle): +class AnonRateThrottle(SimpleRateThottle): + """ + Limits the rate of API calls that may be made by a anonymous users. + + The IP address of the request will be used as the unqiue cache key. + """ + scope = 'anon' + + def get_cache_key(self, request): + if request.user.is_authenticated(): + return None # Only throttle unauthenticated requests. + + ident = request.META.get('REMOTE_ADDR', None) + + return self.cache_format % { + 'class': self.__class__.__name__, + 'scope': self.scope, + 'ident': ident + } + + +class UserRateThrottle(SimpleRateThottle): """ Limits the rate of API calls that may be made by a given user. - The user id will be used as a unique identifier if the user is - authenticated. For anonymous requests, the IP address of the client will + The user id will be used as a unique cache key if the user is + authenticated. For anonymous requests, the IP address of the request will be used. """ + scope = 'user' def get_cache_key(self, request): if request.user.is_authenticated(): ident = request.user.id else: ident = request.META.get('REMOTE_ADDR', None) - return 'throttle_user_%s' % ident + return self.cache_format % { + 'class': self.__class__.__name__, + 'scope': self.scope, + 'ident': ident + } -class PerViewThrottling(SimpleCachingThrottle): - """ - Limits the rate of API calls that may be used on a given view. - The class name of the view is used as a unique identifier to - throttle against. +class ScopedRateThrottle(SimpleRateThottle): + """ + Limits the rate of API calls by different amounts for various parts of + the API. Any view that has the `throttle_scope` property set will be + throttled. The unique cache key will be generated by concatenating the + user id of the request, and the scope of the view being accessed. """ + def __init__(self, view): + """ + Scope is determined from the view being accessed. + """ + self.scope = getattr(self.view, 'throttle_scope', None) + super(ScopedRateThrottle, self).__init__(view) + + def parse_rate_description(self, rate): + """ + Subclassed so that we don't fail if `view.throttle_scope` is not set. + """ + if not rate: + return (None, None) + return super(ScopedRateThrottle, self).parse_rate_description(rate) + def get_cache_key(self, request): - return 'throttle_view_%s' % self.view.__class__.__name__ + """ + If `view.throttle_scope` is not set, don't apply this throttle. + + Otherwise generate the unique cache key by concatenating the user id + with the '.throttle_scope` property of the view. + """ + if not self.scope: + return None # Only throttle views with `.throttle_scope` set. + + if request.user.is_authenticated(): + ident = request.user.id + else: + ident = request.META.get('REMOTE_ADDR', None) + + return self.cache_format % { + 'class': self.__class__.__name__, + 'scope': self.scope, + 'ident': ident + } diff --git a/djangorestframework/views.py b/djangorestframework/views.py index 5ec55d8c..a309386b 100644 --- a/djangorestframework/views.py +++ b/djangorestframework/views.py @@ -186,7 +186,7 @@ class APIView(_View): Check if request should be permitted. """ for permission in self.get_permissions(): - if not permission.check_permission(request, obj): + if not permission.has_permission(request, obj): self.permission_denied(request) def check_throttles(self, request): @@ -194,7 +194,7 @@ class APIView(_View): Check if request should be throttled. """ for throttle in self.get_throttles(): - if not throttle.check_throttle(request): + if not throttle.allow_request(request): self.throttled(request, throttle.wait()) def initialize_request(self, request, *args, **kargs): |
