aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2012-09-26 21:47:19 +0100
committerTom Christie2012-09-26 21:47:19 +0100
commit0cc7030aab9f2b97ce6b5db55d6d1a8a32d50231 (patch)
tree4a47a50776a7611ad59eec2509181a3a6bb4abb2 /rest_framework
parent622e001e0bde3742f5e8830436ea0304c892480a (diff)
downloaddjango-rest-framework-0cc7030aab9f2b97ce6b5db55d6d1a8a32d50231.tar.bz2
Fix @api_view decorator tests
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/decorators.py8
-rw-r--r--rest_framework/exceptions.py7
-rw-r--r--rest_framework/settings.py4
-rw-r--r--rest_framework/tests/decorators.py29
-rw-r--r--rest_framework/throttling.py36
5 files changed, 47 insertions, 37 deletions
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index fc81489e..0c5fec55 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -1,11 +1,3 @@
-from functools import wraps
-from django.utils.decorators import available_attrs
-from django.core.exceptions import PermissionDenied
-from rest_framework import exceptions
-from rest_framework import status
-from rest_framework.response import Response
-from rest_framework.request import Request
-from rest_framework.settings import api_settings
from rest_framework.views import APIView
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index e7836ecd..572425b9 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -77,3 +77,10 @@ class Throttled(APIException):
self.detail = format % (self.wait, self.wait != 1 and 's' or '')
else:
self.detail = detail or self.default_detail
+
+
+class ConfigurationError(Exception):
+ """
+ Indicates an internal server error.
+ """
+ pass
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index a498b222..cfc89fe1 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -39,6 +39,10 @@ DEFAULTS = {
'DEFAULT_THROTTLES': (),
'DEFAULT_CONTENT_NEGOTIATION':
'rest_framework.negotiation.DefaultContentNegotiation',
+ 'DEFAULT_THROTTLE_RATES': {
+ 'user': None,
+ 'anon': None,
+ },
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py
index d41f05d4..9aeaf7d8 100644
--- a/rest_framework/tests/decorators.py
+++ b/rest_framework/tests/decorators.py
@@ -1,10 +1,11 @@
from django.test import TestCase
+from rest_framework import status
from rest_framework.response import Response
from rest_framework.compat import RequestFactory
from rest_framework.renderers import JSONRenderer
from rest_framework.parsers import JSONParser
from rest_framework.authentication import BasicAuthentication
-from rest_framework.throttling import SimpleRateThottle
+from rest_framework.throttling import UserRateThrottle
from rest_framework.permissions import IsAuthenticated
from rest_framework.views import APIView
from rest_framework.decorators import (
@@ -23,7 +24,6 @@ class DecoratorTestCase(TestCase):
self.factory = RequestFactory()
def _finalize_response(self, request, response, *args, **kwargs):
- print "HAI"
response.request = request
return APIView.finalize_response(self, request, response, *args, **kwargs)
@@ -87,21 +87,24 @@ class DecoratorTestCase(TestCase):
@api_view(['GET'])
@permission_classes([IsAuthenticated])
def view(request):
- self.assertEqual(request.permission_classes, [IsAuthenticated])
return Response({})
request = self.factory.get('/')
- view(request)
+ response = view(request)
+ self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN)
-# Doesn't look like this bits are working quite yet
+ def test_throttle_classes(self):
+ class OncePerDayUserThrottle(UserRateThrottle):
+ rate = '1/day'
-# def test_throttle_classes(self):
+ @api_view(['GET'])
+ @throttle_classes([OncePerDayUserThrottle])
+ def view(request):
+ return Response({})
-# @api_view(['GET'])
-# @throttle_classes([SimpleRateThottle])
-# def view(request):
-# self.assertEqual(request.throttle_classes, [SimpleRateThottle])
-# return Response({})
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
-# request = self.factory.get('/')
-# view(request)
+ response = view(request)
+ self.assertEquals(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index e7750478..0d33f0e2 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -1,5 +1,6 @@
import time
from django.core.cache import cache
+from rest_framework import exceptions
from rest_framework.settings import api_settings
@@ -49,8 +50,9 @@ class SimpleRateThottle(BaseThrottle):
def __init__(self, view):
super(SimpleRateThottle, self).__init__(view)
- rate = self.get_rate_description()
- self.num_requests, self.duration = self.parse_rate_description(rate)
+ if not getattr(self, 'rate', None):
+ self.rate = self.get_rate()
+ self.num_requests, self.duration = self.parse_rate(self.rate)
def get_cache_key(self, request):
"""
@@ -61,21 +63,28 @@ class SimpleRateThottle(BaseThrottle):
"""
raise NotImplementedError('.get_cache_key() must be overridden')
- def get_rate_description(self):
+ def get_rate(self):
"""
Determine the string representation of the allowed request rate.
"""
+ if not getattr(self, 'scope', None):
+ msg = ("You must set either `.scope` or `.rate` for '%s' thottle" %
+ self.__class__.__name__)
+ raise exceptions.ConfigurationError(msg)
+
try:
- return self.rate
- except AttributeError:
- return self.settings.DEFAULT_THROTTLE_RATES.get(self.scope)
+ return self.settings.DEFAULT_THROTTLE_RATES[self.scope]
+ except KeyError:
+ msg = "No default throttle rate set for '%s' scope" % self.scope
+ raise exceptions.ConfigurationError(msg)
- def parse_rate_description(self, rate):
+ def parse_rate(self, rate):
"""
Given the request rate string, return a two tuple of:
<allowed number of requests>, <period of time in seconds>
"""
- assert rate, "No throttle rate set for '%s'" % self.__class__.__name__
+ if rate is None:
+ return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
@@ -88,6 +97,9 @@ class SimpleRateThottle(BaseThrottle):
On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
+ if self.rate is None:
+ return True
+
self.key = self.get_cache_key(request)
self.history = cache.get(self.key, [])
self.now = self.timer()
@@ -188,14 +200,6 @@ class ScopedRateThrottle(SimpleRateThottle):
self.scope = getattr(self.view, self.scope_attr, None)
super(ScopedRateThrottle, self).__init__(view)
- def parse_rate_description(self, rate):
- """
- Subclassed so that we don't fail if `view.throttle_scope` is not set.
- """
- if not rate:
- return (None, None)
- return super(ScopedRateThrottle, self).parse_rate_description(rate)
-
def get_cache_key(self, request):
"""
If `view.throttle_scope` is not set, don't apply this throttle.