diff options
| author | Tom Christie | 2012-09-26 21:47:19 +0100 |
|---|---|---|
| committer | Tom Christie | 2012-09-26 21:47:19 +0100 |
| commit | 0cc7030aab9f2b97ce6b5db55d6d1a8a32d50231 (patch) | |
| tree | 4a47a50776a7611ad59eec2509181a3a6bb4abb2 /rest_framework | |
| parent | 622e001e0bde3742f5e8830436ea0304c892480a (diff) | |
| download | django-rest-framework-0cc7030aab9f2b97ce6b5db55d6d1a8a32d50231.tar.bz2 | |
Fix @api_view decorator tests
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/decorators.py | 8 | ||||
| -rw-r--r-- | rest_framework/exceptions.py | 7 | ||||
| -rw-r--r-- | rest_framework/settings.py | 4 | ||||
| -rw-r--r-- | rest_framework/tests/decorators.py | 29 | ||||
| -rw-r--r-- | rest_framework/throttling.py | 36 |
5 files changed, 47 insertions, 37 deletions
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index fc81489e..0c5fec55 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,11 +1,3 @@ -from functools import wraps -from django.utils.decorators import available_attrs -from django.core.exceptions import PermissionDenied -from rest_framework import exceptions -from rest_framework import status -from rest_framework.response import Response -from rest_framework.request import Request -from rest_framework.settings import api_settings from rest_framework.views import APIView diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index e7836ecd..572425b9 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -77,3 +77,10 @@ class Throttled(APIException): self.detail = format % (self.wait, self.wait != 1 and 's' or '') else: self.detail = detail or self.default_detail + + +class ConfigurationError(Exception): + """ + Indicates an internal server error. + """ + pass diff --git a/rest_framework/settings.py b/rest_framework/settings.py index a498b222..cfc89fe1 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -39,6 +39,10 @@ DEFAULTS = { 'DEFAULT_THROTTLES': (), 'DEFAULT_CONTENT_NEGOTIATION': 'rest_framework.negotiation.DefaultContentNegotiation', + 'DEFAULT_THROTTLE_RATES': { + 'user': None, + 'anon': None, + }, 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index d41f05d4..9aeaf7d8 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -1,10 +1,11 @@ from django.test import TestCase +from rest_framework import status from rest_framework.response import Response from rest_framework.compat import RequestFactory from rest_framework.renderers import JSONRenderer from rest_framework.parsers import JSONParser from rest_framework.authentication import BasicAuthentication -from rest_framework.throttling import SimpleRateThottle +from rest_framework.throttling import UserRateThrottle from rest_framework.permissions import IsAuthenticated from rest_framework.views import APIView from rest_framework.decorators import ( @@ -23,7 +24,6 @@ class DecoratorTestCase(TestCase): self.factory = RequestFactory() def _finalize_response(self, request, response, *args, **kwargs): - print "HAI" response.request = request return APIView.finalize_response(self, request, response, *args, **kwargs) @@ -87,21 +87,24 @@ class DecoratorTestCase(TestCase): @api_view(['GET']) @permission_classes([IsAuthenticated]) def view(request): - self.assertEqual(request.permission_classes, [IsAuthenticated]) return Response({}) request = self.factory.get('/') - view(request) + response = view(request) + self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN) -# Doesn't look like this bits are working quite yet + def test_throttle_classes(self): + class OncePerDayUserThrottle(UserRateThrottle): + rate = '1/day' -# def test_throttle_classes(self): + @api_view(['GET']) + @throttle_classes([OncePerDayUserThrottle]) + def view(request): + return Response({}) -# @api_view(['GET']) -# @throttle_classes([SimpleRateThottle]) -# def view(request): -# self.assertEqual(request.throttle_classes, [SimpleRateThottle]) -# return Response({}) + request = self.factory.get('/') + response = view(request) + self.assertEquals(response.status_code, status.HTTP_200_OK) -# request = self.factory.get('/') -# view(request) + response = view(request) + self.assertEquals(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index e7750478..0d33f0e2 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -1,5 +1,6 @@ import time from django.core.cache import cache +from rest_framework import exceptions from rest_framework.settings import api_settings @@ -49,8 +50,9 @@ class SimpleRateThottle(BaseThrottle): def __init__(self, view): super(SimpleRateThottle, self).__init__(view) - rate = self.get_rate_description() - self.num_requests, self.duration = self.parse_rate_description(rate) + if not getattr(self, 'rate', None): + self.rate = self.get_rate() + self.num_requests, self.duration = self.parse_rate(self.rate) def get_cache_key(self, request): """ @@ -61,21 +63,28 @@ class SimpleRateThottle(BaseThrottle): """ raise NotImplementedError('.get_cache_key() must be overridden') - def get_rate_description(self): + def get_rate(self): """ Determine the string representation of the allowed request rate. """ + if not getattr(self, 'scope', None): + msg = ("You must set either `.scope` or `.rate` for '%s' thottle" % + self.__class__.__name__) + raise exceptions.ConfigurationError(msg) + try: - return self.rate - except AttributeError: - return self.settings.DEFAULT_THROTTLE_RATES.get(self.scope) + return self.settings.DEFAULT_THROTTLE_RATES[self.scope] + except KeyError: + msg = "No default throttle rate set for '%s' scope" % self.scope + raise exceptions.ConfigurationError(msg) - def parse_rate_description(self, rate): + def parse_rate(self, rate): """ Given the request rate string, return a two tuple of: <allowed number of requests>, <period of time in seconds> """ - assert rate, "No throttle rate set for '%s'" % self.__class__.__name__ + if rate is None: + return (None, None) num, period = rate.split('/') num_requests = int(num) duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] @@ -88,6 +97,9 @@ class SimpleRateThottle(BaseThrottle): On success calls `throttle_success`. On failure calls `throttle_failure`. """ + if self.rate is None: + return True + self.key = self.get_cache_key(request) self.history = cache.get(self.key, []) self.now = self.timer() @@ -188,14 +200,6 @@ class ScopedRateThrottle(SimpleRateThottle): self.scope = getattr(self.view, self.scope_attr, None) super(ScopedRateThrottle, self).__init__(view) - def parse_rate_description(self, rate): - """ - Subclassed so that we don't fail if `view.throttle_scope` is not set. - """ - if not rate: - return (None, None) - return super(ScopedRateThrottle, self).parse_rate_description(rate) - def get_cache_key(self, request): """ If `view.throttle_scope` is not set, don't apply this throttle. |
