aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorEthan Fremen2013-06-10 10:29:25 -0700
committerEthan Fremen2013-06-10 10:29:25 -0700
commit0e75bcd2592caa8862e0b0166a6b851a3eada749 (patch)
tree4be625563775f1cab06219d3a7a5a62c0ff74bfd /rest_framework
parentde00ec95c3007dd90b5b01f7486b430699ea63c1 (diff)
parent5d0aeef69ecec70242513195c19edcb622e14371 (diff)
downloaddjango-rest-framework-0e75bcd2592caa8862e0b0166a6b851a3eada749.tar.bz2
Merge remote-tracking branch 'upstream/master' into writable-nested-modelserializer
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/authentication.py5
-rw-r--r--rest_framework/exceptions.py7
-rw-r--r--rest_framework/fields.py39
-rw-r--r--rest_framework/generics.py2
-rw-r--r--rest_framework/renderers.py4
-rw-r--r--rest_framework/routers.py3
-rw-r--r--rest_framework/tests/test_authentication.py41
-rw-r--r--rest_framework/tests/test_routers.py16
-rw-r--r--rest_framework/throttling.py6
-rw-r--r--rest_framework/views.py13
10 files changed, 89 insertions, 47 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 9caca788..f659a172 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -230,8 +230,9 @@ class OAuthAuthentication(BaseAuthentication):
try:
consumer_key = oauth_request.get_parameter('oauth_consumer_key')
consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
- except oauth_provider.store.InvalidConsumerError as err:
- raise exceptions.AuthenticationFailed(err)
+ except oauth_provider.store.InvalidConsumerError:
+ msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key')
+ raise exceptions.AuthenticationFailed(msg)
if consumer.status != oauth_provider.consts.ACCEPTED:
msg = 'Invalid consumer key status: %s' % consumer.get_status_display()
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 0c96ecdd..425a7214 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -86,10 +86,3 @@ class Throttled(APIException):
self.detail = format % (self.wait, self.wait != 1 and 's' or '')
else:
self.detail = detail or self.default_detail
-
-
-class ConfigurationError(Exception):
- """
- Indicates an internal server error.
- """
- pass
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 535aa2ac..32e4c4ae 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -7,25 +7,24 @@ from __future__ import unicode_literals
import copy
import datetime
-from decimal import Decimal, DecimalException
import inspect
import re
import warnings
+from decimal import Decimal, DecimalException
+from django import forms
from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
from django.db.models.fields import BLANK_CHOICE_DASH
-from django import forms
from django.forms import widgets
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
from django.utils.datastructures import SortedDict
from rest_framework import ISO_8601
-from rest_framework.compat import (timezone, parse_date, parse_datetime,
- parse_time)
-from rest_framework.compat import BytesIO
-from rest_framework.compat import six
-from rest_framework.compat import smart_text, force_text, is_non_str_iterable
+from rest_framework.compat import (
+ timezone, parse_date, parse_datetime, parse_time, BytesIO, six, smart_text,
+ force_text, is_non_str_iterable
+)
from rest_framework.settings import api_settings
@@ -256,6 +255,12 @@ class WritableField(Field):
widget = widget()
self.widget = widget
+ def __deepcopy__(self, memo):
+ result = copy.copy(self)
+ memo[id(self)] = result
+ result.validators = self.validators[:]
+ return result
+
def validate(self, value):
if value in validators.EMPTY_VALUES and self.required:
raise ValidationError(self.error_messages['required'])
@@ -428,13 +433,6 @@ class SlugField(CharField):
def __init__(self, *args, **kwargs):
super(SlugField, self).__init__(*args, **kwargs)
- def __deepcopy__(self, memo):
- result = copy.copy(self)
- memo[id(self)] = result
- #result.widget = copy.deepcopy(self.widget, memo)
- result.validators = self.validators[:]
- return result
-
class ChoiceField(WritableField):
type_name = 'ChoiceField'
@@ -503,13 +501,6 @@ class EmailField(CharField):
return None
return ret.strip()
- def __deepcopy__(self, memo):
- result = copy.copy(self)
- memo[id(self)] = result
- #result.widget = copy.deepcopy(self.widget, memo)
- result.validators = self.validators[:]
- return result
-
class RegexField(CharField):
type_name = 'RegexField'
@@ -534,12 +525,6 @@ class RegexField(CharField):
regex = property(_get_regex, _set_regex)
- def __deepcopy__(self, memo):
- result = copy.copy(self)
- memo[id(self)] = result
- result.validators = self.validators[:]
- return result
-
class DateField(WritableField):
type_name = 'DateField'
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 9ccc7898..80efad01 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -285,7 +285,7 @@ class GenericAPIView(views.APIView):
)
filter_kwargs = {self.slug_field: slug}
else:
- raise exceptions.ConfigurationError(
+ raise ImproperlyConfigured(
'Expected view %s to be called with a URL keyword argument '
'named "%s". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.' %
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index b2fe43ea..8b2428ad 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -11,6 +11,7 @@ from __future__ import unicode_literals
import copy
import json
from django import forms
+from django.core.exceptions import ImproperlyConfigured
from django.http.multipartparser import parse_header
from django.template import RequestContext, loader, Template
from django.utils.xmlutils import SimplerXMLGenerator
@@ -18,7 +19,6 @@ from rest_framework.compat import StringIO
from rest_framework.compat import six
from rest_framework.compat import smart_text
from rest_framework.compat import yaml
-from rest_framework.exceptions import ConfigurationError
from rest_framework.settings import api_settings
from rest_framework.request import clone_request
from rest_framework.utils import encoders
@@ -270,7 +270,7 @@ class TemplateHTMLRenderer(BaseRenderer):
return [self.template_name]
elif hasattr(view, 'get_template_names'):
return view.get_template_names()
- raise ConfigurationError('Returned a template response with no template_name')
+ raise ImproperlyConfigured('Returned a template response with no template_name')
def get_exception_template(self, response):
template_names = [name % {'status_code': response.status_code}
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 9764e569..f70c2cdb 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -215,6 +215,7 @@ class DefaultRouter(SimpleRouter):
"""
include_root_view = True
include_format_suffixes = True
+ root_view_name = 'api-root'
def get_api_root_view(self):
"""
@@ -244,7 +245,7 @@ class DefaultRouter(SimpleRouter):
urls = []
if self.include_root_view:
- root_url = url(r'^$', self.get_api_root_view(), name='api-root')
+ root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name)
urls.append(root_url)
default_urls = super(DefaultRouter, self).get_urls()
diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py
index d46ac079..6a50be06 100644
--- a/rest_framework/tests/test_authentication.py
+++ b/rest_framework/tests/test_authentication.py
@@ -428,6 +428,47 @@ class OAuthTests(TestCase):
response = self.csrf_client.post('/oauth-with-scope/', params)
self.assertEqual(response.status_code, 200)
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_bad_consumer_key(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': 'badconsumerkey'
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_bad_token_key(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': 'badtokenkey',
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication"""
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index a7534f70..291142cf 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -6,7 +6,7 @@ from rest_framework import serializers, viewsets
from rest_framework.compat import include, patterns, url
from rest_framework.decorators import link, action
from rest_framework.response import Response
-from rest_framework.routers import SimpleRouter
+from rest_framework.routers import SimpleRouter, DefaultRouter
factory = RequestFactory()
@@ -148,3 +148,17 @@ class TestTrailingSlash(TestCase):
expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
for idx in range(len(expected)):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+class TestNameableRoot(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+ self.router = DefaultRouter()
+ self.router.root_view_name = 'nameable-root'
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_router_has_custom_name(self):
+ expected = 'nameable-root'
+ self.assertEqual(expected, self.urls[0].name)
+
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 93ea9816..9d89d1cb 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -3,7 +3,7 @@ Provides various throttling policies.
"""
from __future__ import unicode_literals
from django.core.cache import cache
-from rest_framework import exceptions
+from django.core.exceptions import ImproperlyConfigured
from rest_framework.settings import api_settings
import time
@@ -65,13 +65,13 @@ class SimpleRateThrottle(BaseThrottle):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
- raise exceptions.ConfigurationError(msg)
+ raise ImproperlyConfigured(msg)
try:
return self.settings.DEFAULT_THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
- raise exceptions.ConfigurationError(msg)
+ raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
"""
diff --git a/rest_framework/views.py b/rest_framework/views.py
index e1b6705b..c28d2835 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -304,10 +304,10 @@ class APIView(View):
`.dispatch()` is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
"""
- request = self.initialize_request(request, *args, **kwargs)
- self.request = request
self.args = args
self.kwargs = kwargs
+ request = self.initialize_request(request, *args, **kwargs)
+ self.request = request
self.headers = self.default_response_headers # deprecate?
try:
@@ -341,8 +341,15 @@ class APIView(View):
Return a dictionary of metadata about the view.
Used to return responses for OPTIONS requests.
"""
+
+ # This is used by ViewSets to disambiguate instance vs list views
+ view_name_suffix = getattr(self, 'suffix', None)
+
+ # By default we can't provide any form-like information, however the
+ # generic views override this implementation and add additional
+ # information for POST and PUT methods, based on the serializer.
ret = SortedDict()
- ret['name'] = get_view_name(self.__class__)
+ ret['name'] = get_view_name(self.__class__, view_name_suffix)
ret['description'] = get_view_description(self.__class__)
ret['renders'] = [renderer.media_type for renderer in self.renderer_classes]
ret['parses'] = [parser.media_type for parser in self.parser_classes]