diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/__init__.py | 2 | ||||
| -rw-r--r-- | rest_framework/authentication.py | 14 | ||||
| -rw-r--r-- | rest_framework/authtoken/models.py | 4 | ||||
| -rw-r--r-- | rest_framework/compat.py | 26 | ||||
| -rw-r--r-- | rest_framework/exceptions.py | 7 | ||||
| -rw-r--r-- | rest_framework/fields.py | 51 | ||||
| -rw-r--r-- | rest_framework/generics.py | 4 | ||||
| -rw-r--r-- | rest_framework/permissions.py | 2 | ||||
| -rw-r--r-- | rest_framework/renderers.py | 4 | ||||
| -rw-r--r-- | rest_framework/routers.py | 20 | ||||
| -rw-r--r-- | rest_framework/runtests/settings.py | 2 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 46 | ||||
| -rw-r--r-- | rest_framework/tests/description.py | 26 | ||||
| -rw-r--r-- | rest_framework/tests/test_authentication.py | 41 | ||||
| -rw-r--r-- | rest_framework/tests/test_description.py | 13 | ||||
| -rw-r--r-- | rest_framework/tests/test_fields.py | 30 | ||||
| -rw-r--r-- | rest_framework/tests/test_hyperlinkedserializers.py | 27 | ||||
| -rw-r--r-- | rest_framework/tests/test_routers.py | 73 | ||||
| -rw-r--r-- | rest_framework/tests/test_throttling.py | 109 | ||||
| -rw-r--r-- | rest_framework/throttling.py | 39 | ||||
| -rw-r--r-- | rest_framework/utils/formatting.py | 4 | ||||
| -rw-r--r-- | rest_framework/views.py | 24 |
22 files changed, 449 insertions, 119 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 0a210186..776618ac 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.3.5' +__version__ = '2.3.6' VERSION = __version__ # synonym diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 9caca788..10298027 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -3,14 +3,13 @@ Provides various authentication policies. """ from __future__ import unicode_literals import base64 -from datetime import datetime from django.contrib.auth import authenticate from django.core.exceptions import ImproperlyConfigured from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import oauth, oauth_provider, oauth_provider_store -from rest_framework.compat import oauth2_provider +from rest_framework.compat import oauth2_provider, provider_now from rest_framework.authtoken.models import Token @@ -230,8 +229,9 @@ class OAuthAuthentication(BaseAuthentication): try: consumer_key = oauth_request.get_parameter('oauth_consumer_key') consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key) - except oauth_provider.store.InvalidConsumerError as err: - raise exceptions.AuthenticationFailed(err) + except oauth_provider.store.InvalidConsumerError: + msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key') + raise exceptions.AuthenticationFailed(msg) if consumer.status != oauth_provider.consts.ACCEPTED: msg = 'Invalid consumer key status: %s' % consumer.get_status_display() @@ -319,9 +319,9 @@ class OAuth2Authentication(BaseAuthentication): try: token = oauth2_provider.models.AccessToken.objects.select_related('user') - # TODO: Change to timezone aware datetime when oauth2_provider add - # support to it. - token = token.get(token=access_token, expires__gt=datetime.now()) + # provider_now switches to timezone aware datetime when + # the oauth2_provider version supports to it. + token = token.get(token=access_token, expires__gt=provider_now()) except oauth2_provider.models.AccessToken.DoesNotExist: raise exceptions.AuthenticationFailed('Invalid token') diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py index 52c45ad1..7601f5b7 100644 --- a/rest_framework/authtoken/models.py +++ b/rest_framework/authtoken/models.py @@ -1,7 +1,7 @@ import uuid import hmac from hashlib import sha1 -from rest_framework.compat import User +from rest_framework.compat import AUTH_USER_MODEL from django.conf import settings from django.db import models @@ -11,7 +11,7 @@ class Token(models.Model): The default authorization token model. """ key = models.CharField(max_length=40, primary_key=True) - user = models.OneToOneField(User, related_name='auth_token') + user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token') created = models.DateTimeField(auto_now_add=True) class Meta: diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 76dc0052..b748dcc5 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -2,6 +2,7 @@ The `compat` module provides support for backwards compatibility with older versions of django/python, and compatibility wrappers around optional packages. """ + # flake8: noqa from __future__ import unicode_literals @@ -33,6 +34,12 @@ except ImportError: from django.utils.encoding import force_unicode as force_text +# HttpResponseBase only exists from 1.5 onwards +try: + from django.http.response import HttpResponseBase +except ImportError: + from django.http import HttpResponse as HttpResponseBase + # django-filter is optional try: import django_filters @@ -77,15 +84,9 @@ def get_concrete_model(model_cls): # Django 1.5 add support for custom auth user model if django.VERSION >= (1, 5): from django.conf import settings - if hasattr(settings, 'AUTH_USER_MODEL'): - User = settings.AUTH_USER_MODEL - else: - from django.contrib.auth.models import User + AUTH_USER_MODEL = settings.AUTH_USER_MODEL else: - try: - from django.contrib.auth.models import User - except ImportError: - raise ImportError("User model is not to be found.") + AUTH_USER_MODEL = 'auth.User' if django.VERSION >= (1, 5): @@ -489,12 +490,21 @@ try: from provider.oauth2 import forms as oauth2_provider_forms from provider import scope as oauth2_provider_scope from provider import constants as oauth2_constants + from provider import __version__ as provider_version + if provider_version in ('0.2.3', '0.2.4'): + # 0.2.3 and 0.2.4 are supported version that do not support + # timezone aware datetimes + from datetime.datetime import now as provider_now + else: + # Any other supported version does use timezone aware datetimes + from django.utils.timezone import now as provider_now except ImportError: oauth2_provider = None oauth2_provider_models = None oauth2_provider_forms = None oauth2_provider_scope = None oauth2_constants = None + provider_now = None # Handle lazy strings from django.utils.functional import Promise diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 0c96ecdd..425a7214 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -86,10 +86,3 @@ class Throttled(APIException): self.detail = format % (self.wait, self.wait != 1 and 's' or '') else: self.detail = detail or self.default_detail - - -class ConfigurationError(Exception): - """ - Indicates an internal server error. - """ - pass diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 535aa2ac..35848b4c 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -7,25 +7,24 @@ from __future__ import unicode_literals import copy import datetime -from decimal import Decimal, DecimalException import inspect import re import warnings +from decimal import Decimal, DecimalException +from django import forms from django.core import validators from django.core.exceptions import ValidationError from django.conf import settings from django.db.models.fields import BLANK_CHOICE_DASH -from django import forms from django.forms import widgets from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ from django.utils.datastructures import SortedDict from rest_framework import ISO_8601 -from rest_framework.compat import (timezone, parse_date, parse_datetime, - parse_time) -from rest_framework.compat import BytesIO -from rest_framework.compat import six -from rest_framework.compat import smart_text, force_text, is_non_str_iterable +from rest_framework.compat import ( + timezone, parse_date, parse_datetime, parse_time, BytesIO, six, smart_text, + force_text, is_non_str_iterable +) from rest_framework.settings import api_settings @@ -256,6 +255,12 @@ class WritableField(Field): widget = widget() self.widget = widget + def __deepcopy__(self, memo): + result = copy.copy(self) + memo[id(self)] = result + result.validators = self.validators[:] + return result + def validate(self, value): if value in validators.EMPTY_VALUES and self.required: raise ValidationError(self.error_messages['required']) @@ -331,9 +336,13 @@ class ModelField(WritableField): raise ValueError("ModelField requires 'model_field' kwarg") self.min_length = kwargs.pop('min_length', - getattr(self.model_field, 'min_length', None)) + getattr(self.model_field, 'min_length', None)) self.max_length = kwargs.pop('max_length', - getattr(self.model_field, 'max_length', None)) + getattr(self.model_field, 'max_length', None)) + self.min_value = kwargs.pop('min_value', + getattr(self.model_field, 'min_value', None)) + self.max_value = kwargs.pop('max_value', + getattr(self.model_field, 'max_value', None)) super(ModelField, self).__init__(*args, **kwargs) @@ -341,6 +350,10 @@ class ModelField(WritableField): self.validators.append(validators.MinLengthValidator(self.min_length)) if self.max_length is not None: self.validators.append(validators.MaxLengthValidator(self.max_length)) + if self.min_value is not None: + self.validators.append(validators.MinValueValidator(self.min_value)) + if self.max_value is not None: + self.validators.append(validators.MaxValueValidator(self.max_value)) def from_native(self, value): rel = getattr(self.model_field, "rel", None) @@ -428,13 +441,6 @@ class SlugField(CharField): def __init__(self, *args, **kwargs): super(SlugField, self).__init__(*args, **kwargs) - def __deepcopy__(self, memo): - result = copy.copy(self) - memo[id(self)] = result - #result.widget = copy.deepcopy(self.widget, memo) - result.validators = self.validators[:] - return result - class ChoiceField(WritableField): type_name = 'ChoiceField' @@ -503,13 +509,6 @@ class EmailField(CharField): return None return ret.strip() - def __deepcopy__(self, memo): - result = copy.copy(self) - memo[id(self)] = result - #result.widget = copy.deepcopy(self.widget, memo) - result.validators = self.validators[:] - return result - class RegexField(CharField): type_name = 'RegexField' @@ -534,12 +533,6 @@ class RegexField(CharField): regex = property(_get_regex, _set_regex) - def __deepcopy__(self, memo): - result = copy.copy(self) - memo[id(self)] = result - result.validators = self.validators[:] - return result - class DateField(WritableField): type_name = 'DateField' diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 9ccc7898..99e9782e 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -212,7 +212,7 @@ class GenericAPIView(views.APIView): You may want to override this if you need to provide different serializations depending on the incoming request. - (Eg. admins get full serialization, others get basic serilization) + (Eg. admins get full serialization, others get basic serialization) """ serializer_class = self.serializer_class if serializer_class is not None: @@ -285,7 +285,7 @@ class GenericAPIView(views.APIView): ) filter_kwargs = {self.slug_field: slug} else: - raise exceptions.ConfigurationError( + raise ImproperlyConfigured( 'Expected view %s to be called with a URL keyword argument ' 'named "%s". Fix your URL conf, or set the `.lookup_field` ' 'attribute on the view correctly.' % diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 45fcfd66..1036663e 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -128,7 +128,7 @@ class DjangoModelPermissions(BasePermission): # Workaround to ensure DjangoModelPermissions are not applied # to the root view when using DefaultRouter. - if model_cls is None and getattr(view, '_ignore_model_permissions'): + if model_cls is None and getattr(view, '_ignore_model_permissions', False): return True assert model_cls, ('Cannot apply DjangoModelPermissions on a view that' diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index b2fe43ea..8b2428ad 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -11,6 +11,7 @@ from __future__ import unicode_literals import copy import json from django import forms +from django.core.exceptions import ImproperlyConfigured from django.http.multipartparser import parse_header from django.template import RequestContext, loader, Template from django.utils.xmlutils import SimplerXMLGenerator @@ -18,7 +19,6 @@ from rest_framework.compat import StringIO from rest_framework.compat import six from rest_framework.compat import smart_text from rest_framework.compat import yaml -from rest_framework.exceptions import ConfigurationError from rest_framework.settings import api_settings from rest_framework.request import clone_request from rest_framework.utils import encoders @@ -270,7 +270,7 @@ class TemplateHTMLRenderer(BaseRenderer): return [self.template_name] elif hasattr(view, 'get_template_names'): return view.get_template_names() - raise ConfigurationError('Returned a template response with no template_name') + raise ImproperlyConfigured('Returned a template response with no template_name') def get_exception_template(self, response): template_names = [name % {'status_code': response.status_code} diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 9764e569..930011d3 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -15,7 +15,9 @@ For example, you might have a `urls.py` that looks something like this: """ from __future__ import unicode_literals +import itertools from collections import namedtuple +from django.core.exceptions import ImproperlyConfigured from rest_framework import views from rest_framework.compat import patterns, url from rest_framework.response import Response @@ -38,6 +40,13 @@ def replace_methodname(format_string, methodname): return ret +def flatten(list_of_lists): + """ + Takes an iterable of iterables, returns a single iterable containing all items + """ + return itertools.chain(*list_of_lists) + + class BaseRouter(object): def __init__(self): self.registry = [] @@ -117,7 +126,7 @@ class SimpleRouter(BaseRouter): if model_cls is None and queryset is not None: model_cls = queryset.model - assert model_cls, '`name` not argument not specified, and could ' \ + assert model_cls, '`base_name` argument not specified, and could ' \ 'not automatically determine the name from the viewset, as ' \ 'it does not have a `.model` or `.queryset` attribute.' @@ -130,12 +139,18 @@ class SimpleRouter(BaseRouter): Returns a list of the Route namedtuple. """ + known_actions = flatten([route.mapping.values() for route in self.routes]) + # Determine any `@action` or `@link` decorated methods on the viewset dynamic_routes = [] for methodname in dir(viewset): attr = getattr(viewset, methodname) httpmethods = getattr(attr, 'bind_to_methods', None) if httpmethods: + if methodname in known_actions: + raise ImproperlyConfigured('Cannot use @action or @link decorator on ' + 'method "%s" as it is an existing route' % methodname) + httpmethods = [method.lower() for method in httpmethods] dynamic_routes.append((httpmethods, methodname)) ret = [] @@ -215,6 +230,7 @@ class DefaultRouter(SimpleRouter): """ include_root_view = True include_format_suffixes = True + root_view_name = 'api-root' def get_api_root_view(self): """ @@ -244,7 +260,7 @@ class DefaultRouter(SimpleRouter): urls = [] if self.include_root_view: - root_url = url(r'^$', self.get_api_root_view(), name='api-root') + root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name) urls.append(root_url) default_urls = super(DefaultRouter, self).get_urls() diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index 9dd7b545..b3702d0b 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -134,6 +134,8 @@ PASSWORD_HASHERS = ( 'django.contrib.auth.hashers.CryptPasswordHasher', ) +AUTH_USER_MODEL = 'auth.User' + import django if django.VERSION < (1, 3): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4acbc704..d8f9145e 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -944,34 +944,23 @@ class HyperlinkedModelSerializer(ModelSerializer): _default_view_name = '%(model_name)s-detail' _hyperlink_field_class = HyperlinkedRelatedField - # Just a placeholder to ensure 'url' is the first field - # The field itself is actually created on initialization, - # when the view_name and lookup_field arguments are available. - url = Field() - - def __init__(self, *args, **kwargs): - super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs) + def get_default_fields(self): + fields = super(HyperlinkedModelSerializer, self).get_default_fields() if self.opts.view_name is None: self.opts.view_name = self._get_default_view_name(self.opts.model) - url_field = HyperlinkedIdentityField( - view_name=self.opts.view_name, - lookup_field=self.opts.lookup_field - ) - url_field.initialize(self, 'url') - self.fields['url'] = url_field + if 'url' not in fields: + url_field = HyperlinkedIdentityField( + view_name=self.opts.view_name, + lookup_field=self.opts.lookup_field + ) + ret = self._dict_class() + ret['url'] = url_field + ret.update(fields) + fields = ret - def _get_default_view_name(self, model): - """ - Return the view name to use if 'view_name' is not specified in 'Meta' - """ - model_meta = model._meta - format_kwargs = { - 'app_label': model_meta.app_label, - 'model_name': model_meta.object_name.lower() - } - return self._default_view_name % format_kwargs + return fields def get_pk_field(self, model_field): if self.opts.fields and model_field.name in self.opts.fields: @@ -1006,3 +995,14 @@ class HyperlinkedModelSerializer(ModelSerializer): return data.get('url', None) except AttributeError: return None + + def _get_default_view_name(self, model): + """ + Return the view name to use if 'view_name' is not specified in 'Meta' + """ + model_meta = model._meta + format_kwargs = { + 'app_label': model_meta.app_label, + 'model_name': model_meta.object_name.lower() + } + return self._default_view_name % format_kwargs diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py new file mode 100644 index 00000000..b46d7f54 --- /dev/null +++ b/rest_framework/tests/description.py @@ -0,0 +1,26 @@ +# -- coding: utf-8 -- + +# Apparently there is a python 2.6 issue where docstrings of imported view classes +# do not retain their encoding information even if a module has a proper +# encoding declaration at the top of its source file. Therefore for tests +# to catch unicode related errors, a mock view has to be declared in a separate +# module. + +from rest_framework.views import APIView + + +# test strings snatched from http://www.columbia.edu/~fdc/utf8/, +# http://winrus.com/utf8-jap.htm and memory +UTF8_TEST_DOCSTRING = ( + 'zażółć gęślą jaźń' + 'Sîne klâwen durh die wolken sint geslagen' + 'Τη γλώσσα μου έδωσαν ελληνική' + 'யாமறிந்த மொழிகளிலே தமிழ்மொழி' + 'На берегу пустынных волн' + 'てすと' + 'アイウエオカキクケコサシスセソタチツテ' +) + + +class ViewWithNonASCIICharactersInDocstring(APIView): + __doc__ = UTF8_TEST_DOCSTRING diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index d46ac079..6a50be06 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -428,6 +428,47 @@ class OAuthTests(TestCase): response = self.csrf_client.post('/oauth-with-scope/', params) self.assertEqual(response.status_code, 200) + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_bad_consumer_key(self): + """Ensure POSTing using HMAC_SHA1 signature method passes""" + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': self.token.key, + 'oauth_consumer_key': 'badconsumerkey' + } + + req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + + signature_method = oauth.SignatureMethod_HMAC_SHA1() + req.sign_request(signature_method, self.consumer, self.token) + auth = req.to_header()["Authorization"] + + response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_bad_token_key(self): + """Ensure POSTing using HMAC_SHA1 signature method passes""" + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': 'badtokenkey', + 'oauth_consumer_key': self.consumer.key + } + + req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + + signature_method = oauth.SignatureMethod_HMAC_SHA1() + req.sign_request(signature_method, self.consumer, self.token) + auth = req.to_header()["Authorization"] + + response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) class OAuth2Tests(TestCase): """OAuth 2.0 authentication""" diff --git a/rest_framework/tests/test_description.py b/rest_framework/tests/test_description.py index 52c1a34c..8019f5ec 100644 --- a/rest_framework/tests/test_description.py +++ b/rest_framework/tests/test_description.py @@ -2,8 +2,10 @@ from __future__ import unicode_literals from django.test import TestCase +from rest_framework.compat import apply_markdown, smart_text from rest_framework.views import APIView -from rest_framework.compat import apply_markdown +from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring +from rest_framework.tests.description import UTF8_TEST_DOCSTRING from rest_framework.utils.formatting import get_view_name, get_view_description # We check that docstrings get nicely un-indented. @@ -83,11 +85,10 @@ class TestViewNamesAndDescriptions(TestCase): Unicode in docstrings should be respected. """ - class MockView(APIView): - """Проверка""" - pass - - self.assertEqual(get_view_description(MockView), "Проверка") + self.assertEqual( + get_view_description(ViewWithNonASCIICharactersInDocstring), + smart_text(UTF8_TEST_DOCSTRING) + ) def test_view_description_can_be_empty(self): """ diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index 69a0468e..6836ec86 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -866,3 +866,33 @@ class FieldCallableDefault(TestCase): into = {} field.field_from_native({}, {}, 'field', into) self.assertEqual(into, {'field': 'foo bar'}) + + +class CustomIntegerField(TestCase): + """ + Test that custom fields apply min_value and max_value constraints + """ + def test_custom_fields_can_be_validated_for_value(self): + + class MoneyField(models.PositiveIntegerField): + pass + + class EntryModel(models.Model): + bank = MoneyField(validators=[validators.MaxValueValidator(100)]) + + class EntrySerializer(serializers.ModelSerializer): + class Meta: + model = EntryModel + + entry = EntryModel(bank=1) + + serializer = EntrySerializer(entry, data={"bank": 11}) + self.assertTrue(serializer.is_valid()) + + serializer = EntrySerializer(entry, data={"bank": -1}) + self.assertFalse(serializer.is_valid()) + + serializer = EntrySerializer(entry, data={"bank": 101}) + self.assertFalse(serializer.is_valid()) + + diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py index 1894ddb2..129600cb 100644 --- a/rest_framework/tests/test_hyperlinkedserializers.py +++ b/rest_framework/tests/test_hyperlinkedserializers.py @@ -301,3 +301,30 @@ class TestOptionalRelationHyperlinkedView(TestCase): data=json.dumps(self.data), content_type='application/json') self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class TestOverriddenURLField(TestCase): + def setUp(self): + class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer): + url = serializers.SerializerMethodField('get_url') + + class Meta: + model = BlogPost + fields = ('title', 'url') + + def get_url(self, obj): + return 'foo bar' + + self.Serializer = OverriddenURLSerializer + self.obj = BlogPost.objects.create(title='New blog post') + + def test_overridden_url_field(self): + """ + The 'url' field should respect overriding. + Regression test for #936. + """ + serializer = self.Serializer(self.obj) + self.assertEqual( + serializer.data, + {'title': 'New blog post', 'url': 'foo bar'} + ) diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index a7534f70..d375f4a8 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -2,11 +2,12 @@ from __future__ import unicode_literals from django.db import models from django.test import TestCase from django.test.client import RequestFactory -from rest_framework import serializers, viewsets +from django.core.exceptions import ImproperlyConfigured +from rest_framework import serializers, viewsets, permissions from rest_framework.compat import include, patterns, url from rest_framework.decorators import link, action from rest_framework.response import Response -from rest_framework.routers import SimpleRouter +from rest_framework.routers import SimpleRouter, DefaultRouter factory = RequestFactory() @@ -120,7 +121,7 @@ class TestCustomLookupFields(TestCase): ) -class TestTrailingSlash(TestCase): +class TestTrailingSlashIncluded(TestCase): def setUp(self): class NoteViewSet(viewsets.ModelViewSet): model = RouterTestModel @@ -135,7 +136,7 @@ class TestTrailingSlash(TestCase): self.assertEqual(expected[idx], self.urls[idx].regex.pattern) -class TestTrailingSlash(TestCase): +class TestTrailingSlashRemoved(TestCase): def setUp(self): class NoteViewSet(viewsets.ModelViewSet): model = RouterTestModel @@ -148,3 +149,67 @@ class TestTrailingSlash(TestCase): expected = ['^notes$', '^notes/(?P<pk>[^/]+)$'] for idx in range(len(expected)): self.assertEqual(expected[idx], self.urls[idx].regex.pattern) + + +class TestNameableRoot(TestCase): + def setUp(self): + class NoteViewSet(viewsets.ModelViewSet): + model = RouterTestModel + self.router = DefaultRouter() + self.router.root_view_name = 'nameable-root' + self.router.register(r'notes', NoteViewSet) + self.urls = self.router.urls + + def test_router_has_custom_name(self): + expected = 'nameable-root' + self.assertEqual(expected, self.urls[0].name) + + +class TestActionKeywordArgs(TestCase): + """ + Ensure keyword arguments passed in the `@action` decorator + are properly handled. Refs #940. + """ + + def setUp(self): + class TestViewSet(viewsets.ModelViewSet): + permission_classes = [] + + @action(permission_classes=[permissions.AllowAny]) + def custom(self, request, *args, **kwargs): + return Response({ + 'permission_classes': self.permission_classes + }) + + self.router = SimpleRouter() + self.router.register(r'test', TestViewSet, base_name='test') + self.view = self.router.urls[-1].callback + + def test_action_kwargs(self): + request = factory.post('/test/0/custom/') + response = self.view(request) + self.assertEqual( + response.data, + {'permission_classes': [permissions.AllowAny]} + ) + +class TestActionAppliedToExistingRoute(TestCase): + """ + Ensure `@action` decorator raises an except when applied + to an existing route + """ + + def test_exception_raised_when_action_applied_to_existing_route(self): + class TestViewSet(viewsets.ModelViewSet): + + @action() + def retrieve(self, request, *args, **kwargs): + return Response({ + 'hello': 'world' + }) + + self.router = SimpleRouter() + self.router.register(r'test', TestViewSet, base_name='test') + + with self.assertRaises(ImproperlyConfigured): + self.router.urls diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py index da400b2f..d35d3709 100644 --- a/rest_framework/tests/test_throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -7,7 +7,7 @@ from django.contrib.auth.models import User from django.core.cache import cache from django.test.client import RequestFactory from rest_framework.views import APIView -from rest_framework.throttling import UserRateThrottle +from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle from rest_framework.response import Response @@ -36,8 +36,6 @@ class MockView_MinuteThrottling(APIView): class ThrottlingTests(TestCase): - urls = 'rest_framework.tests.test_throttling' - def setUp(self): """ Reset the cache so that no throttles will be active @@ -141,3 +139,108 @@ class ThrottlingTests(TestCase): (60, None), (80, None) )) + + +class ScopedRateThrottleTests(TestCase): + """ + Tests for ScopedRateThrottle. + """ + + def setUp(self): + class XYScopedRateThrottle(ScopedRateThrottle): + TIMER_SECONDS = 0 + THROTTLE_RATES = {'x': '3/min', 'y': '1/min'} + timer = lambda self: self.TIMER_SECONDS + + class XView(APIView): + throttle_classes = (XYScopedRateThrottle,) + throttle_scope = 'x' + + def get(self, request): + return Response('x') + + class YView(APIView): + throttle_classes = (XYScopedRateThrottle,) + throttle_scope = 'y' + + def get(self, request): + return Response('y') + + class UnscopedView(APIView): + throttle_classes = (XYScopedRateThrottle,) + + def get(self, request): + return Response('y') + + self.throttle_class = XYScopedRateThrottle + self.factory = RequestFactory() + self.x_view = XView.as_view() + self.y_view = YView.as_view() + self.unscoped_view = UnscopedView.as_view() + + def increment_timer(self, seconds=1): + self.throttle_class.TIMER_SECONDS += seconds + + def test_scoped_rate_throttle(self): + request = self.factory.get('/') + + # Should be able to hit x view 3 times per minute. + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(429, response.status_code) + + # Should be able to hit y view 1 time per minute. + self.increment_timer() + response = self.y_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.y_view(request) + self.assertEqual(429, response.status_code) + + # Ensure throttles properly reset by advancing the rest of the minute + self.increment_timer(55) + + # Should still be able to hit x view 3 times per minute. + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(429, response.status_code) + + # Should still be able to hit y view 1 time per minute. + self.increment_timer() + response = self.y_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.y_view(request) + self.assertEqual(429, response.status_code) + + def test_unscoped_view_not_throttled(self): + request = self.factory.get('/') + + for idx in range(10): + self.increment_timer() + response = self.unscoped_view(request) + self.assertEqual(200, response.status_code) diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 93ea9816..f6bb1cc8 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -3,7 +3,7 @@ Provides various throttling policies. """ from __future__ import unicode_literals from django.core.cache import cache -from rest_framework import exceptions +from django.core.exceptions import ImproperlyConfigured from rest_framework.settings import api_settings import time @@ -40,9 +40,9 @@ class SimpleRateThrottle(BaseThrottle): """ timer = time.time - settings = api_settings cache_format = 'throtte_%(scope)s_%(ident)s' scope = None + THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES def __init__(self): if not getattr(self, 'rate', None): @@ -65,13 +65,13 @@ class SimpleRateThrottle(BaseThrottle): if not getattr(self, 'scope', None): msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % self.__class__.__name__) - raise exceptions.ConfigurationError(msg) + raise ImproperlyConfigured(msg) try: - return self.settings.DEFAULT_THROTTLE_RATES[self.scope] + return self.THROTTLE_RATES[self.scope] except KeyError: msg = "No default throttle rate set for '%s' scope" % self.scope - raise exceptions.ConfigurationError(msg) + raise ImproperlyConfigured(msg) def parse_rate(self, rate): """ @@ -187,6 +187,27 @@ class ScopedRateThrottle(SimpleRateThrottle): """ scope_attr = 'throttle_scope' + def __init__(self): + # Override the usual SimpleRateThrottle, because we can't determine + # the rate until called by the view. + pass + + def allow_request(self, request, view): + # We can only determine the scope once we're called by the view. + self.scope = getattr(view, self.scope_attr, None) + + # If a view does not have a `throttle_scope` always allow the request + if not self.scope: + return True + + # Determine the allowed request rate as we normally would during + # the `__init__` call. + self.rate = self.get_rate() + self.num_requests, self.duration = self.parse_rate(self.rate) + + # We can now proceed as normal. + return super(ScopedRateThrottle, self).allow_request(request, view) + def get_cache_key(self, request, view): """ If `view.throttle_scope` is not set, don't apply this throttle. @@ -194,18 +215,12 @@ class ScopedRateThrottle(SimpleRateThrottle): Otherwise generate the unique cache key by concatenating the user id with the '.throttle_scope` property of the view. """ - scope = getattr(view, self.scope_attr, None) - - if not scope: - # Only throttle views if `.throttle_scope` is set on the view. - return None - if request.user.is_authenticated(): ident = request.user.id else: ident = request.META.get('REMOTE_ADDR', None) return self.cache_format % { - 'scope': scope, + 'scope': self.scope, 'ident': ident } diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index ebadb3a6..4bec8387 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -5,7 +5,7 @@ from __future__ import unicode_literals from django.utils.html import escape from django.utils.safestring import mark_safe -from rest_framework.compat import apply_markdown +from rest_framework.compat import apply_markdown, smart_text import re @@ -63,7 +63,7 @@ def get_view_description(cls, html=False): Return a description for an `APIView` class or `@api_view` function. """ description = cls.__doc__ or '' - description = _remove_leading_indent(description) + description = _remove_leading_indent(smart_text(description)) if html: return markup_description(description) return description diff --git a/rest_framework/views.py b/rest_framework/views.py index e1b6705b..37bba7f0 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -4,11 +4,11 @@ Provides an APIView class that is the base of all views in REST framework. from __future__ import unicode_literals from django.core.exceptions import PermissionDenied -from django.http import Http404, HttpResponse +from django.http import Http404 from django.utils.datastructures import SortedDict from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions -from rest_framework.compat import View +from rest_framework.compat import View, HttpResponseBase from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings @@ -244,9 +244,10 @@ class APIView(View): Returns the final response object. """ # Make the error obvious if a proper response is not returned - assert isinstance(response, HttpResponse), ( - 'Expected a `Response` to be returned from the view, ' - 'but received a `%s`' % type(response) + assert isinstance(response, HttpResponseBase), ( + 'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` ' + 'to be returned from the view, but received a `%s`' + % type(response) ) if isinstance(response, Response): @@ -304,10 +305,10 @@ class APIView(View): `.dispatch()` is pretty much the same as Django's regular dispatch, but with extra hooks for startup, finalize, and exception handling. """ - request = self.initialize_request(request, *args, **kwargs) - self.request = request self.args = args self.kwargs = kwargs + request = self.initialize_request(request, *args, **kwargs) + self.request = request self.headers = self.default_response_headers # deprecate? try: @@ -341,8 +342,15 @@ class APIView(View): Return a dictionary of metadata about the view. Used to return responses for OPTIONS requests. """ + + # This is used by ViewSets to disambiguate instance vs list views + view_name_suffix = getattr(self, 'suffix', None) + + # By default we can't provide any form-like information, however the + # generic views override this implementation and add additional + # information for POST and PUT methods, based on the serializer. ret = SortedDict() - ret['name'] = get_view_name(self.__class__) + ret['name'] = get_view_name(self.__class__, view_name_suffix) ret['description'] = get_view_description(self.__class__) ret['renders'] = [renderer.media_type for renderer in self.renderer_classes] ret['parses'] = [parser.media_type for parser in self.parser_classes] |
