aboutsummaryrefslogtreecommitdiffstats
path: root/djangorestframework
diff options
context:
space:
mode:
Diffstat (limited to 'djangorestframework')
-rw-r--r--djangorestframework/compat.py7
-rw-r--r--djangorestframework/mixins.py45
-rw-r--r--djangorestframework/parsers.py44
-rw-r--r--djangorestframework/renderers.py55
-rw-r--r--djangorestframework/resources.py50
-rw-r--r--djangorestframework/runtests/settings.py10
-rw-r--r--djangorestframework/serializer.py8
-rw-r--r--djangorestframework/templates/renderer.html8
-rw-r--r--djangorestframework/tests/oauthentication.py212
-rw-r--r--djangorestframework/tests/renderers.py98
-rw-r--r--djangorestframework/tests/reverse.py6
11 files changed, 450 insertions, 93 deletions
diff --git a/djangorestframework/compat.py b/djangorestframework/compat.py
index 827b4adf..230172c3 100644
--- a/djangorestframework/compat.py
+++ b/djangorestframework/compat.py
@@ -156,6 +156,7 @@ except ImportError:
def head(self, request, *args, **kwargs):
return self.get(request, *args, **kwargs)
+# Markdown is optional
try:
import markdown
import re
@@ -204,3 +205,9 @@ try:
except ImportError:
apply_markdown = None
+
+# Yaml is optional
+try:
+ import yaml
+except ImportError:
+ yaml = None
diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py
index 910d06ae..b1ba0596 100644
--- a/djangorestframework/mixins.py
+++ b/djangorestframework/mixins.py
@@ -11,6 +11,7 @@ from django.http.multipartparser import LimitBytes
from djangorestframework import status
from djangorestframework.parsers import FormParser, MultiPartParser
+from djangorestframework.renderers import BaseRenderer
from djangorestframework.resources import Resource, FormResource, ModelResource
from djangorestframework.response import Response, ErrorResponse
from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX
@@ -290,7 +291,7 @@ class ResponseMixin(object):
accept_list = [token.strip() for token in request.META["HTTP_ACCEPT"].split(',')]
else:
# No accept header specified
- return (self._default_renderer(self), self._default_renderer.media_type)
+ accept_list = ['*/*']
# Check the acceptable media types against each renderer,
# attempting more specific media types first
@@ -298,12 +299,12 @@ class ResponseMixin(object):
# Worst case is we're looping over len(accept_list) * len(self.renderers)
renderers = [renderer_cls(self) for renderer_cls in self.renderers]
- for media_type_lst in order_by_precedence(accept_list):
+ for accepted_media_type_lst in order_by_precedence(accept_list):
for renderer in renderers:
- for media_type in media_type_lst:
- if renderer.can_handle_response(media_type):
- return renderer, media_type
-
+ for accepted_media_type in accepted_media_type_lst:
+ if renderer.can_handle_response(accepted_media_type):
+ return renderer, accepted_media_type
+
# No acceptable renderers were found
raise ErrorResponse(status.HTTP_406_NOT_ACCEPTABLE,
{'detail': 'Could not satisfy the client\'s Accept header',
@@ -316,6 +317,13 @@ class ResponseMixin(object):
Return an list of all the media types that this view can render.
"""
return [renderer.media_type for renderer in self.renderers]
+
+ @property
+ def _rendered_formats(self):
+ """
+ Return a list of all the formats that this view can render.
+ """
+ return [renderer.format for renderer in self.renderers]
@property
def _default_renderer(self):
@@ -483,14 +491,17 @@ class ReadModelMixin(object):
try:
if args:
# If we have any none kwargs then assume the last represents the primrary key
- instance = model.objects.get(pk=args[-1], **kwargs)
+ self.model_instance = model.objects.get(pk=args[-1], **kwargs)
else:
# Otherwise assume the kwargs uniquely identify the model
- instance = model.objects.get(**kwargs)
+ filtered_keywords = kwargs.copy()
+ if BaseRenderer._FORMAT_QUERY_PARAM in filtered_keywords:
+ del filtered_keywords[BaseRenderer._FORMAT_QUERY_PARAM]
+ self.model_instance = model.objects.get(**filtered_keywords)
except model.DoesNotExist:
raise ErrorResponse(status.HTTP_404_NOT_FOUND)
- return instance
+ return self.model_instance
class CreateModelMixin(object):
@@ -529,19 +540,19 @@ class UpdateModelMixin(object):
try:
if args:
# If we have any none kwargs then assume the last represents the primrary key
- instance = model.objects.get(pk=args[-1], **kwargs)
+ self.model_instance = model.objects.get(pk=args[-1], **kwargs)
else:
# Otherwise assume the kwargs uniquely identify the model
- instance = model.objects.get(**kwargs)
+ self.model_instance = model.objects.get(**kwargs)
for (key, val) in self.CONTENT.items():
- setattr(instance, key, val)
+ setattr(self.model_instance, key, val)
except model.DoesNotExist:
- instance = model(**self.CONTENT)
- instance.save()
+ self.model_instance = model(**self.CONTENT)
+ self.model_instance.save()
- instance.save()
- return instance
+ self.model_instance.save()
+ return self.model_instance
class DeleteModelMixin(object):
@@ -587,7 +598,7 @@ class ListModelMixin(object):
def get(self, request, *args, **kwargs):
model = self.resource.model
- queryset = self.queryset if self.queryset else model.objects.all()
+ queryset = self.queryset if self.queryset is not None else model.objects.all()
if hasattr(self, 'resource'):
ordering = getattr(self.resource, 'ordering', None)
diff --git a/djangorestframework/parsers.py b/djangorestframework/parsers.py
index a25ca89e..5f19c521 100644
--- a/djangorestframework/parsers.py
+++ b/djangorestframework/parsers.py
@@ -13,12 +13,13 @@ We need a method to be able to:
from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
+from django.http.multipartparser import MultiPartParserError
from django.utils import simplejson as json
from djangorestframework import status
+from djangorestframework.compat import yaml
from djangorestframework.response import ErrorResponse
from djangorestframework.utils.mediatypes import media_type_matches
-import yaml
__all__ = (
'BaseParser',
@@ -87,25 +88,26 @@ class JSONParser(BaseParser):
{'detail': 'JSON parse error - %s' % unicode(exc)})
-class YAMLParser(BaseParser):
- """
- Parses YAML-serialized data.
- """
-
- media_type = 'application/yaml'
-
- def parse(self, stream):
+if yaml:
+ class YAMLParser(BaseParser):
"""
- Returns a 2-tuple of `(data, files)`.
-
- `data` will be an object which is the parsed content of the response.
- `files` will always be `None`.
+ Parses YAML-serialized data.
"""
- try:
- return (yaml.safe_load(stream), None)
- except ValueError, exc:
- raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
- {'detail': 'YAML parse error - %s' % unicode(exc)})
+
+ media_type = 'application/yaml'
+
+ def parse(self, stream):
+ """
+ Returns a 2-tuple of `(data, files)`.
+
+ `data` will be an object which is the parsed content of the response.
+ `files` will always be `None`.
+ """
+ try:
+ return (yaml.safe_load(stream), None)
+ except ValueError, exc:
+ raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
+ {'detail': 'YAML parse error - %s' % unicode(exc)})
class PlainTextParser(BaseParser):
@@ -158,6 +160,10 @@ class MultiPartParser(BaseParser):
`files` will be a :class:`QueryDict` containing all the form files.
"""
upload_handlers = self.view.request._get_upload_handlers()
- django_parser = DjangoMultiPartParser(self.view.request.META, stream, upload_handlers)
+ try:
+ django_parser = DjangoMultiPartParser(self.view.request.META, stream, upload_handlers)
+ except MultiPartParserError, exc:
+ raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
+ {'detail': 'multipart parse error - %s' % unicode(exc)})
return django_parser.parse()
diff --git a/djangorestframework/renderers.py b/djangorestframework/renderers.py
index 18ffbf66..aae2cab2 100644
--- a/djangorestframework/renderers.py
+++ b/djangorestframework/renderers.py
@@ -12,7 +12,7 @@ from django.template import RequestContext, loader
from django.utils import simplejson as json
-from djangorestframework.compat import apply_markdown
+from djangorestframework.compat import apply_markdown, yaml
from djangorestframework.utils import dict2xml, url_resolves
from djangorestframework.utils.breadcrumbs import get_breadcrumbs
from djangorestframework.utils.description import get_name, get_description
@@ -21,7 +21,6 @@ from djangorestframework import VERSION
import string
from urllib import quote_plus
-import yaml
__all__ = (
'BaseRenderer',
@@ -40,8 +39,11 @@ class BaseRenderer(object):
All renderers must extend this class, set the :attr:`media_type` attribute,
and override the :meth:`render` method.
"""
+
+ _FORMAT_QUERY_PARAM = 'format'
media_type = None
+ format = None
def __init__(self, view):
self.view = view
@@ -58,6 +60,11 @@ class BaseRenderer(object):
This may be overridden to provide for other behavior, but typically you'll
instead want to just set the :attr:`media_type` attribute on the class.
"""
+ format = self.view.kwargs.get(self._FORMAT_QUERY_PARAM, None)
+ if format is None:
+ format = self.view.request.GET.get(self._FORMAT_QUERY_PARAM, None)
+ if format is not None:
+ return format == self.format
return media_type_matches(self.media_type, accept)
def render(self, obj=None, media_type=None):
@@ -84,6 +91,7 @@ class JSONRenderer(BaseRenderer):
"""
media_type = 'application/json'
+ format = 'json'
def render(self, obj=None, media_type=None):
"""
@@ -111,6 +119,7 @@ class XMLRenderer(BaseRenderer):
"""
media_type = 'application/xml'
+ format = 'xml'
def render(self, obj=None, media_type=None):
"""
@@ -120,20 +129,27 @@ class XMLRenderer(BaseRenderer):
return ''
return dict2xml(obj)
-class YAMLRenderer(BaseRenderer):
- """
- Renderer which serializes to YAML.
- """
- media_type = 'application/yaml'
-
- def render(self, obj=None, media_type=None):
+if yaml:
+ class YAMLRenderer(BaseRenderer):
"""
- Renders *obj* into serialized YAML.
+ Renderer which serializes to YAML.
"""
- if obj is None:
- return ''
- return yaml.dump(obj)
+
+ media_type = 'application/yaml'
+ format = 'yaml'
+
+ def render(self, obj=None, media_type=None):
+ """
+ Renders *obj* into serialized YAML.
+ """
+ if obj is None:
+ return ''
+
+ return yaml.dump(obj)
+else:
+ YAMLRenderer = None
+
class TemplateRenderer(BaseRenderer):
"""
@@ -303,12 +319,12 @@ class DocumentingTemplateRenderer(BaseRenderer):
'version': VERSION,
'markeddown': markeddown,
'breadcrumblist': breadcrumb_list,
- 'available_media_types': self.view._rendered_media_types,
+ 'available_formats': self.view._rendered_formats,
'put_form': put_form_instance,
'post_form': post_form_instance,
'login_url': login_url,
'logout_url': logout_url,
- 'ACCEPT_PARAM': getattr(self.view, '_ACCEPT_QUERY_PARAM', None),
+ 'FORMAT_PARAM': self._FORMAT_QUERY_PARAM,
'METHOD_PARAM': getattr(self.view, '_METHOD_PARAM', None),
'ADMIN_MEDIA_PREFIX': settings.ADMIN_MEDIA_PREFIX
})
@@ -331,6 +347,7 @@ class DocumentingHTMLRenderer(DocumentingTemplateRenderer):
"""
media_type = 'text/html'
+ format = 'html'
template = 'renderer.html'
@@ -342,6 +359,7 @@ class DocumentingXHTMLRenderer(DocumentingTemplateRenderer):
"""
media_type = 'application/xhtml+xml'
+ format = 'xhtml'
template = 'renderer.html'
@@ -353,6 +371,7 @@ class DocumentingPlainTextRenderer(DocumentingTemplateRenderer):
"""
media_type = 'text/plain'
+ format = 'txt'
template = 'renderer.txt'
@@ -360,7 +379,7 @@ DEFAULT_RENDERERS = ( JSONRenderer,
DocumentingHTMLRenderer,
DocumentingXHTMLRenderer,
DocumentingPlainTextRenderer,
- XMLRenderer,
- YAMLRenderer )
-
+ XMLRenderer )
+if YAMLRenderer:
+ DEFAULT_RENDERERS += (YAMLRenderer,)
diff --git a/djangorestframework/resources.py b/djangorestframework/resources.py
index b42bd952..be361ab8 100644
--- a/djangorestframework/resources.py
+++ b/djangorestframework/resources.py
@@ -177,14 +177,12 @@ class FormResource(Resource):
# Return HTTP 400 response (BAD REQUEST)
raise ErrorResponse(400, detail)
-
- def get_bound_form(self, data=None, files=None, method=None):
+
+ def get_form_class(self, method=None):
"""
- Given some content return a Django form bound to that content.
- If form validation is turned off (:attr:`form` class attribute is :const:`None`) then returns :const:`None`.
+ Returns the form class used to validate this resource.
"""
-
# A form on the view overrides a form on the resource.
form = getattr(self.view, 'form', None) or self.form
@@ -200,6 +198,16 @@ class FormResource(Resource):
form = getattr(self, '%s_form' % method.lower(), form)
form = getattr(self.view, '%s_form' % method.lower(), form)
+ return form
+
+
+ def get_bound_form(self, data=None, files=None, method=None):
+ """
+ Given some content return a Django form bound to that content.
+ If form validation is turned off (:attr:`form` class attribute is :const:`None`) then returns :const:`None`.
+ """
+ form = self.get_form_class(method)
+
if not form:
return None
@@ -306,31 +314,31 @@ class ModelResource(FormResource):
If the :attr:`form` class attribute has been explicitly set then that class will be used
to create the Form, otherwise the model will be used to create a ModelForm.
"""
+ form = self.get_form_class(method)
- form = super(ModelResource, self).get_bound_form(data, files, method=method)
-
- # Use an explict Form if it exists
- if form:
- return form
-
- elif self.model:
+ if not form and self.model:
# Fall back to ModelForm which we create on the fly
class OnTheFlyModelForm(forms.ModelForm):
class Meta:
model = self.model
#fields = tuple(self._model_fields_set)
- # Instantiate the ModelForm as appropriate
- if data and isinstance(data, models.Model):
- # Bound to an existing model instance
- return OnTheFlyModelForm(instance=content)
- elif data is not None:
- return OnTheFlyModelForm(data, files)
- return OnTheFlyModelForm()
+ form = OnTheFlyModelForm
# Both form and model not set? Okay bruv, whatevs...
- return None
-
+ if not form:
+ return None
+
+ # Instantiate the ModelForm as appropriate
+ if data is not None or files is not None:
+ if issubclass(form, forms.ModelForm) and hasattr(self.view, 'model_instance'):
+ # Bound to an existing model instance
+ return form(data, files, instance=self.view.model_instance)
+ else:
+ return form(data, files)
+
+ return form()
+
def url(self, instance):
"""
diff --git a/djangorestframework/runtests/settings.py b/djangorestframework/runtests/settings.py
index 0cc7f4e3..9b3c2c92 100644
--- a/djangorestframework/runtests/settings.py
+++ b/djangorestframework/runtests/settings.py
@@ -2,6 +2,7 @@
DEBUG = True
TEMPLATE_DEBUG = DEBUG
+DEBUG_PROPAGATE_EXCEPTIONS = True
ADMINS = (
# ('Your Name', 'your_email@domain.com'),
@@ -96,6 +97,15 @@ INSTALLED_APPS = (
'djangorestframework',
)
+# OAuth support is optional, so we only test oauth if it's installed.
+try:
+ import oauth_provider
+except ImportError:
+ pass
+else:
+ INSTALLED_APPS += ('oauth_provider',)
+
+# If we're running on the Jenkins server we want to archive the coverage reports as XML.
import os
if os.environ.get('HUDSON_URL', None):
TEST_RUNNER = 'xmlrunner.extra.djangotestrunner.XMLTestRunner'
diff --git a/djangorestframework/serializer.py b/djangorestframework/serializer.py
index da8036e9..82aeb53f 100644
--- a/djangorestframework/serializer.py
+++ b/djangorestframework/serializer.py
@@ -4,7 +4,7 @@ Customizable serialization.
from django.db import models
from django.db.models.query import QuerySet
from django.db.models.fields.related import RelatedField
-from django.utils.encoding import smart_unicode, is_protected_type
+from django.utils.encoding import smart_unicode, is_protected_type, smart_str
import decimal
import inspect
@@ -177,7 +177,7 @@ class Serializer(object):
Keys serialize to their string value,
unless they exist in the `rename` dict.
"""
- return getattr(self.rename, key, key)
+ return getattr(self.rename, smart_str(key), smart_str(key))
def serialize_val(self, key, obj):
@@ -228,12 +228,12 @@ class Serializer(object):
# serialize each required field
for fname in fields:
- if hasattr(self, fname):
+ if hasattr(self, smart_str(fname)):
# check for a method 'fname' on self first
meth = getattr(self, fname)
if inspect.ismethod(meth) and len(inspect.getargspec(meth)[0]) == 2:
obj = meth(instance)
- elif hasattr(instance, fname):
+ elif hasattr(instance, smart_str(fname)):
# now check for an attribute 'fname' on the instance
obj = getattr(instance, fname)
elif fname in instance:
diff --git a/djangorestframework/templates/renderer.html b/djangorestframework/templates/renderer.html
index 44e032aa..3dd5faf3 100644
--- a/djangorestframework/templates/renderer.html
+++ b/djangorestframework/templates/renderer.html
@@ -50,9 +50,9 @@
<h2>GET {{ name }}</h2>
<div class='submit-row' style='margin: 0; border: 0'>
<a href='{{ request.get_full_path }}' rel="nofollow" style='float: left'>GET</a>
- {% for media_type in available_media_types %}
- {% with ACCEPT_PARAM|add:"="|add:media_type as param %}
- [<a href='{{ request.get_full_path|add_query_param:param }}' rel="nofollow">{{ media_type }}</a>]
+ {% for format in available_formats %}
+ {% with FORMAT_PARAM|add:"="|add:format as param %}
+ [<a href='{{ request.get_full_path|add_query_param:param }}' rel="nofollow">{{ format }}</a>]
{% endwith %}
{% endfor %}
</div>
@@ -124,4 +124,4 @@
</div>
</div>
</body>
-</html> \ No newline at end of file
+</html>
diff --git a/djangorestframework/tests/oauthentication.py b/djangorestframework/tests/oauthentication.py
new file mode 100644
index 00000000..109d9a72
--- /dev/null
+++ b/djangorestframework/tests/oauthentication.py
@@ -0,0 +1,212 @@
+import time
+
+from django.conf.urls.defaults import patterns, url, include
+from django.contrib.auth.models import User
+from django.test import Client, TestCase
+
+from djangorestframework.views import View
+
+# Since oauth2 / django-oauth-plus are optional dependancies, we don't want to
+# always run these tests.
+
+# Unfortunatly we can't skip tests easily until 2.7, se we'll just do this for now.
+try:
+ import oauth2 as oauth
+ from oauth_provider.decorators import oauth_required
+ from oauth_provider.models import Resource, Consumer, Token
+
+except ImportError:
+ pass
+
+else:
+ # Alrighty, we're good to go here.
+ class ClientView(View):
+ def get(self, request):
+ return {'resource': 'Protected!'}
+
+ urlpatterns = patterns('',
+ url(r'^$', oauth_required(ClientView.as_view())),
+ url(r'^oauth/', include('oauth_provider.urls')),
+ url(r'^accounts/login/$', 'djangorestframework.utils.staticviews.api_login'),
+ )
+
+
+ class OAuthTests(TestCase):
+ """
+ OAuth authentication:
+ * the user would like to access his API data from a third-party website
+ * the third-party website proposes a link to get that API data
+ * the user is redirected to the API and must log in if not authenticated
+ * the API displays a webpage to confirm that the user trusts the third-party website
+ * if confirmed, the user is redirected to the third-party website through the callback view
+ * the third-party website is able to retrieve data from the API
+ """
+ urls = 'djangorestframework.tests.oauthentication'
+
+ def setUp(self):
+ self.client = Client()
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ # OAuth requirements
+ self.resource = Resource(name='data', url='/')
+ self.resource.save()
+ self.CONSUMER_KEY = 'dpf43f3p2l4k3l03'
+ self.CONSUMER_SECRET = 'kd94hf93k423kf44'
+ self.consumer = Consumer(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
+ name='api.example.com', user=self.user)
+ self.consumer.save()
+
+ def test_oauth_invalid_and_anonymous_access(self):
+ """
+ Verify that the resource is protected and the OAuth authorization view
+ require the user to be logged in.
+ """
+ response = self.client.get('/')
+ self.assertEqual(response.content, 'Invalid request parameters.')
+ self.assertEqual(response.status_code, 401)
+ response = self.client.get('/oauth/authorize/', follow=True)
+ self.assertRedirects(response, '/accounts/login/?next=/oauth/authorize/')
+
+ def test_oauth_authorize_access(self):
+ """
+ Verify that once logged in, the user can access the authorization page
+ but can't display the page because the request token is not specified.
+ """
+ self.client.login(username=self.username, password=self.password)
+ response = self.client.get('/oauth/authorize/', follow=True)
+ self.assertEqual(response.content, 'No request token specified.')
+
+ def _create_request_token_parameters(self):
+ """
+ A shortcut to create request's token parameters.
+ """
+ return {
+ 'oauth_consumer_key': self.CONSUMER_KEY,
+ 'oauth_signature_method': 'PLAINTEXT',
+ 'oauth_signature': '%s&' % self.CONSUMER_SECRET,
+ 'oauth_timestamp': str(int(time.time())),
+ 'oauth_nonce': 'requestnonce',
+ 'oauth_version': '1.0',
+ 'oauth_callback': 'http://api.example.com/request_token_ready',
+ 'scope': 'data',
+ }
+
+ def test_oauth_request_token_retrieval(self):
+ """
+ Verify that the request token can be retrieved by the server.
+ """
+ response = self.client.get("/oauth/request_token/",
+ self._create_request_token_parameters())
+ self.assertEqual(response.status_code, 200)
+ token = list(Token.objects.all())[-1]
+ self.failIf(token.key not in response.content)
+ self.failIf(token.secret not in response.content)
+
+ def test_oauth_user_request_authorization(self):
+ """
+ Verify that the user can access the authorization page once logged in
+ and the request token has been retrieved.
+ """
+ # Setup
+ response = self.client.get("/oauth/request_token/",
+ self._create_request_token_parameters())
+ token = list(Token.objects.all())[-1]
+
+ # Starting the test here
+ self.client.login(username=self.username, password=self.password)
+ parameters = {'oauth_token': token.key,}
+ response = self.client.get("/oauth/authorize/", parameters)
+ self.assertEqual(response.status_code, 200)
+ self.failIf(not response.content.startswith('Fake authorize view for api.example.com with params: oauth_token='))
+ self.assertEqual(token.is_approved, 0)
+ parameters['authorize_access'] = 1 # fake authorization by the user
+ response = self.client.post("/oauth/authorize/", parameters)
+ self.assertEqual(response.status_code, 302)
+ self.failIf(not response['Location'].startswith('http://api.example.com/request_token_ready?oauth_verifier='))
+ token = Token.objects.get(key=token.key)
+ self.failIf(token.key not in response['Location'])
+ self.assertEqual(token.is_approved, 1)
+
+ def _create_access_token_parameters(self, token):
+ """
+ A shortcut to create access' token parameters.
+ """
+ return {
+ 'oauth_consumer_key': self.CONSUMER_KEY,
+ 'oauth_token': token.key,
+ 'oauth_signature_method': 'PLAINTEXT',
+ 'oauth_signature': '%s&%s' % (self.CONSUMER_SECRET, token.secret),
+ 'oauth_timestamp': str(int(time.time())),
+ 'oauth_nonce': 'accessnonce',
+ 'oauth_version': '1.0',
+ 'oauth_verifier': token.verifier,
+ 'scope': 'data',
+ }
+
+ def test_oauth_access_token_retrieval(self):
+ """
+ Verify that the request token can be retrieved by the server.
+ """
+ # Setup
+ response = self.client.get("/oauth/request_token/",
+ self._create_request_token_parameters())
+ token = list(Token.objects.all())[-1]
+ self.client.login(username=self.username, password=self.password)
+ parameters = {'oauth_token': token.key,}
+ response = self.client.get("/oauth/authorize/", parameters)
+ parameters['authorize_access'] = 1 # fake authorization by the user
+ response = self.client.post("/oauth/authorize/", parameters)
+ token = Token.objects.get(key=token.key)
+
+ # Starting the test here
+ response = self.client.get("/oauth/access_token/", self._create_access_token_parameters(token))
+ self.assertEqual(response.status_code, 200)
+ self.failIf(not response.content.startswith('oauth_token_secret='))
+ access_token = list(Token.objects.filter(token_type=Token.ACCESS))[-1]
+ self.failIf(access_token.key not in response.content)
+ self.failIf(access_token.secret not in response.content)
+ self.assertEqual(access_token.user.username, 'john')
+
+ def _create_access_parameters(self, access_token):
+ """
+ A shortcut to create access' parameters.
+ """
+ parameters = {
+ 'oauth_consumer_key': self.CONSUMER_KEY,
+ 'oauth_token': access_token.key,
+ 'oauth_signature_method': 'HMAC-SHA1',
+ 'oauth_timestamp': str(int(time.time())),
+ 'oauth_nonce': 'accessresourcenonce',
+ 'oauth_version': '1.0',
+ }
+ oauth_request = oauth.Request.from_token_and_callback(access_token,
+ http_url='http://testserver/', parameters=parameters)
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ signature = signature_method.sign(oauth_request, self.consumer, access_token)
+ parameters['oauth_signature'] = signature
+ return parameters
+
+ def test_oauth_protected_resource_access(self):
+ """
+ Verify that the request token can be retrieved by the server.
+ """
+ # Setup
+ response = self.client.get("/oauth/request_token/",
+ self._create_request_token_parameters())
+ token = list(Token.objects.all())[-1]
+ self.client.login(username=self.username, password=self.password)
+ parameters = {'oauth_token': token.key,}
+ response = self.client.get("/oauth/authorize/", parameters)
+ parameters['authorize_access'] = 1 # fake authorization by the user
+ response = self.client.post("/oauth/authorize/", parameters)
+ token = Token.objects.get(key=token.key)
+ response = self.client.get("/oauth/access_token/", self._create_access_token_parameters(token))
+ access_token = list(Token.objects.filter(token_type=Token.ACCESS))[-1]
+
+ # Starting the test here
+ response = self.client.get("/", self._create_access_token_parameters(access_token))
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, '{"resource": "Protected!"}')
diff --git a/djangorestframework/tests/renderers.py b/djangorestframework/tests/renderers.py
index 569eb640..d2046212 100644
--- a/djangorestframework/tests/renderers.py
+++ b/djangorestframework/tests/renderers.py
@@ -2,16 +2,17 @@ from django.conf.urls.defaults import patterns, url
from django import http
from django.test import TestCase
+from djangorestframework import status
from djangorestframework.compat import View as DjangoView
-from djangorestframework.renderers import BaseRenderer, JSONRenderer
-from djangorestframework.parsers import JSONParser
+from djangorestframework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer
+from djangorestframework.parsers import JSONParser, YAMLParser
from djangorestframework.mixins import ResponseMixin
from djangorestframework.response import Response
from djangorestframework.utils.mediatypes import add_media_type_param
from StringIO import StringIO
-DUMMYSTATUS = 200
+DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'
RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x
@@ -19,12 +20,14 @@ RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x
class RendererA(BaseRenderer):
media_type = 'mock/renderera'
+ format="formata"
def render(self, obj=None, media_type=None):
return RENDERER_A_SERIALIZER(obj)
class RendererB(BaseRenderer):
media_type = 'mock/rendererb'
+ format="formatb"
def render(self, obj=None, media_type=None):
return RENDERER_B_SERIALIZER(obj)
@@ -32,11 +35,13 @@ class RendererB(BaseRenderer):
class MockView(ResponseMixin, DjangoView):
renderers = (RendererA, RendererB)
- def get(self, request):
+ def get(self, request, **kwargs):
response = Response(DUMMYSTATUS, DUMMYCONTENT)
return self.render(response)
+
urlpatterns = patterns('',
+ url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderers=[RendererA, RendererB])),
url(r'^$', MockView.as_view(renderers=[RendererA, RendererB])),
)
@@ -85,10 +90,58 @@ class RendererIntegrationTests(TestCase):
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
+ def test_specified_renderer_serializes_content_on_accept_query(self):
+ """The '_accept' query string should behave in the same way as the Accept header."""
+ resp = self.client.get('/?_accept=%s' % RendererB.media_type)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
- self.assertEquals(resp.status_code, 406)
+ self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
+
+ def test_specified_renderer_serializes_content_on_format_query(self):
+ """If a 'format' query is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/?format=%s' % RendererB.format)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_format_kwargs(self):
+ """If a 'format' keyword arg is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/something.formatb')
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
+ """If both a 'format' query and a matching Accept header specified,
+ the renderer with the matching format attribute should serialize the response."""
+ resp = self.client.get('/?format=%s' % RendererB.format,
+ HTTP_ACCEPT=RendererB.media_type)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_conflicting_format_query_and_accept_ignores_accept(self):
+ """If a 'format' query is specified that does not match the Accept
+ header, we should only honor the 'format' query string."""
+ resp = self.client.get('/?format=%s' % RendererB.format,
+ HTTP_ACCEPT='dummy')
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_bla(self):
+ resp = self.client.get('/?format=formatb',
+ HTTP_ACCEPT='text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8')
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
_flat_repr = '{"foo": ["bar", "baz"]}'
@@ -136,3 +189,38 @@ class JSONRendererTests(TestCase):
content = renderer.render(obj, 'application/json')
(data, files) = parser.parse(StringIO(content))
self.assertEquals(obj, data)
+
+
+
+if YAMLRenderer:
+ _yaml_repr = 'foo: [bar, baz]\n'
+
+
+ class YAMLRendererTests(TestCase):
+ """
+ Tests specific to the JSON Renderer
+ """
+
+ def test_render(self):
+ """
+ Test basic YAML rendering.
+ """
+ obj = {'foo':['bar','baz']}
+ renderer = YAMLRenderer(None)
+ content = renderer.render(obj, 'application/yaml')
+ self.assertEquals(content, _yaml_repr)
+
+
+ def test_render_and_parse(self):
+ """
+ Test rendering and then parsing returns the original object.
+ IE obj -> render -> parse -> obj.
+ """
+ obj = {'foo':['bar','baz']}
+
+ renderer = YAMLRenderer(None)
+ parser = YAMLParser(None)
+
+ content = renderer.render(obj, 'application/yaml')
+ (data, files) = parser.parse(StringIO(content))
+ self.assertEquals(obj, data) \ No newline at end of file
diff --git a/djangorestframework/tests/reverse.py b/djangorestframework/tests/reverse.py
index b4b0a793..2d1ca79e 100644
--- a/djangorestframework/tests/reverse.py
+++ b/djangorestframework/tests/reverse.py
@@ -24,9 +24,5 @@ class ReverseTests(TestCase):
urls = 'djangorestframework.tests.reverse'
def test_reversed_urls_are_fully_qualified(self):
- try:
- response = self.client.get('/')
- except:
- import traceback
- traceback.print_exc()
+ response = self.client.get('/')
self.assertEqual(json.loads(response.content), 'http://testserver/another')