aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/authentication.py14
-rw-r--r--rest_framework/authtoken/models.py4
-rw-r--r--rest_framework/compat.py26
-rw-r--r--rest_framework/exceptions.py7
-rw-r--r--rest_framework/fields.py51
-rw-r--r--rest_framework/generics.py4
-rw-r--r--rest_framework/permissions.py2
-rw-r--r--rest_framework/renderers.py4
-rw-r--r--rest_framework/routers.py20
-rw-r--r--rest_framework/runtests/settings.py2
-rw-r--r--rest_framework/serializers.py46
-rw-r--r--rest_framework/tests/description.py26
-rw-r--r--rest_framework/tests/test_authentication.py41
-rw-r--r--rest_framework/tests/test_description.py13
-rw-r--r--rest_framework/tests/test_fields.py30
-rw-r--r--rest_framework/tests/test_hyperlinkedserializers.py27
-rw-r--r--rest_framework/tests/test_routers.py73
-rw-r--r--rest_framework/tests/test_throttling.py109
-rw-r--r--rest_framework/throttling.py39
-rw-r--r--rest_framework/utils/formatting.py4
-rw-r--r--rest_framework/views.py24
22 files changed, 449 insertions, 119 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 0a210186..776618ac 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,4 +1,4 @@
-__version__ = '2.3.5'
+__version__ = '2.3.6'
VERSION = __version__ # synonym
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 9caca788..10298027 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -3,14 +3,13 @@ Provides various authentication policies.
"""
from __future__ import unicode_literals
import base64
-from datetime import datetime
from django.contrib.auth import authenticate
from django.core.exceptions import ImproperlyConfigured
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
-from rest_framework.compat import oauth2_provider
+from rest_framework.compat import oauth2_provider, provider_now
from rest_framework.authtoken.models import Token
@@ -230,8 +229,9 @@ class OAuthAuthentication(BaseAuthentication):
try:
consumer_key = oauth_request.get_parameter('oauth_consumer_key')
consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
- except oauth_provider.store.InvalidConsumerError as err:
- raise exceptions.AuthenticationFailed(err)
+ except oauth_provider.store.InvalidConsumerError:
+ msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key')
+ raise exceptions.AuthenticationFailed(msg)
if consumer.status != oauth_provider.consts.ACCEPTED:
msg = 'Invalid consumer key status: %s' % consumer.get_status_display()
@@ -319,9 +319,9 @@ class OAuth2Authentication(BaseAuthentication):
try:
token = oauth2_provider.models.AccessToken.objects.select_related('user')
- # TODO: Change to timezone aware datetime when oauth2_provider add
- # support to it.
- token = token.get(token=access_token, expires__gt=datetime.now())
+ # provider_now switches to timezone aware datetime when
+ # the oauth2_provider version supports to it.
+ token = token.get(token=access_token, expires__gt=provider_now())
except oauth2_provider.models.AccessToken.DoesNotExist:
raise exceptions.AuthenticationFailed('Invalid token')
diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py
index 52c45ad1..7601f5b7 100644
--- a/rest_framework/authtoken/models.py
+++ b/rest_framework/authtoken/models.py
@@ -1,7 +1,7 @@
import uuid
import hmac
from hashlib import sha1
-from rest_framework.compat import User
+from rest_framework.compat import AUTH_USER_MODEL
from django.conf import settings
from django.db import models
@@ -11,7 +11,7 @@ class Token(models.Model):
The default authorization token model.
"""
key = models.CharField(max_length=40, primary_key=True)
- user = models.OneToOneField(User, related_name='auth_token')
+ user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token')
created = models.DateTimeField(auto_now_add=True)
class Meta:
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 76dc0052..b748dcc5 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -2,6 +2,7 @@
The `compat` module provides support for backwards compatibility with older
versions of django/python, and compatibility wrappers around optional packages.
"""
+
# flake8: noqa
from __future__ import unicode_literals
@@ -33,6 +34,12 @@ except ImportError:
from django.utils.encoding import force_unicode as force_text
+# HttpResponseBase only exists from 1.5 onwards
+try:
+ from django.http.response import HttpResponseBase
+except ImportError:
+ from django.http import HttpResponse as HttpResponseBase
+
# django-filter is optional
try:
import django_filters
@@ -77,15 +84,9 @@ def get_concrete_model(model_cls):
# Django 1.5 add support for custom auth user model
if django.VERSION >= (1, 5):
from django.conf import settings
- if hasattr(settings, 'AUTH_USER_MODEL'):
- User = settings.AUTH_USER_MODEL
- else:
- from django.contrib.auth.models import User
+ AUTH_USER_MODEL = settings.AUTH_USER_MODEL
else:
- try:
- from django.contrib.auth.models import User
- except ImportError:
- raise ImportError("User model is not to be found.")
+ AUTH_USER_MODEL = 'auth.User'
if django.VERSION >= (1, 5):
@@ -489,12 +490,21 @@ try:
from provider.oauth2 import forms as oauth2_provider_forms
from provider import scope as oauth2_provider_scope
from provider import constants as oauth2_constants
+ from provider import __version__ as provider_version
+ if provider_version in ('0.2.3', '0.2.4'):
+ # 0.2.3 and 0.2.4 are supported version that do not support
+ # timezone aware datetimes
+ from datetime.datetime import now as provider_now
+ else:
+ # Any other supported version does use timezone aware datetimes
+ from django.utils.timezone import now as provider_now
except ImportError:
oauth2_provider = None
oauth2_provider_models = None
oauth2_provider_forms = None
oauth2_provider_scope = None
oauth2_constants = None
+ provider_now = None
# Handle lazy strings
from django.utils.functional import Promise
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 0c96ecdd..425a7214 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -86,10 +86,3 @@ class Throttled(APIException):
self.detail = format % (self.wait, self.wait != 1 and 's' or '')
else:
self.detail = detail or self.default_detail
-
-
-class ConfigurationError(Exception):
- """
- Indicates an internal server error.
- """
- pass
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 535aa2ac..35848b4c 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -7,25 +7,24 @@ from __future__ import unicode_literals
import copy
import datetime
-from decimal import Decimal, DecimalException
import inspect
import re
import warnings
+from decimal import Decimal, DecimalException
+from django import forms
from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
from django.db.models.fields import BLANK_CHOICE_DASH
-from django import forms
from django.forms import widgets
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
from django.utils.datastructures import SortedDict
from rest_framework import ISO_8601
-from rest_framework.compat import (timezone, parse_date, parse_datetime,
- parse_time)
-from rest_framework.compat import BytesIO
-from rest_framework.compat import six
-from rest_framework.compat import smart_text, force_text, is_non_str_iterable
+from rest_framework.compat import (
+ timezone, parse_date, parse_datetime, parse_time, BytesIO, six, smart_text,
+ force_text, is_non_str_iterable
+)
from rest_framework.settings import api_settings
@@ -256,6 +255,12 @@ class WritableField(Field):
widget = widget()
self.widget = widget
+ def __deepcopy__(self, memo):
+ result = copy.copy(self)
+ memo[id(self)] = result
+ result.validators = self.validators[:]
+ return result
+
def validate(self, value):
if value in validators.EMPTY_VALUES and self.required:
raise ValidationError(self.error_messages['required'])
@@ -331,9 +336,13 @@ class ModelField(WritableField):
raise ValueError("ModelField requires 'model_field' kwarg")
self.min_length = kwargs.pop('min_length',
- getattr(self.model_field, 'min_length', None))
+ getattr(self.model_field, 'min_length', None))
self.max_length = kwargs.pop('max_length',
- getattr(self.model_field, 'max_length', None))
+ getattr(self.model_field, 'max_length', None))
+ self.min_value = kwargs.pop('min_value',
+ getattr(self.model_field, 'min_value', None))
+ self.max_value = kwargs.pop('max_value',
+ getattr(self.model_field, 'max_value', None))
super(ModelField, self).__init__(*args, **kwargs)
@@ -341,6 +350,10 @@ class ModelField(WritableField):
self.validators.append(validators.MinLengthValidator(self.min_length))
if self.max_length is not None:
self.validators.append(validators.MaxLengthValidator(self.max_length))
+ if self.min_value is not None:
+ self.validators.append(validators.MinValueValidator(self.min_value))
+ if self.max_value is not None:
+ self.validators.append(validators.MaxValueValidator(self.max_value))
def from_native(self, value):
rel = getattr(self.model_field, "rel", None)
@@ -428,13 +441,6 @@ class SlugField(CharField):
def __init__(self, *args, **kwargs):
super(SlugField, self).__init__(*args, **kwargs)
- def __deepcopy__(self, memo):
- result = copy.copy(self)
- memo[id(self)] = result
- #result.widget = copy.deepcopy(self.widget, memo)
- result.validators = self.validators[:]
- return result
-
class ChoiceField(WritableField):
type_name = 'ChoiceField'
@@ -503,13 +509,6 @@ class EmailField(CharField):
return None
return ret.strip()
- def __deepcopy__(self, memo):
- result = copy.copy(self)
- memo[id(self)] = result
- #result.widget = copy.deepcopy(self.widget, memo)
- result.validators = self.validators[:]
- return result
-
class RegexField(CharField):
type_name = 'RegexField'
@@ -534,12 +533,6 @@ class RegexField(CharField):
regex = property(_get_regex, _set_regex)
- def __deepcopy__(self, memo):
- result = copy.copy(self)
- memo[id(self)] = result
- result.validators = self.validators[:]
- return result
-
class DateField(WritableField):
type_name = 'DateField'
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 9ccc7898..99e9782e 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -212,7 +212,7 @@ class GenericAPIView(views.APIView):
You may want to override this if you need to provide different
serializations depending on the incoming request.
- (Eg. admins get full serialization, others get basic serilization)
+ (Eg. admins get full serialization, others get basic serialization)
"""
serializer_class = self.serializer_class
if serializer_class is not None:
@@ -285,7 +285,7 @@ class GenericAPIView(views.APIView):
)
filter_kwargs = {self.slug_field: slug}
else:
- raise exceptions.ConfigurationError(
+ raise ImproperlyConfigured(
'Expected view %s to be called with a URL keyword argument '
'named "%s". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.' %
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index 45fcfd66..1036663e 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -128,7 +128,7 @@ class DjangoModelPermissions(BasePermission):
# Workaround to ensure DjangoModelPermissions are not applied
# to the root view when using DefaultRouter.
- if model_cls is None and getattr(view, '_ignore_model_permissions'):
+ if model_cls is None and getattr(view, '_ignore_model_permissions', False):
return True
assert model_cls, ('Cannot apply DjangoModelPermissions on a view that'
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index b2fe43ea..8b2428ad 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -11,6 +11,7 @@ from __future__ import unicode_literals
import copy
import json
from django import forms
+from django.core.exceptions import ImproperlyConfigured
from django.http.multipartparser import parse_header
from django.template import RequestContext, loader, Template
from django.utils.xmlutils import SimplerXMLGenerator
@@ -18,7 +19,6 @@ from rest_framework.compat import StringIO
from rest_framework.compat import six
from rest_framework.compat import smart_text
from rest_framework.compat import yaml
-from rest_framework.exceptions import ConfigurationError
from rest_framework.settings import api_settings
from rest_framework.request import clone_request
from rest_framework.utils import encoders
@@ -270,7 +270,7 @@ class TemplateHTMLRenderer(BaseRenderer):
return [self.template_name]
elif hasattr(view, 'get_template_names'):
return view.get_template_names()
- raise ConfigurationError('Returned a template response with no template_name')
+ raise ImproperlyConfigured('Returned a template response with no template_name')
def get_exception_template(self, response):
template_names = [name % {'status_code': response.status_code}
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 9764e569..930011d3 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -15,7 +15,9 @@ For example, you might have a `urls.py` that looks something like this:
"""
from __future__ import unicode_literals
+import itertools
from collections import namedtuple
+from django.core.exceptions import ImproperlyConfigured
from rest_framework import views
from rest_framework.compat import patterns, url
from rest_framework.response import Response
@@ -38,6 +40,13 @@ def replace_methodname(format_string, methodname):
return ret
+def flatten(list_of_lists):
+ """
+ Takes an iterable of iterables, returns a single iterable containing all items
+ """
+ return itertools.chain(*list_of_lists)
+
+
class BaseRouter(object):
def __init__(self):
self.registry = []
@@ -117,7 +126,7 @@ class SimpleRouter(BaseRouter):
if model_cls is None and queryset is not None:
model_cls = queryset.model
- assert model_cls, '`name` not argument not specified, and could ' \
+ assert model_cls, '`base_name` argument not specified, and could ' \
'not automatically determine the name from the viewset, as ' \
'it does not have a `.model` or `.queryset` attribute.'
@@ -130,12 +139,18 @@ class SimpleRouter(BaseRouter):
Returns a list of the Route namedtuple.
"""
+ known_actions = flatten([route.mapping.values() for route in self.routes])
+
# Determine any `@action` or `@link` decorated methods on the viewset
dynamic_routes = []
for methodname in dir(viewset):
attr = getattr(viewset, methodname)
httpmethods = getattr(attr, 'bind_to_methods', None)
if httpmethods:
+ if methodname in known_actions:
+ raise ImproperlyConfigured('Cannot use @action or @link decorator on '
+ 'method "%s" as it is an existing route' % methodname)
+ httpmethods = [method.lower() for method in httpmethods]
dynamic_routes.append((httpmethods, methodname))
ret = []
@@ -215,6 +230,7 @@ class DefaultRouter(SimpleRouter):
"""
include_root_view = True
include_format_suffixes = True
+ root_view_name = 'api-root'
def get_api_root_view(self):
"""
@@ -244,7 +260,7 @@ class DefaultRouter(SimpleRouter):
urls = []
if self.include_root_view:
- root_url = url(r'^$', self.get_api_root_view(), name='api-root')
+ root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name)
urls.append(root_url)
default_urls = super(DefaultRouter, self).get_urls()
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index 9dd7b545..b3702d0b 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -134,6 +134,8 @@ PASSWORD_HASHERS = (
'django.contrib.auth.hashers.CryptPasswordHasher',
)
+AUTH_USER_MODEL = 'auth.User'
+
import django
if django.VERSION < (1, 3):
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 4acbc704..d8f9145e 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -944,34 +944,23 @@ class HyperlinkedModelSerializer(ModelSerializer):
_default_view_name = '%(model_name)s-detail'
_hyperlink_field_class = HyperlinkedRelatedField
- # Just a placeholder to ensure 'url' is the first field
- # The field itself is actually created on initialization,
- # when the view_name and lookup_field arguments are available.
- url = Field()
-
- def __init__(self, *args, **kwargs):
- super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
+ def get_default_fields(self):
+ fields = super(HyperlinkedModelSerializer, self).get_default_fields()
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model)
- url_field = HyperlinkedIdentityField(
- view_name=self.opts.view_name,
- lookup_field=self.opts.lookup_field
- )
- url_field.initialize(self, 'url')
- self.fields['url'] = url_field
+ if 'url' not in fields:
+ url_field = HyperlinkedIdentityField(
+ view_name=self.opts.view_name,
+ lookup_field=self.opts.lookup_field
+ )
+ ret = self._dict_class()
+ ret['url'] = url_field
+ ret.update(fields)
+ fields = ret
- def _get_default_view_name(self, model):
- """
- Return the view name to use if 'view_name' is not specified in 'Meta'
- """
- model_meta = model._meta
- format_kwargs = {
- 'app_label': model_meta.app_label,
- 'model_name': model_meta.object_name.lower()
- }
- return self._default_view_name % format_kwargs
+ return fields
def get_pk_field(self, model_field):
if self.opts.fields and model_field.name in self.opts.fields:
@@ -1006,3 +995,14 @@ class HyperlinkedModelSerializer(ModelSerializer):
return data.get('url', None)
except AttributeError:
return None
+
+ def _get_default_view_name(self, model):
+ """
+ Return the view name to use if 'view_name' is not specified in 'Meta'
+ """
+ model_meta = model._meta
+ format_kwargs = {
+ 'app_label': model_meta.app_label,
+ 'model_name': model_meta.object_name.lower()
+ }
+ return self._default_view_name % format_kwargs
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py
new file mode 100644
index 00000000..b46d7f54
--- /dev/null
+++ b/rest_framework/tests/description.py
@@ -0,0 +1,26 @@
+# -- coding: utf-8 --
+
+# Apparently there is a python 2.6 issue where docstrings of imported view classes
+# do not retain their encoding information even if a module has a proper
+# encoding declaration at the top of its source file. Therefore for tests
+# to catch unicode related errors, a mock view has to be declared in a separate
+# module.
+
+from rest_framework.views import APIView
+
+
+# test strings snatched from http://www.columbia.edu/~fdc/utf8/,
+# http://winrus.com/utf8-jap.htm and memory
+UTF8_TEST_DOCSTRING = (
+ 'zażółć gęślą jaźń'
+ 'Sîne klâwen durh die wolken sint geslagen'
+ 'Τη γλώσσα μου έδωσαν ελληνική'
+ 'யாமறிந்த மொழிகளிலே தமிழ்மொழி'
+ 'На берегу пустынных волн'
+ 'てすと'
+ 'アイウエオカキクケコサシスセソタチツテ'
+)
+
+
+class ViewWithNonASCIICharactersInDocstring(APIView):
+ __doc__ = UTF8_TEST_DOCSTRING
diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py
index d46ac079..6a50be06 100644
--- a/rest_framework/tests/test_authentication.py
+++ b/rest_framework/tests/test_authentication.py
@@ -428,6 +428,47 @@ class OAuthTests(TestCase):
response = self.csrf_client.post('/oauth-with-scope/', params)
self.assertEqual(response.status_code, 200)
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_bad_consumer_key(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': 'badconsumerkey'
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_bad_token_key(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': 'badtokenkey',
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication"""
diff --git a/rest_framework/tests/test_description.py b/rest_framework/tests/test_description.py
index 52c1a34c..8019f5ec 100644
--- a/rest_framework/tests/test_description.py
+++ b/rest_framework/tests/test_description.py
@@ -2,8 +2,10 @@
from __future__ import unicode_literals
from django.test import TestCase
+from rest_framework.compat import apply_markdown, smart_text
from rest_framework.views import APIView
-from rest_framework.compat import apply_markdown
+from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring
+from rest_framework.tests.description import UTF8_TEST_DOCSTRING
from rest_framework.utils.formatting import get_view_name, get_view_description
# We check that docstrings get nicely un-indented.
@@ -83,11 +85,10 @@ class TestViewNamesAndDescriptions(TestCase):
Unicode in docstrings should be respected.
"""
- class MockView(APIView):
- """Проверка"""
- pass
-
- self.assertEqual(get_view_description(MockView), "Проверка")
+ self.assertEqual(
+ get_view_description(ViewWithNonASCIICharactersInDocstring),
+ smart_text(UTF8_TEST_DOCSTRING)
+ )
def test_view_description_can_be_empty(self):
"""
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index 69a0468e..6836ec86 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -866,3 +866,33 @@ class FieldCallableDefault(TestCase):
into = {}
field.field_from_native({}, {}, 'field', into)
self.assertEqual(into, {'field': 'foo bar'})
+
+
+class CustomIntegerField(TestCase):
+ """
+ Test that custom fields apply min_value and max_value constraints
+ """
+ def test_custom_fields_can_be_validated_for_value(self):
+
+ class MoneyField(models.PositiveIntegerField):
+ pass
+
+ class EntryModel(models.Model):
+ bank = MoneyField(validators=[validators.MaxValueValidator(100)])
+
+ class EntrySerializer(serializers.ModelSerializer):
+ class Meta:
+ model = EntryModel
+
+ entry = EntryModel(bank=1)
+
+ serializer = EntrySerializer(entry, data={"bank": 11})
+ self.assertTrue(serializer.is_valid())
+
+ serializer = EntrySerializer(entry, data={"bank": -1})
+ self.assertFalse(serializer.is_valid())
+
+ serializer = EntrySerializer(entry, data={"bank": 101})
+ self.assertFalse(serializer.is_valid())
+
+
diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py
index 1894ddb2..129600cb 100644
--- a/rest_framework/tests/test_hyperlinkedserializers.py
+++ b/rest_framework/tests/test_hyperlinkedserializers.py
@@ -301,3 +301,30 @@ class TestOptionalRelationHyperlinkedView(TestCase):
data=json.dumps(self.data),
content_type='application/json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class TestOverriddenURLField(TestCase):
+ def setUp(self):
+ class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer):
+ url = serializers.SerializerMethodField('get_url')
+
+ class Meta:
+ model = BlogPost
+ fields = ('title', 'url')
+
+ def get_url(self, obj):
+ return 'foo bar'
+
+ self.Serializer = OverriddenURLSerializer
+ self.obj = BlogPost.objects.create(title='New blog post')
+
+ def test_overridden_url_field(self):
+ """
+ The 'url' field should respect overriding.
+ Regression test for #936.
+ """
+ serializer = self.Serializer(self.obj)
+ self.assertEqual(
+ serializer.data,
+ {'title': 'New blog post', 'url': 'foo bar'}
+ )
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index a7534f70..d375f4a8 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -2,11 +2,12 @@ from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
from django.test.client import RequestFactory
-from rest_framework import serializers, viewsets
+from django.core.exceptions import ImproperlyConfigured
+from rest_framework import serializers, viewsets, permissions
from rest_framework.compat import include, patterns, url
from rest_framework.decorators import link, action
from rest_framework.response import Response
-from rest_framework.routers import SimpleRouter
+from rest_framework.routers import SimpleRouter, DefaultRouter
factory = RequestFactory()
@@ -120,7 +121,7 @@ class TestCustomLookupFields(TestCase):
)
-class TestTrailingSlash(TestCase):
+class TestTrailingSlashIncluded(TestCase):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
model = RouterTestModel
@@ -135,7 +136,7 @@ class TestTrailingSlash(TestCase):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
-class TestTrailingSlash(TestCase):
+class TestTrailingSlashRemoved(TestCase):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
model = RouterTestModel
@@ -148,3 +149,67 @@ class TestTrailingSlash(TestCase):
expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
for idx in range(len(expected)):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestNameableRoot(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+ self.router = DefaultRouter()
+ self.router.root_view_name = 'nameable-root'
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_router_has_custom_name(self):
+ expected = 'nameable-root'
+ self.assertEqual(expected, self.urls[0].name)
+
+
+class TestActionKeywordArgs(TestCase):
+ """
+ Ensure keyword arguments passed in the `@action` decorator
+ are properly handled. Refs #940.
+ """
+
+ def setUp(self):
+ class TestViewSet(viewsets.ModelViewSet):
+ permission_classes = []
+
+ @action(permission_classes=[permissions.AllowAny])
+ def custom(self, request, *args, **kwargs):
+ return Response({
+ 'permission_classes': self.permission_classes
+ })
+
+ self.router = SimpleRouter()
+ self.router.register(r'test', TestViewSet, base_name='test')
+ self.view = self.router.urls[-1].callback
+
+ def test_action_kwargs(self):
+ request = factory.post('/test/0/custom/')
+ response = self.view(request)
+ self.assertEqual(
+ response.data,
+ {'permission_classes': [permissions.AllowAny]}
+ )
+
+class TestActionAppliedToExistingRoute(TestCase):
+ """
+ Ensure `@action` decorator raises an except when applied
+ to an existing route
+ """
+
+ def test_exception_raised_when_action_applied_to_existing_route(self):
+ class TestViewSet(viewsets.ModelViewSet):
+
+ @action()
+ def retrieve(self, request, *args, **kwargs):
+ return Response({
+ 'hello': 'world'
+ })
+
+ self.router = SimpleRouter()
+ self.router.register(r'test', TestViewSet, base_name='test')
+
+ with self.assertRaises(ImproperlyConfigured):
+ self.router.urls
diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py
index da400b2f..d35d3709 100644
--- a/rest_framework/tests/test_throttling.py
+++ b/rest_framework/tests/test_throttling.py
@@ -7,7 +7,7 @@ from django.contrib.auth.models import User
from django.core.cache import cache
from django.test.client import RequestFactory
from rest_framework.views import APIView
-from rest_framework.throttling import UserRateThrottle
+from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle
from rest_framework.response import Response
@@ -36,8 +36,6 @@ class MockView_MinuteThrottling(APIView):
class ThrottlingTests(TestCase):
- urls = 'rest_framework.tests.test_throttling'
-
def setUp(self):
"""
Reset the cache so that no throttles will be active
@@ -141,3 +139,108 @@ class ThrottlingTests(TestCase):
(60, None),
(80, None)
))
+
+
+class ScopedRateThrottleTests(TestCase):
+ """
+ Tests for ScopedRateThrottle.
+ """
+
+ def setUp(self):
+ class XYScopedRateThrottle(ScopedRateThrottle):
+ TIMER_SECONDS = 0
+ THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
+ timer = lambda self: self.TIMER_SECONDS
+
+ class XView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+ throttle_scope = 'x'
+
+ def get(self, request):
+ return Response('x')
+
+ class YView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+ throttle_scope = 'y'
+
+ def get(self, request):
+ return Response('y')
+
+ class UnscopedView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+
+ def get(self, request):
+ return Response('y')
+
+ self.throttle_class = XYScopedRateThrottle
+ self.factory = RequestFactory()
+ self.x_view = XView.as_view()
+ self.y_view = YView.as_view()
+ self.unscoped_view = UnscopedView.as_view()
+
+ def increment_timer(self, seconds=1):
+ self.throttle_class.TIMER_SECONDS += seconds
+
+ def test_scoped_rate_throttle(self):
+ request = self.factory.get('/')
+
+ # Should be able to hit x view 3 times per minute.
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Should be able to hit y view 1 time per minute.
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Ensure throttles properly reset by advancing the rest of the minute
+ self.increment_timer(55)
+
+ # Should still be able to hit x view 3 times per minute.
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Should still be able to hit y view 1 time per minute.
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(429, response.status_code)
+
+ def test_unscoped_view_not_throttled(self):
+ request = self.factory.get('/')
+
+ for idx in range(10):
+ self.increment_timer()
+ response = self.unscoped_view(request)
+ self.assertEqual(200, response.status_code)
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 93ea9816..f6bb1cc8 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -3,7 +3,7 @@ Provides various throttling policies.
"""
from __future__ import unicode_literals
from django.core.cache import cache
-from rest_framework import exceptions
+from django.core.exceptions import ImproperlyConfigured
from rest_framework.settings import api_settings
import time
@@ -40,9 +40,9 @@ class SimpleRateThrottle(BaseThrottle):
"""
timer = time.time
- settings = api_settings
cache_format = 'throtte_%(scope)s_%(ident)s'
scope = None
+ THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
@@ -65,13 +65,13 @@ class SimpleRateThrottle(BaseThrottle):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
- raise exceptions.ConfigurationError(msg)
+ raise ImproperlyConfigured(msg)
try:
- return self.settings.DEFAULT_THROTTLE_RATES[self.scope]
+ return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
- raise exceptions.ConfigurationError(msg)
+ raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
"""
@@ -187,6 +187,27 @@ class ScopedRateThrottle(SimpleRateThrottle):
"""
scope_attr = 'throttle_scope'
+ def __init__(self):
+ # Override the usual SimpleRateThrottle, because we can't determine
+ # the rate until called by the view.
+ pass
+
+ def allow_request(self, request, view):
+ # We can only determine the scope once we're called by the view.
+ self.scope = getattr(view, self.scope_attr, None)
+
+ # If a view does not have a `throttle_scope` always allow the request
+ if not self.scope:
+ return True
+
+ # Determine the allowed request rate as we normally would during
+ # the `__init__` call.
+ self.rate = self.get_rate()
+ self.num_requests, self.duration = self.parse_rate(self.rate)
+
+ # We can now proceed as normal.
+ return super(ScopedRateThrottle, self).allow_request(request, view)
+
def get_cache_key(self, request, view):
"""
If `view.throttle_scope` is not set, don't apply this throttle.
@@ -194,18 +215,12 @@ class ScopedRateThrottle(SimpleRateThrottle):
Otherwise generate the unique cache key by concatenating the user id
with the '.throttle_scope` property of the view.
"""
- scope = getattr(view, self.scope_attr, None)
-
- if not scope:
- # Only throttle views if `.throttle_scope` is set on the view.
- return None
-
if request.user.is_authenticated():
ident = request.user.id
else:
ident = request.META.get('REMOTE_ADDR', None)
return self.cache_format % {
- 'scope': scope,
+ 'scope': self.scope,
'ident': ident
}
diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py
index ebadb3a6..4bec8387 100644
--- a/rest_framework/utils/formatting.py
+++ b/rest_framework/utils/formatting.py
@@ -5,7 +5,7 @@ from __future__ import unicode_literals
from django.utils.html import escape
from django.utils.safestring import mark_safe
-from rest_framework.compat import apply_markdown
+from rest_framework.compat import apply_markdown, smart_text
import re
@@ -63,7 +63,7 @@ def get_view_description(cls, html=False):
Return a description for an `APIView` class or `@api_view` function.
"""
description = cls.__doc__ or ''
- description = _remove_leading_indent(description)
+ description = _remove_leading_indent(smart_text(description))
if html:
return markup_description(description)
return description
diff --git a/rest_framework/views.py b/rest_framework/views.py
index e1b6705b..37bba7f0 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -4,11 +4,11 @@ Provides an APIView class that is the base of all views in REST framework.
from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied
-from django.http import Http404, HttpResponse
+from django.http import Http404
from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions
-from rest_framework.compat import View
+from rest_framework.compat import View, HttpResponseBase
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.settings import api_settings
@@ -244,9 +244,10 @@ class APIView(View):
Returns the final response object.
"""
# Make the error obvious if a proper response is not returned
- assert isinstance(response, HttpResponse), (
- 'Expected a `Response` to be returned from the view, '
- 'but received a `%s`' % type(response)
+ assert isinstance(response, HttpResponseBase), (
+ 'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` '
+ 'to be returned from the view, but received a `%s`'
+ % type(response)
)
if isinstance(response, Response):
@@ -304,10 +305,10 @@ class APIView(View):
`.dispatch()` is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
"""
- request = self.initialize_request(request, *args, **kwargs)
- self.request = request
self.args = args
self.kwargs = kwargs
+ request = self.initialize_request(request, *args, **kwargs)
+ self.request = request
self.headers = self.default_response_headers # deprecate?
try:
@@ -341,8 +342,15 @@ class APIView(View):
Return a dictionary of metadata about the view.
Used to return responses for OPTIONS requests.
"""
+
+ # This is used by ViewSets to disambiguate instance vs list views
+ view_name_suffix = getattr(self, 'suffix', None)
+
+ # By default we can't provide any form-like information, however the
+ # generic views override this implementation and add additional
+ # information for POST and PUT methods, based on the serializer.
ret = SortedDict()
- ret['name'] = get_view_name(self.__class__)
+ ret['name'] = get_view_name(self.__class__, view_name_suffix)
ret['description'] = get_view_description(self.__class__)
ret['renders'] = [renderer.media_type for renderer in self.renderer_classes]
ret['parses'] = [parser.media_type for parser in self.parser_classes]