aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--AUTHORS13
-rw-r--r--README10
-rw-r--r--djangorestframework/__init__.py2
-rw-r--r--djangorestframework/compat.py12
-rw-r--r--djangorestframework/mixins.py42
-rw-r--r--djangorestframework/parsers.py12
-rw-r--r--djangorestframework/permissions.py145
-rw-r--r--djangorestframework/renderers.py23
-rw-r--r--djangorestframework/resources.py139
-rw-r--r--djangorestframework/response.py4
-rw-r--r--djangorestframework/runtests/runtests.py28
-rw-r--r--djangorestframework/runtests/settings.py14
-rw-r--r--djangorestframework/serializer.py310
-rw-r--r--djangorestframework/templates/renderer.html10
-rw-r--r--djangorestframework/tests/content.py10
-rw-r--r--djangorestframework/tests/files.py12
-rw-r--r--djangorestframework/tests/methods.py6
-rw-r--r--djangorestframework/tests/oauthentication.py212
-rw-r--r--djangorestframework/tests/parsers.py22
-rw-r--r--djangorestframework/tests/renderers.py89
-rw-r--r--djangorestframework/tests/resources.py60
-rw-r--r--djangorestframework/tests/reverse.py6
-rw-r--r--djangorestframework/tests/serializer.py117
-rw-r--r--djangorestframework/tests/throttling.py164
-rw-r--r--djangorestframework/views.py16
-rw-r--r--docs/examples/blogpost.rst16
-rw-r--r--docs/examples/modelviews.rst9
-rw-r--r--docs/howto/alternativeframeworks.rst37
-rw-r--r--docs/index.rst7
-rw-r--r--docs/library/serializer.rst5
-rw-r--r--examples/blogpost/models.py3
-rw-r--r--examples/blogpost/resources.py27
-rw-r--r--examples/blogpost/tests.py22
-rw-r--r--examples/blogpost/urls.py29
-rw-r--r--examples/modelresourceexample/resources.py7
-rw-r--r--examples/modelresourceexample/urls.py10
-rw-r--r--examples/permissionsexample/__init__.py (renamed from examples/modelresourceexample/views.py)0
-rw-r--r--examples/permissionsexample/urls.py6
-rw-r--r--examples/permissionsexample/views.py20
-rw-r--r--examples/pygments_api/views.py11
-rw-r--r--examples/requirements.txt6
-rw-r--r--examples/sandbox/views.py4
-rw-r--r--examples/urls.py1
43 files changed, 1328 insertions, 370 deletions
diff --git a/AUTHORS b/AUTHORS
index 103423ab..0da0bca8 100644
--- a/AUTHORS
+++ b/AUTHORS
@@ -1,11 +1,16 @@
-Tom Christie <tomchristie> - tom@tomchristie.com, @thisneonsoul
-
+Tom Christie <tomchristie> - tom@tomchristie.com, @thisneonsoul - Author.
Paul Bagwell <pbgwl> - Suggestions & bugfixes.
-Marko Tibold <markotibold> - Contributions & Providing the Hudson CI Server.
+Marko Tibold <markotibold> - Contributions & Providing the Jenkins CI Server.
Sébastien Piquemal <sebpiq> - Contributions.
Carmen Wick <cwick> - Bugfixes.
Alex Ehlke <aehlke> - Design Contributions.
+Alen Mujezinovic <flashingpumpkin> - Contributions.
+Carles Barrobés <txels> - HEAD support.
+Michael Fötsch <mfoetsch> - File format support.
+David Larlet <david> - OAuth support.
+Andrew Straw <astraw> - Bugfixes.
THANKS TO:
+
Jesper Noehr <jespern> & the django-piston contributors for providing the starting point for this project.
-And of course, to the Django core team and the Django community at large.
+And of course, to the Django core team and the Django community at large. You guys rock.
diff --git a/README b/README
index c3705442..3c740486 100644
--- a/README
+++ b/README
@@ -52,6 +52,16 @@ To create the sdist packages
Release Notes
=============
+0.2.3
+
+ * Fix some throttling bugs
+ * X-Throttle header on throttling
+ * Support for nesting resources on related models
+
+0.2.2
+
+ * Throttling support complete
+
0.2.1
* Couple of simple bugfixes over 0.2.0
diff --git a/djangorestframework/__init__.py b/djangorestframework/__init__.py
index 1ee96d36..b1ef6dda 100644
--- a/djangorestframework/__init__.py
+++ b/djangorestframework/__init__.py
@@ -1,3 +1,3 @@
-__version__ = '0.2.1'
+__version__ = '0.2.3'
VERSION = __version__ # synonym
diff --git a/djangorestframework/compat.py b/djangorestframework/compat.py
index 0274511a..827b4adf 100644
--- a/djangorestframework/compat.py
+++ b/djangorestframework/compat.py
@@ -67,6 +67,14 @@ except ImportError:
# django.views.generic.View (Django >= 1.3)
try:
from django.views.generic import View
+ if not hasattr(View, 'head'):
+ # First implementation of Django class-based views did not include head method
+ # in base View class - https://code.djangoproject.com/ticket/15668
+ class ViewPlusHead(View):
+ def head(self, request, *args, **kwargs):
+ return self.get(request, *args, **kwargs)
+ View = ViewPlusHead
+
except ImportError:
from django import http
from django.utils.functional import update_wrapper
@@ -145,6 +153,8 @@ except ImportError:
#)
return http.HttpResponseNotAllowed(allowed_methods)
+ def head(self, request, *args, **kwargs):
+ return self.get(request, *args, **kwargs)
try:
import markdown
@@ -193,4 +203,4 @@ try:
return md.convert(text)
except ImportError:
- apply_markdown = None \ No newline at end of file
+ apply_markdown = None
diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py
index 11e3bb38..1b3aa241 100644
--- a/djangorestframework/mixins.py
+++ b/djangorestframework/mixins.py
@@ -11,6 +11,7 @@ from django.http.multipartparser import LimitBytes
from djangorestframework import status
from djangorestframework.parsers import FormParser, MultiPartParser
+from djangorestframework.renderers import BaseRenderer
from djangorestframework.resources import Resource, FormResource, ModelResource
from djangorestframework.response import Response, ErrorResponse
from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX
@@ -290,7 +291,7 @@ class ResponseMixin(object):
accept_list = [token.strip() for token in request.META["HTTP_ACCEPT"].split(',')]
else:
# No accept header specified
- return (self._default_renderer(self), self._default_renderer.media_type)
+ accept_list = ['*/*']
# Check the acceptable media types against each renderer,
# attempting more specific media types first
@@ -298,12 +299,12 @@ class ResponseMixin(object):
# Worst case is we're looping over len(accept_list) * len(self.renderers)
renderers = [renderer_cls(self) for renderer_cls in self.renderers]
- for media_type_lst in order_by_precedence(accept_list):
+ for accepted_media_type_lst in order_by_precedence(accept_list):
for renderer in renderers:
- for media_type in media_type_lst:
- if renderer.can_handle_response(media_type):
- return renderer, media_type
-
+ for accepted_media_type in accepted_media_type_lst:
+ if renderer.can_handle_response(accepted_media_type):
+ return renderer, accepted_media_type
+
# No acceptable renderers were found
raise ErrorResponse(status.HTTP_406_NOT_ACCEPTABLE,
{'detail': 'Could not satisfy the client\'s Accept header',
@@ -316,6 +317,13 @@ class ResponseMixin(object):
Return an list of all the media types that this view can render.
"""
return [renderer.media_type for renderer in self.renderers]
+
+ @property
+ def _rendered_formats(self):
+ """
+ Return a list of all the formats that this view can render.
+ """
+ return [renderer.format for renderer in self.renderers]
@property
def _default_renderer(self):
@@ -466,7 +474,7 @@ class InstanceMixin(object):
# We do a little dance when we store the view callable...
# we need to store it wrapped in a 1-tuple, so that inspect will treat it
# as a function when we later look it up (rather than turning it into a method).
- # This makes sure our URL reversing works ok.
+ # This makes sure our URL reversing works ok.
resource.view_callable = (view,)
return view
@@ -479,13 +487,17 @@ class ReadModelMixin(object):
"""
def get(self, request, *args, **kwargs):
model = self.resource.model
+
try:
if args:
# If we have any none kwargs then assume the last represents the primrary key
instance = model.objects.get(pk=args[-1], **kwargs)
else:
# Otherwise assume the kwargs uniquely identify the model
- instance = model.objects.get(**kwargs)
+ filtered_keywords = kwargs.copy()
+ if BaseRenderer._FORMAT_QUERY_PARAM in filtered_keywords:
+ del filtered_keywords[BaseRenderer._FORMAT_QUERY_PARAM]
+ instance = model.objects.get(**filtered_keywords)
except model.DoesNotExist:
raise ErrorResponse(status.HTTP_404_NOT_FOUND)
@@ -498,6 +510,7 @@ class CreateModelMixin(object):
"""
def post(self, request, *args, **kwargs):
model = self.resource.model
+
# translated 'related_field' kwargs into 'related_field_id'
for related_name in [field.name for field in model._meta.fields if isinstance(field, RelatedField)]:
if kwargs.has_key(related_name):
@@ -522,6 +535,7 @@ class UpdateModelMixin(object):
"""
def put(self, request, *args, **kwargs):
model = self.resource.model
+
# TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url
try:
if args:
@@ -547,6 +561,7 @@ class DeleteModelMixin(object):
"""
def delete(self, request, *args, **kwargs):
model = self.resource.model
+
try:
if args:
# If we have any none kwargs then assume the last represents the primrary key
@@ -581,8 +596,15 @@ class ListModelMixin(object):
queryset = None
def get(self, request, *args, **kwargs):
- queryset = self.queryset if self.queryset else self.resource.model.objects.all()
- ordering = getattr(self.resource, 'ordering', None)
+ model = self.resource.model
+
+ queryset = self.queryset if self.queryset else model.objects.all()
+
+ if hasattr(self, 'resource'):
+ ordering = getattr(self.resource, 'ordering', None)
+ else:
+ ordering = None
+
if ordering:
args = as_tuple(ordering)
queryset = queryset.order_by(*args)
diff --git a/djangorestframework/parsers.py b/djangorestframework/parsers.py
index 726e09e9..37882984 100644
--- a/djangorestframework/parsers.py
+++ b/djangorestframework/parsers.py
@@ -11,12 +11,12 @@ We need a method to be able to:
and multipart/form-data. (eg also handle multipart/json)
"""
+from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
+from django.http.multipartparser import MultiPartParserError
from django.utils import simplejson as json
from djangorestframework import status
-from djangorestframework.compat import parse_qs
from djangorestframework.response import ErrorResponse
-from djangorestframework.utils import as_tuple
from djangorestframework.utils.mediatypes import media_type_matches
__all__ = (
@@ -117,7 +117,7 @@ class FormParser(BaseParser):
`data` will be a :class:`QueryDict` containing all the form parameters.
`files` will always be :const:`None`.
"""
- data = parse_qs(stream.read(), keep_blank_values=True)
+ data = QueryDict(stream.read())
return (data, None)
@@ -136,6 +136,10 @@ class MultiPartParser(BaseParser):
`files` will be a :class:`QueryDict` containing all the form files.
"""
upload_handlers = self.view.request._get_upload_handlers()
- django_parser = DjangoMultiPartParser(self.view.request.META, stream, upload_handlers)
+ try:
+ django_parser = DjangoMultiPartParser(self.view.request.META, stream, upload_handlers)
+ except MultiPartParserError, exc:
+ raise ErrorResponse(status.HTTP_400_BAD_REQUEST,
+ {'detail': 'multipart parse error - %s' % unicode(exc)})
return django_parser.parse()
diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py
index 1f6151f8..7dcabcf0 100644
--- a/djangorestframework/permissions.py
+++ b/djangorestframework/permissions.py
@@ -1,6 +1,6 @@
"""
The :mod:`permissions` module bundles a set of permission classes that are used
-for checking if a request passes a certain set of constraints. You can assign a permision
+for checking if a request passes a certain set of constraints. You can assign a permission
class to your view by setting your View's :attr:`permissions` class attribute.
"""
@@ -15,7 +15,9 @@ __all__ = (
'IsAuthenticated',
'IsAdminUser',
'IsUserOrIsAnonReadOnly',
- 'PerUserThrottling'
+ 'PerUserThrottling',
+ 'PerViewThrottling',
+ 'PerResourceThrottling'
)
@@ -24,12 +26,11 @@ _403_FORBIDDEN_RESPONSE = ErrorResponse(
{'detail': 'You do not have permission to access this resource. ' +
'You may need to login or otherwise authenticate the request.'})
-_503_THROTTLED_RESPONSE = ErrorResponse(
+_503_SERVICE_UNAVAILABLE = ErrorResponse(
status.HTTP_503_SERVICE_UNAVAILABLE,
{'detail': 'request was throttled'})
-
class BasePermission(object):
"""
A base class from which all permission classes should inherit.
@@ -88,37 +89,131 @@ class IsUserOrIsAnonReadOnly(BasePermission):
raise _403_FORBIDDEN_RESPONSE
-class PerUserThrottling(BasePermission):
+class BaseThrottle(BasePermission):
"""
- Rate throttling of requests on a per-user basis.
+ Rate throttling of requests.
- The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class.
- The attribute is a two tuple of the form (number of requests, duration in seconds).
+ The rate (requests / seconds) is set by a :attr:`throttle` attribute
+ on the :class:`.View` class. The attribute is a string of the form 'number of
+ requests/period'.
- The user id will be used as a unique identifier if the user is authenticated.
- For anonymous requests, the IP address of the client will be used.
+ Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
+ """
+
+ attr_name = 'throttle'
+ default = '0/sec'
+ timer = time.time
+
+ def get_cache_key(self):
+ """
+ Should return a unique cache-key which can be used for throttling.
+ Muse be overridden.
+ """
+ pass
+
+ def check_permission(self, auth):
+ """
+ Check the throttling.
+ Return `None` or raise an :exc:`.ErrorResponse`.
+ """
+ num, period = getattr(self.view, self.attr_name, self.default).split('/')
+ self.num_requests = int(num)
+ self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
+ self.auth = auth
+ self.check_throttle()
+
+ def check_throttle(self):
+ """
+ Implement the check to see if the request should be throttled.
+
+ On success calls :meth:`throttle_success`.
+ On failure calls :meth:`throttle_failure`.
+ """
+ self.key = self.get_cache_key()
+ self.history = cache.get(self.key, [])
+ self.now = self.timer()
+
+ # Drop any requests from the history which have now passed the
+ # throttle duration
+ while self.history and self.history[-1] <= self.now - self.duration:
+ self.history.pop()
+ if len(self.history) >= self.num_requests:
+ self.throttle_failure()
+ else:
+ self.throttle_success()
+
+ def throttle_success(self):
+ """
+ Inserts the current request's timestamp along with the key
+ into the cache.
+ """
+ self.history.insert(0, self.now)
+ cache.set(self.key, self.history, self.duration)
+ header = 'status=SUCCESS; next=%s sec' % self.next()
+ self.view.add_header('X-Throttle', header)
+
+ def throttle_failure(self):
+ """
+ Called when a request to the API has failed due to throttling.
+ Raises a '503 service unavailable' response.
+ """
+ header = 'status=FAILURE; next=%s sec' % self.next()
+ self.view.add_header('X-Throttle', header)
+ raise _503_SERVICE_UNAVAILABLE
+
+ def next(self):
+ """
+ Returns the recommended next request time in seconds.
+ """
+ if self.history:
+ remaining_duration = self.duration - (self.now - self.history[-1])
+ else:
+ remaining_duration = self.duration
+
+ available_requests = self.num_requests - len(self.history) + 1
+
+ return '%.2f' % (remaining_duration / float(available_requests))
+
+
+class PerUserThrottling(BaseThrottle):
"""
+ Limits the rate of API calls that may be made by a given user.
- def check_permission(self, user):
- (num_requests, duration) = getattr(self.view, 'throttle', (0, 0))
+ The user id will be used as a unique identifier if the user is
+ authenticated. For anonymous requests, the IP address of the client will
+ be used.
+ """
- if user.is_authenticated():
- ident = str(auth)
+ def get_cache_key(self):
+ if self.auth.is_authenticated():
+ ident = str(self.auth)
else:
ident = self.view.request.META.get('REMOTE_ADDR', None)
+ return 'throttle_user_%s' % ident
- key = 'throttle_%s' % ident
- history = cache.get(key, [])
- now = time.time()
-
- # Drop any requests from the history which have now passed the throttle duration
- while history and history[0] < now - duration:
- history.pop()
- if len(history) >= num_requests:
- raise _503_THROTTLED_RESPONSE
+class PerViewThrottling(BaseThrottle):
+ """
+ Limits the rate of API calls that may be used on a given view.
+
+ The class name of the view is used as a unique identifier to
+ throttle against.
+ """
+
+ def get_cache_key(self):
+ return 'throttle_view_%s' % self.view.__class__.__name__
+
+
+class PerResourceThrottling(BaseThrottle):
+ """
+ Limits the rate of API calls that may be used against all views on
+ a given resource.
+
+ The class name of the resource is used as a unique identifier to
+ throttle against.
+ """
- history.insert(0, now)
- cache.set(key, history, duration)
+ def get_cache_key(self):
+ return 'throttle_resource_%s' % self.view.resource.__class__.__name__
diff --git a/djangorestframework/renderers.py b/djangorestframework/renderers.py
index 9834ba5e..e09e2abc 100644
--- a/djangorestframework/renderers.py
+++ b/djangorestframework/renderers.py
@@ -17,6 +17,7 @@ from djangorestframework.utils import dict2xml, url_resolves
from djangorestframework.utils.breadcrumbs import get_breadcrumbs
from djangorestframework.utils.description import get_name, get_description
from djangorestframework.utils.mediatypes import get_media_type_params, add_media_type_param, media_type_matches
+from djangorestframework import VERSION
from decimal import Decimal
import re
@@ -39,8 +40,11 @@ class BaseRenderer(object):
All renderers must extend this class, set the :attr:`media_type` attribute,
and override the :meth:`render` method.
"""
+
+ _FORMAT_QUERY_PARAM = 'format'
media_type = None
+ format = None
def __init__(self, view):
self.view = view
@@ -57,6 +61,11 @@ class BaseRenderer(object):
This may be overridden to provide for other behavior, but typically you'll
instead want to just set the :attr:`media_type` attribute on the class.
"""
+ format = self.view.kwargs.get(self._FORMAT_QUERY_PARAM, None)
+ if format is None:
+ format = self.view.request.GET.get(self._FORMAT_QUERY_PARAM, None)
+ if format is not None:
+ return format == self.format
return media_type_matches(self.media_type, accept)
def render(self, obj=None, media_type=None):
@@ -83,6 +92,7 @@ class JSONRenderer(BaseRenderer):
"""
media_type = 'application/json'
+ format = 'json'
def render(self, obj=None, media_type=None):
"""
@@ -108,7 +118,9 @@ class XMLRenderer(BaseRenderer):
"""
Renderer which serializes to XML.
"""
+
media_type = 'application/xml'
+ format = 'xml'
def render(self, obj=None, media_type=None):
"""
@@ -181,7 +193,7 @@ class DocumentingTemplateRenderer(BaseRenderer):
# Get the form instance if we have one bound to the input
form_instance = None
- if method == view.method.lower():
+ if method == getattr(view, 'method', view.request.method).lower():
form_instance = getattr(view, 'bound_form_instance', None)
if not form_instance and hasattr(view, 'get_bound_form'):
@@ -251,6 +263,7 @@ class DocumentingTemplateRenderer(BaseRenderer):
The context used in the template contains all the information
needed to self-document the response to this request.
"""
+
content = self._get_content(self.view, self.view.request, obj, media_type)
put_form_instance = self._get_form_instance(self.view, 'put')
@@ -283,14 +296,15 @@ class DocumentingTemplateRenderer(BaseRenderer):
'response': self.view.response,
'description': description,
'name': name,
+ 'version': VERSION,
'markeddown': markeddown,
'breadcrumblist': breadcrumb_list,
- 'available_media_types': self.view._rendered_media_types,
+ 'available_formats': self.view._rendered_formats,
'put_form': put_form_instance,
'post_form': post_form_instance,
'login_url': login_url,
'logout_url': logout_url,
- 'ACCEPT_PARAM': getattr(self.view, '_ACCEPT_QUERY_PARAM', None),
+ 'FORMAT_PARAM': self._FORMAT_QUERY_PARAM,
'METHOD_PARAM': getattr(self.view, '_METHOD_PARAM', None),
'ADMIN_MEDIA_PREFIX': settings.ADMIN_MEDIA_PREFIX
})
@@ -313,6 +327,7 @@ class DocumentingHTMLRenderer(DocumentingTemplateRenderer):
"""
media_type = 'text/html'
+ format = 'html'
template = 'renderer.html'
@@ -324,6 +339,7 @@ class DocumentingXHTMLRenderer(DocumentingTemplateRenderer):
"""
media_type = 'application/xhtml+xml'
+ format = 'xhtml'
template = 'renderer.html'
@@ -335,6 +351,7 @@ class DocumentingPlainTextRenderer(DocumentingTemplateRenderer):
"""
media_type = 'text/plain'
+ format = 'txt'
template = 'renderer.txt'
diff --git a/djangorestframework/resources.py b/djangorestframework/resources.py
index 4b81bf79..b42bd952 100644
--- a/djangorestframework/resources.py
+++ b/djangorestframework/resources.py
@@ -6,6 +6,7 @@ from django.db.models.fields.related import RelatedField
from django.utils.encoding import smart_unicode
from djangorestframework.response import ErrorResponse
+from djangorestframework.serializer import Serializer, _SkipField
from djangorestframework.utils import as_tuple
import decimal
@@ -13,122 +14,9 @@ import inspect
import re
-# TODO: _IgnoreFieldException
-# Map model classes to resource classes
-#_model_to_resource = {}
-
-def _model_to_dict(instance, resource=None):
- """
- Given a model instance, return a ``dict`` representing the model.
-
- The implementation is similar to Django's ``django.forms.model_to_dict``, except:
-
- * It doesn't coerce related objects into primary keys.
- * It doesn't drop ``editable=False`` fields.
- * It also supports attribute or method fields on the instance or resource.
- """
- opts = instance._meta
- data = {}
-
- #print [rel.name for rel in opts.get_all_related_objects()]
- #related = [rel.get_accessor_name() for rel in opts.get_all_related_objects()]
- #print [getattr(instance, rel) for rel in related]
- #if resource.fields:
- # fields = resource.fields
- #else:
- # fields = set(opts.fields + opts.many_to_many)
-
- fields = resource and resource.fields or ()
- include = resource and resource.include or ()
- exclude = resource and resource.exclude or ()
-
- extra_fields = fields and list(fields) or list(include)
-
- # Model fields
- for f in opts.fields + opts.many_to_many:
- if fields and not f.name in fields:
- continue
- if exclude and f.name in exclude:
- continue
- if isinstance(f, models.ForeignKey):
- data[f.name] = getattr(instance, f.name)
- else:
- data[f.name] = f.value_from_object(instance)
-
- if extra_fields and f.name in extra_fields:
- extra_fields.remove(f.name)
-
- # Method fields
- for fname in extra_fields:
-
- if isinstance(fname, (tuple, list)):
- fname, fields = fname
- else:
- fname, fields = fname, False
-
- try:
- if hasattr(resource, fname):
- # check the resource first, to allow it to override fields
- obj = getattr(resource, fname)
- # if it's a method like foo(self, instance), then call it
- if inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) == 2:
- obj = obj(instance)
- elif hasattr(instance, fname):
- # now check the object instance
- obj = getattr(instance, fname)
- else:
- continue
-
- # TODO: It would be nicer if this didn't recurse here.
- # Let's keep _model_to_dict flat, and _object_to_data recursive.
- if fields:
- Resource = type('Resource', (object,), {'fields': fields,
- 'include': (),
- 'exclude': ()})
- data[fname] = _object_to_data(obj, Resource())
- else:
- data[fname] = _object_to_data(obj)
-
- except NoReverseMatch:
- # Ug, bit of a hack for now
- pass
-
- return data
-
-
-def _object_to_data(obj, resource=None):
- """
- Convert an object into a serializable representation.
- """
- if isinstance(obj, dict):
- # dictionaries
- # TODO: apply same _model_to_dict logic fields/exclude here
- return dict([ (key, _object_to_data(val)) for key, val in obj.iteritems() ])
- if isinstance(obj, (tuple, list, set, QuerySet)):
- # basic iterables
- return [_object_to_data(item, resource) for item in obj]
- if isinstance(obj, models.Manager):
- # Manager objects
- return [_object_to_data(item, resource) for item in obj.all()]
- if isinstance(obj, models.Model):
- # Model instances
- return _object_to_data(_model_to_dict(obj, resource))
- if isinstance(obj, decimal.Decimal):
- # Decimals (force to string representation)
- return str(obj)
- if inspect.isfunction(obj) and not inspect.getargspec(obj)[0]:
- # function with no args
- return _object_to_data(obj(), resource)
- if inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1:
- # bound method
- return _object_to_data(obj(), resource)
-
- return smart_unicode(obj, strings_only=True)
-
-
-class BaseResource(object):
+class BaseResource(Serializer):
"""
Base class for all Resource classes, which simply defines the interface they provide.
"""
@@ -136,7 +24,8 @@ class BaseResource(object):
include = None
exclude = None
- def __init__(self, view):
+ def __init__(self, view=None, depth=None, stack=[], **kwargs):
+ super(BaseResource, self).__init__(depth, stack, **kwargs)
self.view = view
def validate_request(self, data, files=None):
@@ -150,7 +39,7 @@ class BaseResource(object):
"""
Given the response content, filter it into a serializable object.
"""
- return _object_to_data(obj, self)
+ return self.serialize(obj)
class Resource(BaseResource):
@@ -241,14 +130,12 @@ class FormResource(Resource):
# In addition to regular validation we also ensure no additional fields are being passed in...
unknown_fields = seen_fields_set - (form_fields_set | allowed_extra_fields_set)
unknown_fields = unknown_fields - set(('csrfmiddlewaretoken', '_accept', '_method')) # TODO: Ugh.
-
+
# Check using both regular validation, and our stricter no additional fields rule
if bound_form.is_valid() and not unknown_fields:
# Validation succeeded...
cleaned_data = bound_form.cleaned_data
- cleaned_data.update(bound_form.files)
-
# Add in any extra fields to the cleaned content...
for key in (allowed_extra_fields_set & seen_fields_set) - set(cleaned_data.keys()):
cleaned_data[key] = data[key]
@@ -299,7 +186,7 @@ class FormResource(Resource):
"""
# A form on the view overrides a form on the resource.
- form = getattr(self.view, 'form', self.form)
+ form = getattr(self.view, 'form', None) or self.form
# Use the requested method or determine the request method
if method is None and hasattr(self.view, 'request') and hasattr(self.view, 'method'):
@@ -316,7 +203,7 @@ class FormResource(Resource):
if not form:
return None
- if data is not None:
+ if data is not None or files is not None:
return form(data, files)
return form()
@@ -392,8 +279,8 @@ class ModelResource(FormResource):
"""
super(ModelResource, self).__init__(view)
- if getattr(view, 'model', None):
- self.model = view.model
+ self.model = getattr(view, 'model', None) or self.model
+
def validate_request(self, data, files=None):
"""
@@ -455,7 +342,7 @@ class ModelResource(FormResource):
"""
if not hasattr(self, 'view_callable'):
- raise NoReverseMatch
+ raise _SkipField
# dis does teh magicks...
urlconf = get_urlconf()
@@ -478,13 +365,13 @@ class ModelResource(FormResource):
if isinstance(attr, models.Model):
instance_attrs[param] = attr.pk
else:
- instance_attrs[param] = attr
+ instance_attrs[param] = attr
try:
return reverse(self.view_callable[0], kwargs=instance_attrs)
except NoReverseMatch:
pass
- raise NoReverseMatch
+ raise _SkipField
@property
diff --git a/djangorestframework/response.py b/djangorestframework/response.py
index d68ececf..311e0bb7 100644
--- a/djangorestframework/response.py
+++ b/djangorestframework/response.py
@@ -16,13 +16,13 @@ class Response(object):
An HttpResponse that may include content that hasn't yet been serialized.
"""
- def __init__(self, status=200, content=None, headers={}):
+ def __init__(self, status=200, content=None, headers=None):
self.status = status
self.media_type = None
self.has_content_body = content is not None
self.raw_content = content # content prior to filtering
self.cleaned_content = content # content after filtering
- self.headers = headers
+ self.headers = headers or {}
@property
def status_text(self):
diff --git a/djangorestframework/runtests/runtests.py b/djangorestframework/runtests/runtests.py
index a3cdfa67..1da918f5 100644
--- a/djangorestframework/runtests/runtests.py
+++ b/djangorestframework/runtests/runtests.py
@@ -13,21 +13,27 @@ os.environ['DJANGO_SETTINGS_MODULE'] = 'djangorestframework.runtests.settings'
from django.conf import settings
from django.test.utils import get_runner
+def usage():
+ return """
+ Usage: python runtests.py [UnitTestClass].[method]
+
+ You can pass the Class name of the `UnitTestClass` you want to test.
+
+ Append a method name if you only want to test a specific method of that class.
+ """
+
def main():
TestRunner = get_runner(settings)
- if hasattr(TestRunner, 'func_name'):
- # Pre 1.2 test runners were just functions,
- # and did not support the 'failfast' option.
- import warnings
- warnings.warn(
- 'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
- DeprecationWarning
- )
- failures = TestRunner(['djangorestframework'])
+ test_runner = TestRunner()
+ if len(sys.argv) == 2:
+ test_case = '.' + sys.argv[1]
+ elif len(sys.argv) == 1:
+ test_case = ''
else:
- test_runner = TestRunner()
- failures = test_runner.run_tests(['djangorestframework'])
+ print usage()
+ sys.exit(1)
+ failures = test_runner.run_tests(['djangorestframework' + test_case])
sys.exit(failures)
diff --git a/djangorestframework/runtests/settings.py b/djangorestframework/runtests/settings.py
index 0cc7f4e3..006727bc 100644
--- a/djangorestframework/runtests/settings.py
+++ b/djangorestframework/runtests/settings.py
@@ -2,6 +2,7 @@
DEBUG = True
TEMPLATE_DEBUG = DEBUG
+DEBUG_PROPAGATE_EXCEPTIONS = True
ADMINS = (
# ('Your Name', 'your_email@domain.com'),
@@ -83,7 +84,7 @@ TEMPLATE_DIRS = (
# Don't forget to use absolute paths, not relative paths.
)
-INSTALLED_APPS = (
+INSTALLED_APPS = [
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
@@ -94,8 +95,17 @@ INSTALLED_APPS = (
# Uncomment the next line to enable admin documentation:
# 'django.contrib.admindocs',
'djangorestframework',
-)
+]
+
+# OAuth support is optional, so we only test oauth if it's installed.
+try:
+ import oauth_provider
+except:
+ pass
+else:
+ INSTALLED_APPS.append('oauth_provider')
+# If we're running on the Jenkins server we want to archive the coverage reports as XML.
import os
if os.environ.get('HUDSON_URL', None):
TEST_RUNNER = 'xmlrunner.extra.djangotestrunner.XMLTestRunner'
diff --git a/djangorestframework/serializer.py b/djangorestframework/serializer.py
new file mode 100644
index 00000000..da8036e9
--- /dev/null
+++ b/djangorestframework/serializer.py
@@ -0,0 +1,310 @@
+"""
+Customizable serialization.
+"""
+from django.db import models
+from django.db.models.query import QuerySet
+from django.db.models.fields.related import RelatedField
+from django.utils.encoding import smart_unicode, is_protected_type
+
+import decimal
+import inspect
+import types
+
+
+# We register serializer classes, so that we can refer to them by their
+# class names, if there are cyclical serialization heirachys.
+_serializers = {}
+
+
+def _field_to_tuple(field):
+ """
+ Convert an item in the `fields` attribute into a 2-tuple.
+ """
+ if isinstance(field, (tuple, list)):
+ return (field[0], field[1])
+ return (field, None)
+
+def _fields_to_list(fields):
+ """
+ Return a list of field names.
+ """
+ return [_field_to_tuple(field)[0] for field in fields or ()]
+
+def _fields_to_dict(fields):
+ """
+ Return a `dict` of field name -> None, or tuple of fields, or Serializer class
+ """
+ return dict([_field_to_tuple(field) for field in fields or ()])
+
+
+class _SkipField(Exception):
+ """
+ Signals that a serialized field should be ignored.
+ We use this mechanism as the default behavior for ensuring
+ that we don't infinitely recurse when dealing with nested data.
+ """
+ pass
+
+
+class _RegisterSerializer(type):
+ """
+ Metaclass to register serializers.
+ """
+ def __new__(cls, name, bases, attrs):
+ # Build the class and register it.
+ ret = super(_RegisterSerializer, cls).__new__(cls, name, bases, attrs)
+ _serializers[name] = ret
+ return ret
+
+
+class Serializer(object):
+ """
+ Converts python objects into plain old native types suitable for
+ serialization. In particular it handles models and querysets.
+
+ The output format is specified by setting a number of attributes
+ on the class.
+
+ You may also override any of the serialization methods, to provide
+ for more flexible behavior.
+
+ Valid output types include anything that may be directly rendered into
+ json, xml etc...
+ """
+ __metaclass__ = _RegisterSerializer
+
+ fields = ()
+ """
+ Specify the fields to be serialized on a model or dict.
+ Overrides `include` and `exclude`.
+ """
+
+ include = ()
+ """
+ Fields to add to the default set to be serialized on a model/dict.
+ """
+
+ exclude = ()
+ """
+ Fields to remove from the default set to be serialized on a model/dict.
+ """
+
+ rename = {}
+ """
+ A dict of key->name to use for the field keys.
+ """
+
+ related_serializer = None
+ """
+ The default serializer class to use for any related models.
+ """
+
+ depth = None
+ """
+ The maximum depth to serialize to, or `None`.
+ """
+
+
+ def __init__(self, depth=None, stack=[], **kwargs):
+ self.depth = depth or self.depth
+ self.stack = stack
+
+
+ def get_fields(self, obj):
+ """
+ Return the set of field names/keys to use for a model instance/dict.
+ """
+ fields = self.fields
+
+ # If `fields` is not set, we use the default fields and modify
+ # them with `include` and `exclude`
+ if not fields:
+ default = self.get_default_fields(obj)
+ include = self.include or ()
+ exclude = self.exclude or ()
+ fields = set(default + list(include)) - set(exclude)
+
+ else:
+ fields = _fields_to_list(self.fields)
+
+ return fields
+
+
+ def get_default_fields(self, obj):
+ """
+ Return the default list of field names/keys for a model instance/dict.
+ These are used if `fields` is not given.
+ """
+ if isinstance(obj, models.Model):
+ opts = obj._meta
+ return [field.name for field in opts.fields + opts.many_to_many]
+ else:
+ return obj.keys()
+
+
+ def get_related_serializer(self, key):
+ info = _fields_to_dict(self.fields).get(key, None)
+
+ # If an element in `fields` is a 2-tuple of (str, tuple)
+ # then the second element of the tuple is the fields to
+ # set on the related serializer
+ if isinstance(info, (list, tuple)):
+ class OnTheFlySerializer(Serializer):
+ fields = info
+ return OnTheFlySerializer
+
+ # If an element in `fields` is a 2-tuple of (str, Serializer)
+ # then the second element of the tuple is the Serializer
+ # class to use for that field.
+ elif isinstance(info, type) and issubclass(info, Serializer):
+ return info
+
+ # If an element in `fields` is a 2-tuple of (str, str)
+ # then the second element of the tuple is the name of the Serializer
+ # class to use for that field.
+ #
+ # Black magic to deal with cyclical Serializer dependancies.
+ # Similar to what Django does for cyclically related models.
+ elif isinstance(info, str) and info in _serializers:
+ return _serializers[info]
+
+ # Otherwise use `related_serializer` or fall back to `Serializer`
+ return getattr(self, 'related_serializer') or Serializer
+
+
+ def serialize_key(self, key):
+ """
+ Keys serialize to their string value,
+ unless they exist in the `rename` dict.
+ """
+ return getattr(self.rename, key, key)
+
+
+ def serialize_val(self, key, obj):
+ """
+ Convert a model field or dict value into a serializable representation.
+ """
+ related_serializer = self.get_related_serializer(key)
+
+ if self.depth is None:
+ depth = None
+ elif self.depth <= 0:
+ return self.serialize_max_depth(obj)
+ else:
+ depth = self.depth - 1
+
+ if any([obj is elem for elem in self.stack]):
+ return self.serialize_recursion(obj)
+ else:
+ stack = self.stack[:]
+ stack.append(obj)
+
+ return related_serializer(depth=depth, stack=stack).serialize(obj)
+
+
+ def serialize_max_depth(self, obj):
+ """
+ Determine how objects should be serialized once `depth` is exceeded.
+ The default behavior is to ignore the field.
+ """
+ raise _SkipField
+
+
+ def serialize_recursion(self, obj):
+ """
+ Determine how objects should be serialized if recursion occurs.
+ The default behavior is to ignore the field.
+ """
+ raise _SkipField
+
+
+ def serialize_model(self, instance):
+ """
+ Given a model instance or dict, serialize it to a dict..
+ """
+ data = {}
+
+ fields = self.get_fields(instance)
+
+ # serialize each required field
+ for fname in fields:
+ if hasattr(self, fname):
+ # check for a method 'fname' on self first
+ meth = getattr(self, fname)
+ if inspect.ismethod(meth) and len(inspect.getargspec(meth)[0]) == 2:
+ obj = meth(instance)
+ elif hasattr(instance, fname):
+ # now check for an attribute 'fname' on the instance
+ obj = getattr(instance, fname)
+ elif fname in instance:
+ # finally check for a key 'fname' on the instance
+ obj = instance[fname]
+ else:
+ continue
+
+ try:
+ key = self.serialize_key(fname)
+ val = self.serialize_val(fname, obj)
+ data[key] = val
+ except _SkipField:
+ pass
+
+ return data
+
+
+ def serialize_iter(self, obj):
+ """
+ Convert iterables into a serializable representation.
+ """
+ return [self.serialize(item) for item in obj]
+
+
+ def serialize_func(self, obj):
+ """
+ Convert no-arg methods and functions into a serializable representation.
+ """
+ return self.serialize(obj())
+
+
+ def serialize_manager(self, obj):
+ """
+ Convert a model manager into a serializable representation.
+ """
+ return self.serialize_iter(obj.all())
+
+
+ def serialize_fallback(self, obj):
+ """
+ Convert any unhandled object into a serializable representation.
+ """
+ return smart_unicode(obj, strings_only=True)
+
+
+ def serialize(self, obj):
+ """
+ Convert any object into a serializable representation.
+ """
+
+ if isinstance(obj, (dict, models.Model)):
+ # Model instances & dictionaries
+ return self.serialize_model(obj)
+ elif isinstance(obj, (tuple, list, set, QuerySet, types.GeneratorType)):
+ # basic iterables
+ return self.serialize_iter(obj)
+ elif isinstance(obj, models.Manager):
+ # Manager objects
+ return self.serialize_manager(obj)
+ elif inspect.isfunction(obj) and not inspect.getargspec(obj)[0]:
+ # function with no args
+ return self.serialize_func(obj)
+ elif inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1:
+ # bound method
+ return self.serialize_func(obj)
+
+ # Protected types are passed through as is.
+ # (i.e. Primitives like None, numbers, dates, and Decimals.)
+ if is_protected_type(obj):
+ return obj
+
+ # All other values are converted to string.
+ return self.serialize_fallback(obj)
diff --git a/djangorestframework/templates/renderer.html b/djangorestframework/templates/renderer.html
index 94748d28..5b32d1ec 100644
--- a/djangorestframework/templates/renderer.html
+++ b/djangorestframework/templates/renderer.html
@@ -18,7 +18,7 @@
<div id="header">
<div id="branding">
- <h1 id="site-name"><a href='http://django-rest-framework.org'>Django REST framework</a></h1>
+ <h1 id="site-name"><a href='http://django-rest-framework.org'>Django REST framework</a> <small>{{ version }}</small></h1>
</div>
<div id="user-tools">
{% if user.is_active %}Welcome, {{ user }}.{% if logout_url %} <a href='{{ logout_url }}'>Log out</a>{% endif %}{% else %}Anonymous {% if login_url %}<a href='{{ login_url }}'>Log in</a>{% endif %}{% endif %}
@@ -48,9 +48,9 @@
<h2>GET {{ name }}</h2>
<div class='submit-row' style='margin: 0; border: 0'>
<a href='{{ request.get_full_path }}' rel="nofollow" style='float: left'>GET</a>
- {% for media_type in available_media_types %}
- {% with ACCEPT_PARAM|add:"="|add:media_type as param %}
- [<a href='{{ request.get_full_path|add_query_param:param }}' rel="nofollow">{{ media_type }}</a>]
+ {% for format in available_formats %}
+ {% with FORMAT_PARAM|add:"="|add:format as param %}
+ [<a href='{{ request.get_full_path|add_query_param:param }}' rel="nofollow">{{ format }}</a>]
{% endwith %}
{% endfor %}
</div>
@@ -122,4 +122,4 @@
</div>
</div>
</body>
-</html> \ No newline at end of file
+</html>
diff --git a/djangorestframework/tests/content.py b/djangorestframework/tests/content.py
index ee3597a4..83ad72d0 100644
--- a/djangorestframework/tests/content.py
+++ b/djangorestframework/tests/content.py
@@ -6,7 +6,6 @@ from djangorestframework.compat import RequestFactory
from djangorestframework.mixins import RequestMixin
from djangorestframework.parsers import FormParser, MultiPartParser, PlainTextParser
-
class TestContentParsing(TestCase):
def setUp(self):
self.req = RequestFactory()
@@ -16,6 +15,11 @@ class TestContentParsing(TestCase):
view.request = self.req.get('/')
self.assertEqual(view.DATA, None)
+ def ensure_determines_no_content_HEAD(self, view):
+ """Ensure view.DATA returns None for HEAD request."""
+ view.request = self.req.head('/')
+ self.assertEqual(view.DATA, None)
+
def ensure_determines_form_content_POST(self, view):
"""Ensure view.DATA returns content for POST request with form content."""
form_data = {'qwerty': 'uiop'}
@@ -50,6 +54,10 @@ class TestContentParsing(TestCase):
"""Ensure view.DATA returns None for GET request with no content."""
self.ensure_determines_no_content_GET(RequestMixin())
+ def test_standard_behaviour_determines_no_content_HEAD(self):
+ """Ensure view.DATA returns None for HEAD request."""
+ self.ensure_determines_no_content_HEAD(RequestMixin())
+
def test_standard_behaviour_determines_form_content_POST(self):
"""Ensure view.DATA returns content for POST request with form content."""
self.ensure_determines_form_content_POST(RequestMixin())
diff --git a/djangorestframework/tests/files.py b/djangorestframework/tests/files.py
index 25aad9b4..992d3cba 100644
--- a/djangorestframework/tests/files.py
+++ b/djangorestframework/tests/files.py
@@ -12,20 +12,16 @@ class UploadFilesTests(TestCase):
def test_upload_file(self):
-
class FileForm(forms.Form):
- file = forms.FileField
-
- class MockResource(FormResource):
- form = FileForm
+ file = forms.FileField()
class MockView(View):
permissions = ()
- resource = MockResource
+ form = FileForm
def post(self, request, *args, **kwargs):
- return {'FILE_NAME': self.CONTENT['file'][0].name,
- 'FILE_CONTENT': self.CONTENT['file'][0].read()}
+ return {'FILE_NAME': self.CONTENT['file'].name,
+ 'FILE_CONTENT': self.CONTENT['file'].read()}
file = StringIO.StringIO('stuff')
file.name = 'stuff.txt'
diff --git a/djangorestframework/tests/methods.py b/djangorestframework/tests/methods.py
index d8f0d919..c3a3a28d 100644
--- a/djangorestframework/tests/methods.py
+++ b/djangorestframework/tests/methods.py
@@ -24,3 +24,9 @@ class TestMethodOverloading(TestCase):
view = RequestMixin()
view.request = self.req.post('/', {view._METHOD_PARAM: 'DELETE'})
self.assertEqual(view.method, 'DELETE')
+
+ def test_HEAD_is_a_valid_method(self):
+ """HEAD requests identified"""
+ view = RequestMixin()
+ view.request = self.req.head('/')
+ self.assertEqual(view.method, 'HEAD')
diff --git a/djangorestframework/tests/oauthentication.py b/djangorestframework/tests/oauthentication.py
new file mode 100644
index 00000000..7f74b804
--- /dev/null
+++ b/djangorestframework/tests/oauthentication.py
@@ -0,0 +1,212 @@
+import time
+
+from django.conf.urls.defaults import patterns, url, include
+from django.contrib.auth.models import User
+from django.test import Client, TestCase
+
+from djangorestframework.views import View
+
+# Since oauth2 / django-oauth-plus are optional dependancies, we don't want to
+# always run these tests.
+
+# Unfortunatly we can't skip tests easily until 2.7, se we'll just do this for now.
+try:
+ import oauth2 as oauth
+ from oauth_provider.decorators import oauth_required
+ from oauth_provider.models import Resource, Consumer, Token
+
+except:
+ pass
+
+else:
+ # Alrighty, we're good to go here.
+ class ClientView(View):
+ def get(self, request):
+ return {'resource': 'Protected!'}
+
+ urlpatterns = patterns('',
+ url(r'^$', oauth_required(ClientView.as_view())),
+ url(r'^oauth/', include('oauth_provider.urls')),
+ url(r'^accounts/login/$', 'djangorestframework.utils.staticviews.api_login'),
+ )
+
+
+ class OAuthTests(TestCase):
+ """
+ OAuth authentication:
+ * the user would like to access his API data from a third-party website
+ * the third-party website proposes a link to get that API data
+ * the user is redirected to the API and must log in if not authenticated
+ * the API displays a webpage to confirm that the user trusts the third-party website
+ * if confirmed, the user is redirected to the third-party website through the callback view
+ * the third-party website is able to retrieve data from the API
+ """
+ urls = 'djangorestframework.tests.oauthentication'
+
+ def setUp(self):
+ self.client = Client()
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ # OAuth requirements
+ self.resource = Resource(name='data', url='/')
+ self.resource.save()
+ self.CONSUMER_KEY = 'dpf43f3p2l4k3l03'
+ self.CONSUMER_SECRET = 'kd94hf93k423kf44'
+ self.consumer = Consumer(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
+ name='api.example.com', user=self.user)
+ self.consumer.save()
+
+ def test_oauth_invalid_and_anonymous_access(self):
+ """
+ Verify that the resource is protected and the OAuth authorization view
+ require the user to be logged in.
+ """
+ response = self.client.get('/')
+ self.assertEqual(response.content, 'Invalid request parameters.')
+ self.assertEqual(response.status_code, 401)
+ response = self.client.get('/oauth/authorize/', follow=True)
+ self.assertRedirects(response, '/accounts/login/?next=/oauth/authorize/')
+
+ def test_oauth_authorize_access(self):
+ """
+ Verify that once logged in, the user can access the authorization page
+ but can't display the page because the request token is not specified.
+ """
+ self.client.login(username=self.username, password=self.password)
+ response = self.client.get('/oauth/authorize/', follow=True)
+ self.assertEqual(response.content, 'No request token specified.')
+
+ def _create_request_token_parameters(self):
+ """
+ A shortcut to create request's token parameters.
+ """
+ return {
+ 'oauth_consumer_key': self.CONSUMER_KEY,
+ 'oauth_signature_method': 'PLAINTEXT',
+ 'oauth_signature': '%s&' % self.CONSUMER_SECRET,
+ 'oauth_timestamp': str(int(time.time())),
+ 'oauth_nonce': 'requestnonce',
+ 'oauth_version': '1.0',
+ 'oauth_callback': 'http://api.example.com/request_token_ready',
+ 'scope': 'data',
+ }
+
+ def test_oauth_request_token_retrieval(self):
+ """
+ Verify that the request token can be retrieved by the server.
+ """
+ response = self.client.get("/oauth/request_token/",
+ self._create_request_token_parameters())
+ self.assertEqual(response.status_code, 200)
+ token = list(Token.objects.all())[-1]
+ self.failIf(token.key not in response.content)
+ self.failIf(token.secret not in response.content)
+
+ def test_oauth_user_request_authorization(self):
+ """
+ Verify that the user can access the authorization page once logged in
+ and the request token has been retrieved.
+ """
+ # Setup
+ response = self.client.get("/oauth/request_token/",
+ self._create_request_token_parameters())
+ token = list(Token.objects.all())[-1]
+
+ # Starting the test here
+ self.client.login(username=self.username, password=self.password)
+ parameters = {'oauth_token': token.key,}
+ response = self.client.get("/oauth/authorize/", parameters)
+ self.assertEqual(response.status_code, 200)
+ self.failIf(not response.content.startswith('Fake authorize view for api.example.com with params: oauth_token='))
+ self.assertEqual(token.is_approved, 0)
+ parameters['authorize_access'] = 1 # fake authorization by the user
+ response = self.client.post("/oauth/authorize/", parameters)
+ self.assertEqual(response.status_code, 302)
+ self.failIf(not response['Location'].startswith('http://api.example.com/request_token_ready?oauth_verifier='))
+ token = Token.objects.get(key=token.key)
+ self.failIf(token.key not in response['Location'])
+ self.assertEqual(token.is_approved, 1)
+
+ def _create_access_token_parameters(self, token):
+ """
+ A shortcut to create access' token parameters.
+ """
+ return {
+ 'oauth_consumer_key': self.CONSUMER_KEY,
+ 'oauth_token': token.key,
+ 'oauth_signature_method': 'PLAINTEXT',
+ 'oauth_signature': '%s&%s' % (self.CONSUMER_SECRET, token.secret),
+ 'oauth_timestamp': str(int(time.time())),
+ 'oauth_nonce': 'accessnonce',
+ 'oauth_version': '1.0',
+ 'oauth_verifier': token.verifier,
+ 'scope': 'data',
+ }
+
+ def test_oauth_access_token_retrieval(self):
+ """
+ Verify that the request token can be retrieved by the server.
+ """
+ # Setup
+ response = self.client.get("/oauth/request_token/",
+ self._create_request_token_parameters())
+ token = list(Token.objects.all())[-1]
+ self.client.login(username=self.username, password=self.password)
+ parameters = {'oauth_token': token.key,}
+ response = self.client.get("/oauth/authorize/", parameters)
+ parameters['authorize_access'] = 1 # fake authorization by the user
+ response = self.client.post("/oauth/authorize/", parameters)
+ token = Token.objects.get(key=token.key)
+
+ # Starting the test here
+ response = self.client.get("/oauth/access_token/", self._create_access_token_parameters(token))
+ self.assertEqual(response.status_code, 200)
+ self.failIf(not response.content.startswith('oauth_token_secret='))
+ access_token = list(Token.objects.filter(token_type=Token.ACCESS))[-1]
+ self.failIf(access_token.key not in response.content)
+ self.failIf(access_token.secret not in response.content)
+ self.assertEqual(access_token.user.username, 'john')
+
+ def _create_access_parameters(self, access_token):
+ """
+ A shortcut to create access' parameters.
+ """
+ parameters = {
+ 'oauth_consumer_key': self.CONSUMER_KEY,
+ 'oauth_token': access_token.key,
+ 'oauth_signature_method': 'HMAC-SHA1',
+ 'oauth_timestamp': str(int(time.time())),
+ 'oauth_nonce': 'accessresourcenonce',
+ 'oauth_version': '1.0',
+ }
+ oauth_request = oauth.Request.from_token_and_callback(access_token,
+ http_url='http://testserver/', parameters=parameters)
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ signature = signature_method.sign(oauth_request, self.consumer, access_token)
+ parameters['oauth_signature'] = signature
+ return parameters
+
+ def test_oauth_protected_resource_access(self):
+ """
+ Verify that the request token can be retrieved by the server.
+ """
+ # Setup
+ response = self.client.get("/oauth/request_token/",
+ self._create_request_token_parameters())
+ token = list(Token.objects.all())[-1]
+ self.client.login(username=self.username, password=self.password)
+ parameters = {'oauth_token': token.key,}
+ response = self.client.get("/oauth/authorize/", parameters)
+ parameters['authorize_access'] = 1 # fake authorization by the user
+ response = self.client.post("/oauth/authorize/", parameters)
+ token = Token.objects.get(key=token.key)
+ response = self.client.get("/oauth/access_token/", self._create_access_token_parameters(token))
+ access_token = list(Token.objects.filter(token_type=Token.ACCESS))[-1]
+
+ # Starting the test here
+ response = self.client.get("/", self._create_access_token_parameters(access_token))
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, '{"resource": "Protected!"}')
diff --git a/djangorestframework/tests/parsers.py b/djangorestframework/tests/parsers.py
index 3ab1a61c..deba688e 100644
--- a/djangorestframework/tests/parsers.py
+++ b/djangorestframework/tests/parsers.py
@@ -131,3 +131,25 @@
# self.assertEqual(data['key1'], 'val1')
# self.assertEqual(files['file1'].read(), 'blablabla')
+from StringIO import StringIO
+from cgi import parse_qs
+from django import forms
+from django.test import TestCase
+from djangorestframework.parsers import FormParser
+
+class Form(forms.Form):
+ field1 = forms.CharField(max_length=3)
+ field2 = forms.CharField()
+
+class TestFormParser(TestCase):
+ def setUp(self):
+ self.string = "field1=abc&field2=defghijk"
+
+ def test_parse(self):
+ """ Make sure the `QueryDict` works OK """
+ parser = FormParser(None)
+
+ stream = StringIO(self.string)
+ (data, files) = parser.parse(stream)
+
+ self.assertEqual(Form(data).is_valid(), True)
diff --git a/djangorestframework/tests/renderers.py b/djangorestframework/tests/renderers.py
index 54276993..bf135e55 100644
--- a/djangorestframework/tests/renderers.py
+++ b/djangorestframework/tests/renderers.py
@@ -1,13 +1,18 @@
from django.conf.urls.defaults import patterns, url
from django import http
from django.test import TestCase
+
+from djangorestframework import status
from djangorestframework.compat import View as DjangoView
from djangorestframework.renderers import BaseRenderer, JSONRenderer
+from djangorestframework.parsers import JSONParser
from djangorestframework.mixins import ResponseMixin
from djangorestframework.response import Response
from djangorestframework.utils.mediatypes import add_media_type_param
-DUMMYSTATUS = 200
+from StringIO import StringIO
+
+DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'
RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x
@@ -15,12 +20,14 @@ RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x
class RendererA(BaseRenderer):
media_type = 'mock/renderera'
+ format="formata"
def render(self, obj=None, media_type=None):
return RENDERER_A_SERIALIZER(obj)
class RendererB(BaseRenderer):
media_type = 'mock/rendererb'
+ format="formatb"
def render(self, obj=None, media_type=None):
return RENDERER_B_SERIALIZER(obj)
@@ -28,11 +35,13 @@ class RendererB(BaseRenderer):
class MockView(ResponseMixin, DjangoView):
renderers = (RendererA, RendererB)
- def get(self, request):
+ def get(self, request, **kwargs):
response = Response(DUMMYSTATUS, DUMMYCONTENT)
return self.render(response)
+
urlpatterns = patterns('',
+ url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderers=[RendererA, RendererB])),
url(r'^$', MockView.as_view(renderers=[RendererA, RendererB])),
)
@@ -51,6 +60,13 @@ class RendererIntegrationTests(TestCase):
self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
+ def test_head_method_serializes_no_content(self):
+ """No response must be included in HEAD requests."""
+ resp = self.client.head('/')
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEquals(resp['Content-Type'], RendererA.media_type)
+ self.assertEquals(resp.content, '')
+
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
@@ -74,12 +90,58 @@ class RendererIntegrationTests(TestCase):
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
+ def test_specified_renderer_serializes_content_on_accept_query(self):
+ """The '_accept' query string should behave in the same way as the Accept header."""
+ resp = self.client.get('/?_accept=%s' % RendererB.media_type)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
- self.assertEquals(resp.status_code, 406)
+ self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
+ def test_specified_renderer_serializes_content_on_format_query(self):
+ """If a 'format' query is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/?format=%s' % RendererB.format)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+ def test_specified_renderer_serializes_content_on_format_kwargs(self):
+ """If a 'format' keyword arg is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/something.formatb')
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
+ """If both a 'format' query and a matching Accept header specified,
+ the renderer with the matching format attribute should serialize the response."""
+ resp = self.client.get('/?format=%s' % RendererB.format,
+ HTTP_ACCEPT=RendererB.media_type)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_conflicting_format_query_and_accept_ignores_accept(self):
+ """If a 'format' query is specified that does not match the Accept
+ header, we should only honor the 'format' query string."""
+ resp = self.client.get('/?format=%s' % RendererB.format,
+ HTTP_ACCEPT='dummy')
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_bla(self):
+ resp = self.client.get('/?format=formatb',
+ HTTP_ACCEPT='text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8')
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
_flat_repr = '{"foo": ["bar", "baz"]}'
@@ -95,14 +157,35 @@ class JSONRendererTests(TestCase):
"""
Tests specific to the JSON Renderer
"""
+
def test_without_content_type_args(self):
+ """
+ Test basic JSON rendering.
+ """
obj = {'foo':['bar','baz']}
renderer = JSONRenderer(None)
content = renderer.render(obj, 'application/json')
self.assertEquals(content, _flat_repr)
def test_with_content_type_args(self):
+ """
+ Test JSON rendering with additional content type arguments supplied.
+ """
obj = {'foo':['bar','baz']}
renderer = JSONRenderer(None)
content = renderer.render(obj, 'application/json; indent=2')
self.assertEquals(content, _indented_repr)
+
+ def test_render_and_parse(self):
+ """
+ Test rendering and then parsing returns the original object.
+ IE obj -> render -> parse -> obj.
+ """
+ obj = {'foo':['bar','baz']}
+
+ renderer = JSONRenderer(None)
+ parser = JSONParser(None)
+
+ content = renderer.render(obj, 'application/json')
+ (data, files) = parser.parse(StringIO(content))
+ self.assertEquals(obj, data)
diff --git a/djangorestframework/tests/resources.py b/djangorestframework/tests/resources.py
deleted file mode 100644
index 088e3159..00000000
--- a/djangorestframework/tests/resources.py
+++ /dev/null
@@ -1,60 +0,0 @@
-"""Tests for the resource module"""
-from django.test import TestCase
-from djangorestframework.resources import _object_to_data
-
-from django.db import models
-
-import datetime
-import decimal
-
-class TestObjectToData(TestCase):
- """Tests for the _object_to_data function"""
-
- def test_decimal(self):
- """Decimals need to be converted to a string representation."""
- self.assertEquals(_object_to_data(decimal.Decimal('1.5')), '1.5')
-
- def test_function(self):
- """Functions with no arguments should be called."""
- def foo():
- return 1
- self.assertEquals(_object_to_data(foo), 1)
-
- def test_method(self):
- """Methods with only a ``self`` argument should be called."""
- class Foo(object):
- def foo(self):
- return 1
- self.assertEquals(_object_to_data(Foo().foo), 1)
-
- def test_datetime(self):
- """datetime objects are left as-is."""
- now = datetime.datetime.now()
- self.assertEquals(_object_to_data(now), now)
-
- def test_tuples(self):
- """ Test tuple serialisation """
- class M1(models.Model):
- field1 = models.CharField()
- field2 = models.CharField()
-
- class M2(models.Model):
- field = models.OneToOneField(M1)
-
- class M3(models.Model):
- field = models.ForeignKey(M1)
-
- m1 = M1(field1='foo', field2='bar')
- m2 = M2(field=m1)
- m3 = M3(field=m1)
-
- Resource = type('Resource', (object,), {'fields':(), 'include':(), 'exclude':()})
-
- r = Resource()
- r.fields = (('field', ('field1')),)
-
- self.assertEqual(_object_to_data(m2, r), dict(field=dict(field1=u'foo')))
-
- r.fields = (('field', ('field2')),)
- self.assertEqual(_object_to_data(m3, r), dict(field=dict(field2=u'bar')))
-
diff --git a/djangorestframework/tests/reverse.py b/djangorestframework/tests/reverse.py
index b4b0a793..2d1ca79e 100644
--- a/djangorestframework/tests/reverse.py
+++ b/djangorestframework/tests/reverse.py
@@ -24,9 +24,5 @@ class ReverseTests(TestCase):
urls = 'djangorestframework.tests.reverse'
def test_reversed_urls_are_fully_qualified(self):
- try:
- response = self.client.get('/')
- except:
- import traceback
- traceback.print_exc()
+ response = self.client.get('/')
self.assertEqual(json.loads(response.content), 'http://testserver/another')
diff --git a/djangorestframework/tests/serializer.py b/djangorestframework/tests/serializer.py
new file mode 100644
index 00000000..9f629050
--- /dev/null
+++ b/djangorestframework/tests/serializer.py
@@ -0,0 +1,117 @@
+"""Tests for the resource module"""
+from django.test import TestCase
+from djangorestframework.serializer import Serializer
+
+from django.db import models
+
+import datetime
+import decimal
+
+class TestObjectToData(TestCase):
+ """
+ Tests for the Serializer class.
+ """
+
+ def setUp(self):
+ self.serializer = Serializer()
+ self.serialize = self.serializer.serialize
+
+ def test_decimal(self):
+ """Decimals need to be converted to a string representation."""
+ self.assertEquals(self.serialize(decimal.Decimal('1.5')), decimal.Decimal('1.5'))
+
+ def test_function(self):
+ """Functions with no arguments should be called."""
+ def foo():
+ return 1
+ self.assertEquals(self.serialize(foo), 1)
+
+ def test_method(self):
+ """Methods with only a ``self`` argument should be called."""
+ class Foo(object):
+ def foo(self):
+ return 1
+ self.assertEquals(self.serialize(Foo().foo), 1)
+
+ def test_datetime(self):
+ """
+ datetime objects are left as-is.
+ """
+ now = datetime.datetime.now()
+ self.assertEquals(self.serialize(now), now)
+
+
+class TestFieldNesting(TestCase):
+ """
+ Test nesting the fields in the Serializer class
+ """
+ def setUp(self):
+ self.serializer = Serializer()
+ self.serialize = self.serializer.serialize
+
+ class M1(models.Model):
+ field1 = models.CharField()
+ field2 = models.CharField()
+
+ class M2(models.Model):
+ field = models.OneToOneField(M1)
+
+ class M3(models.Model):
+ field = models.ForeignKey(M1)
+
+ self.m1 = M1(field1='foo', field2='bar')
+ self.m2 = M2(field=self.m1)
+ self.m3 = M3(field=self.m1)
+
+
+ def test_tuple_nesting(self):
+ """
+ Test tuple nesting on `fields` attr
+ """
+ class SerializerM2(Serializer):
+ fields = (('field', ('field1',)),)
+
+ class SerializerM3(Serializer):
+ fields = (('field', ('field2',)),)
+
+ self.assertEqual(SerializerM2().serialize(self.m2), {'field': {'field1': u'foo'}})
+ self.assertEqual(SerializerM3().serialize(self.m3), {'field': {'field2': u'bar'}})
+
+
+ def test_serializer_class_nesting(self):
+ """
+ Test related model serialization
+ """
+ class NestedM2(Serializer):
+ fields = ('field1', )
+
+ class NestedM3(Serializer):
+ fields = ('field2', )
+
+ class SerializerM2(Serializer):
+ fields = [('field', NestedM2)]
+
+ class SerializerM3(Serializer):
+ fields = [('field', NestedM3)]
+
+ self.assertEqual(SerializerM2().serialize(self.m2), {'field': {'field1': u'foo'}})
+ self.assertEqual(SerializerM3().serialize(self.m3), {'field': {'field2': u'bar'}})
+
+ def test_serializer_classname_nesting(self):
+ """
+ Test related model serialization
+ """
+ class SerializerM2(Serializer):
+ fields = [('field', 'NestedM2')]
+
+ class SerializerM3(Serializer):
+ fields = [('field', 'NestedM3')]
+
+ class NestedM2(Serializer):
+ fields = ('field1', )
+
+ class NestedM3(Serializer):
+ fields = ('field2', )
+
+ self.assertEqual(SerializerM2().serialize(self.m2), {'field': {'field1': u'foo'}})
+ self.assertEqual(SerializerM3().serialize(self.m3), {'field': {'field2': u'bar'}})
diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py
index a8f08b18..b620ee24 100644
--- a/djangorestframework/tests/throttling.py
+++ b/djangorestframework/tests/throttling.py
@@ -1,38 +1,148 @@
-from django.conf.urls.defaults import patterns
+"""
+Tests for the throttling implementations in the permissions module.
+"""
+
from django.test import TestCase
-from django.utils import simplejson as json
+from django.contrib.auth.models import User
+from django.core.cache import cache
from djangorestframework.compat import RequestFactory
from djangorestframework.views import View
-from djangorestframework.permissions import PerUserThrottling
-
+from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling
+from djangorestframework.resources import FormResource
class MockView(View):
permissions = ( PerUserThrottling, )
- throttle = (3, 1) # 3 requests per second
+ throttle = '3/sec'
def get(self, request):
return 'foo'
-urlpatterns = patterns('',
- (r'^$', MockView.as_view()),
-)
-
-
-#class ThrottlingTests(TestCase):
-# """Basic authentication"""
-# urls = 'djangorestframework.tests.throttling'
-#
-# def test_requests_are_throttled(self):
-# """Ensure request rate is limited"""
-# for dummy in range(3):
-# response = self.client.get('/')
-# response = self.client.get('/')
-#
-# def test_request_throttling_is_per_user(self):
-# """Ensure request rate is only limited per user, not globally"""
-# pass
-#
-# def test_request_throttling_expires(self):
-# """Ensure request rate is limited for a limited duration only"""
-# pass
+class MockView_PerViewThrottling(MockView):
+ permissions = ( PerViewThrottling, )
+
+class MockView_PerResourceThrottling(MockView):
+ permissions = ( PerResourceThrottling, )
+ resource = FormResource
+
+class MockView_MinuteThrottling(MockView):
+ throttle = '3/min'
+
+
+
+class ThrottlingTests(TestCase):
+ urls = 'djangorestframework.tests.throttling'
+
+ def setUp(self):
+ """
+ Reset the cache so that no throttles will be active
+ """
+ cache.clear()
+ self.factory = RequestFactory()
+
+ def test_requests_are_throttled(self):
+ """
+ Ensure request rate is limited
+ """
+ request = self.factory.get('/')
+ for dummy in range(4):
+ response = MockView.as_view()(request)
+ self.assertEqual(503, response.status_code)
+
+ def set_throttle_timer(self, view, value):
+ """
+ Explicitly set the timer, overriding time.time()
+ """
+ view.permissions[0].timer = lambda self: value
+
+ def test_request_throttling_expires(self):
+ """
+ Ensure request rate is limited for a limited duration only
+ """
+ self.set_throttle_timer(MockView, 0)
+
+ request = self.factory.get('/')
+ for dummy in range(4):
+ response = MockView.as_view()(request)
+ self.assertEqual(503, response.status_code)
+
+ # Advance the timer by one second
+ self.set_throttle_timer(MockView, 1)
+
+ response = MockView.as_view()(request)
+ self.assertEqual(200, response.status_code)
+
+ def ensure_is_throttled(self, view, expect):
+ request = self.factory.get('/')
+ request.user = User.objects.create(username='a')
+ for dummy in range(3):
+ view.as_view()(request)
+ request.user = User.objects.create(username='b')
+ response = view.as_view()(request)
+ self.assertEqual(expect, response.status_code)
+
+ def test_request_throttling_is_per_user(self):
+ """
+ Ensure request rate is only limited per user, not globally for
+ PerUserThrottles
+ """
+ self.ensure_is_throttled(MockView, 200)
+
+ def test_request_throttling_is_per_view(self):
+ """
+ Ensure request rate is limited globally per View for PerViewThrottles
+ """
+ self.ensure_is_throttled(MockView_PerViewThrottling, 503)
+
+ def test_request_throttling_is_per_resource(self):
+ """
+ Ensure request rate is limited globally per Resource for PerResourceThrottles
+ """
+ self.ensure_is_throttled(MockView_PerResourceThrottling, 503)
+
+
+ def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
+ """
+ Ensure the response returns an X-Throttle field with status and next attributes
+ set properly.
+ """
+ request = self.factory.get('/')
+ for timer, expect in expected_headers:
+ self.set_throttle_timer(view, timer)
+ response = view.as_view()(request)
+ self.assertEquals(response['X-Throttle'], expect)
+
+ def test_seconds_fields(self):
+ """
+ Ensure for second based throttles.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView,
+ ((0, 'status=SUCCESS; next=0.33 sec'),
+ (0, 'status=SUCCESS; next=0.50 sec'),
+ (0, 'status=SUCCESS; next=1.00 sec'),
+ (0, 'status=FAILURE; next=1.00 sec')
+ ))
+
+ def test_minutes_fields(self):
+ """
+ Ensure for minute based throttles.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
+ ((0, 'status=SUCCESS; next=20.00 sec'),
+ (0, 'status=SUCCESS; next=30.00 sec'),
+ (0, 'status=SUCCESS; next=60.00 sec'),
+ (0, 'status=FAILURE; next=60.00 sec')
+ ))
+
+ def test_next_rate_remains_constant_if_followed(self):
+ """
+ If a client follows the recommended next request rate,
+ the throttling rate should stay constant.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
+ ((0, 'status=SUCCESS; next=20.00 sec'),
+ (20, 'status=SUCCESS; next=20.00 sec'),
+ (40, 'status=SUCCESS; next=20.00 sec'),
+ (60, 'status=SUCCESS; next=20.00 sec'),
+ (80, 'status=SUCCESS; next=20.00 sec')
+ ))
diff --git a/djangorestframework/views.py b/djangorestframework/views.py
index 6f2ab5b7..18d064e1 100644
--- a/djangorestframework/views.py
+++ b/djangorestframework/views.py
@@ -64,7 +64,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
"""
permissions = ( permissions.FullAnonAccess, )
-
+
@classmethod
def as_view(cls, **initkwargs):
"""
@@ -101,6 +101,14 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
"""
pass
+
+ def add_header(self, field, value):
+ """
+ Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class.
+ """
+ self.headers[field] = value
+
+
# Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt.
@csrf_exempt
@@ -108,6 +116,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
self.request = request
self.args = args
self.kwargs = kwargs
+ self.headers = {}
# Calls to 'reverse' will not be fully qualified unless we set the scheme/host/port here.
prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host())
@@ -149,7 +158,10 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
# also it's currently sub-obtimal for HTTP caching - need to sort that out.
response.headers['Allow'] = ', '.join(self.allowed_methods)
response.headers['Vary'] = 'Authenticate, Accept'
-
+
+ # merge with headers possibly set at some point in the view
+ response.headers.update(self.headers)
+
return self.render(response)
diff --git a/docs/examples/blogpost.rst b/docs/examples/blogpost.rst
index 36b9d982..be91913d 100644
--- a/docs/examples/blogpost.rst
+++ b/docs/examples/blogpost.rst
@@ -18,9 +18,19 @@ In this example we're working from two related models:
Creating the resources
----------------------
-Once we have some existing models there's very little we need to do to create the API.
-Firstly create a resource for each model that defines which fields we want to expose on the model.
-Secondly we map a base view and an instance view for each resource.
+We need to create two resources that we map to our two existing models, in order to describe how the models should be serialized.
+Our resource descriptions will typically go into a module called something like 'resources.py'
+
+``resources.py``
+
+.. include:: ../../examples/blogpost/resources.py
+ :literal:
+
+Creating views for our resources
+--------------------------------
+
+Once we've created the resources there's very little we need to do to create the API.
+For each resource we'll create a base view, and an instance view.
The generic views :class:`.ListOrCreateModelView` and :class:`.InstanceModelView` provide default operations for listing, creating and updating our models via the API, and also automatically provide input validation using default ModelForms for each model.
``urls.py``
diff --git a/docs/examples/modelviews.rst b/docs/examples/modelviews.rst
index 7cc78d39..c60c9f24 100644
--- a/docs/examples/modelviews.rst
+++ b/docs/examples/modelviews.rst
@@ -25,7 +25,14 @@ Here's the model we're working from in this example:
.. include:: ../../examples/modelresourceexample/models.py
:literal:
-To add an API for the model, all we need to do is create a Resource for the model, and map a couple of views to it in our urlconf.
+To add an API for the model, first we need to create a Resource for the model.
+
+``resources.py``
+
+.. include:: ../../examples/modelresourceexample/resources.py
+ :literal:
+
+Then we simply map a couple of views to the Resource in our urlconf.
``urls.py``
diff --git a/docs/howto/alternativeframeworks.rst b/docs/howto/alternativeframeworks.rst
index c6eba1dd..dc8d1ea6 100644
--- a/docs/howto/alternativeframeworks.rst
+++ b/docs/howto/alternativeframeworks.rst
@@ -1,6 +1,35 @@
-Alternative Frameworks
-======================
+Alternative frameworks & Why Django REST framework
+==================================================
-#. `django-piston <https://bitbucket.org/jespern/django-piston/wiki/Home>`_ is excellent, and has a great community behind it. This project is based on piston code in parts.
+Alternative frameworks
+----------------------
-#. `django-tasypie <https://github.com/toastdriven/django-tastypie>`_ is also well worth looking at.
+There are a number of alternative REST frameworks for Django:
+
+* `django-piston <https://bitbucket.org/jespern/django-piston/wiki/Home>`_ is very mature, and has a large community behind it. This project was originally based on piston code in parts.
+* `django-tasypie <https://github.com/toastdriven/django-tastypie>`_ is also very good, and has a very active and helpful developer community and maintainers.
+* Other interesting projects include `dagny <https://github.com/zacharyvoase/dagny>`_ and `dj-webmachine <http://benoitc.github.com/dj-webmachine/>`_
+
+
+Why use Django REST framework?
+------------------------------
+
+The big benefits of using Django REST framework come down to:
+
+1. It's based on Django's class based views, which makes it simple, modular, and future-proof.
+2. It stays as close as possible to Django idioms and language throughout.
+3. The browse-able API makes working with the APIs extremely quick and easy.
+
+
+Why was this project created?
+-----------------------------
+
+For me the browse-able API is the most important aspect of Django REST framework.
+
+I wanted to show that Web APIs could easily be made Web browse-able,
+and demonstrate how much better browse-able Web APIs are to work with.
+
+Being able to navigate and use a Web API directly in the browser is a huge win over only having command line and programmatic
+access to the API. It enables the API to be properly self-describing, and it makes it much much quicker and easier to work with.
+There's no fundamental reason why the Web APIs we're creating shouldn't be able to render to HTML as well as JSON/XML/whatever,
+and I really think that more Web API frameworks *in whatever language* ought to be taking a similar approach.
diff --git a/docs/index.rst b/docs/index.rst
index dfa361bd..8a285271 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -31,7 +31,7 @@ Resources
* The ``djangorestframework`` package is `available on PyPI <http://pypi.python.org/pypi/djangorestframework>`_.
* We have an active `discussion group <http://groups.google.com/group/django-rest-framework>`_ and a `project blog <http://blog.django-rest-framework.org>`_.
* Bug reports are handled on the `issue tracker <https://github.com/tomchristie/django-rest-framework/issues>`_.
-* There is a `Jenkins CI server <http://datacenter.tibold.nl/job/djangorestframework/>`_ which tracks test status and coverage reporting. (Thanks Marko!)
+* There is a `Jenkins CI server <http://jenkins.tibold.nl/job/djangorestframework/>`_ which tracks test status and coverage reporting. (Thanks Marko!)
Any and all questions, thoughts, bug reports and contributions are *hugely appreciated*.
@@ -83,8 +83,8 @@ Using Django REST framework can be as simple as adding a few lines to your urlco
model = MyModel
urlpatterns = patterns('',
- url(r'^$', RootModelResource.as_view(resource=MyResource)),
- url(r'^(?P<pk>[^/]+)/$', ModelResource.as_view(resource=MyResource)),
+ url(r'^$', ListOrCreateModelView.as_view(resource=MyResource)),
+ url(r'^(?P<pk>[^/]+)/$', InstanceModelView.as_view(resource=MyResource)),
)
Django REST framework comes with two "getting started" examples.
@@ -134,6 +134,7 @@ Library Reference
library/renderers
library/resource
library/response
+ library/serializer
library/status
library/views
diff --git a/docs/library/serializer.rst b/docs/library/serializer.rst
new file mode 100644
index 00000000..63dd3308
--- /dev/null
+++ b/docs/library/serializer.rst
@@ -0,0 +1,5 @@
+:mod:`serializer`
+=================
+
+.. automodule:: serializer
+ :members:
diff --git a/examples/blogpost/models.py b/examples/blogpost/models.py
index c4925a15..d77f530d 100644
--- a/examples/blogpost/models.py
+++ b/examples/blogpost/models.py
@@ -22,6 +22,9 @@ class BlogPost(models.Model):
slug = models.SlugField(editable=False, default='')
def save(self, *args, **kwargs):
+ """
+ For the purposes of the sandbox, limit the maximum number of stored models.
+ """
self.slug = slugify(self.title)
super(self.__class__, self).save(*args, **kwargs)
for obj in self.__class__.objects.order_by('-created')[MAX_POSTS:]:
diff --git a/examples/blogpost/resources.py b/examples/blogpost/resources.py
new file mode 100644
index 00000000..9b91ed73
--- /dev/null
+++ b/examples/blogpost/resources.py
@@ -0,0 +1,27 @@
+from django.core.urlresolvers import reverse
+from djangorestframework.resources import ModelResource
+from blogpost.models import BlogPost, Comment
+
+
+class BlogPostResource(ModelResource):
+ """
+ A Blog Post has a *title* and *content*, and can be associated with zero or more comments.
+ """
+ model = BlogPost
+ fields = ('created', 'title', 'slug', 'content', 'url', 'comments')
+ ordering = ('-created',)
+
+ def comments(self, instance):
+ return reverse('comments', kwargs={'blogpost': instance.key})
+
+
+class CommentResource(ModelResource):
+ """
+ A Comment is associated with a given Blog Post and has a *username* and *comment*, and optionally a *rating*.
+ """
+ model = Comment
+ fields = ('username', 'comment', 'created', 'rating', 'url', 'blogpost')
+ ordering = ('-created',)
+
+ def blogpost(self, instance):
+ return reverse('blog-post', kwargs={'key': instance.blogpost.key}) \ No newline at end of file
diff --git a/examples/blogpost/tests.py b/examples/blogpost/tests.py
index 30b152fa..e55f0f90 100644
--- a/examples/blogpost/tests.py
+++ b/examples/blogpost/tests.py
@@ -7,9 +7,10 @@ from django.core.urlresolvers import reverse
from django.utils import simplejson as json
from djangorestframework.compat import RequestFactory
+from djangorestframework.views import InstanceModelView, ListOrCreateModelView
-from blogpost import models
-import blogpost
+from blogpost import models, urls
+#import blogpost
# class AcceptHeaderTests(TestCase):
@@ -178,32 +179,33 @@ class TestRotation(TestCase):
models.BlogPost.objects.all().delete()
def test_get_to_root(self):
- '''Simple test to demonstrate how the requestfactory needs to be used'''
+ '''Simple get to the *root* url of blogposts'''
request = self.factory.get('/blog-post')
- view = views.BlogPosts.as_view()
+ view = ListOrCreateModelView.as_view(resource=urls.BlogPostResource)
response = view(request)
self.assertEqual(response.status_code, 200)
def test_blogposts_not_exceed_MAX_POSTS(self):
'''Posting blog-posts should not result in more than MAX_POSTS items stored.'''
- for post in range(views.MAX_POSTS + 5):
+ for post in range(models.MAX_POSTS + 5):
form_data = {'title': 'This is post #%s' % post, 'content': 'This is the content of post #%s' % post}
request = self.factory.post('/blog-post', data=form_data)
- view = views.BlogPosts.as_view()
+ view = ListOrCreateModelView.as_view(resource=urls.BlogPostResource)
view(request)
- self.assertEquals(len(models.BlogPost.objects.all()),views.MAX_POSTS)
+ self.assertEquals(len(models.BlogPost.objects.all()),models.MAX_POSTS)
def test_fifo_behaviour(self):
'''It's fine that the Blogposts are capped off at MAX_POSTS. But we want to make sure we see FIFO behaviour.'''
for post in range(15):
form_data = {'title': '%s' % post, 'content': 'This is the content of post #%s' % post}
request = self.factory.post('/blog-post', data=form_data)
- view = views.BlogPosts.as_view()
+ view = ListOrCreateModelView.as_view(resource=urls.BlogPostResource)
view(request)
request = self.factory.get('/blog-post')
- view = views.BlogPosts.as_view()
+ view = ListOrCreateModelView.as_view(resource=urls.BlogPostResource)
response = view(request)
response_posts = json.loads(response.content)
response_titles = [d['title'] for d in response_posts]
- self.assertEquals(response_titles, ['%s' % i for i in range(views.MAX_POSTS - 5, views.MAX_POSTS + 5)])
+ response_titles.reverse()
+ self.assertEquals(response_titles, ['%s' % i for i in range(models.MAX_POSTS - 5, models.MAX_POSTS + 5)])
\ No newline at end of file
diff --git a/examples/blogpost/urls.py b/examples/blogpost/urls.py
index c677b8fa..e9bd2754 100644
--- a/examples/blogpost/urls.py
+++ b/examples/blogpost/urls.py
@@ -1,36 +1,11 @@
from django.conf.urls.defaults import patterns, url
-from django.core.urlresolvers import reverse
-
from djangorestframework.views import ListOrCreateModelView, InstanceModelView
-from djangorestframework.resources import ModelResource
-
-from blogpost.models import BlogPost, Comment
-
-
-class BlogPostResource(ModelResource):
- """
- A Blog Post has a *title* and *content*, and can be associated with zero or more comments.
- """
- model = BlogPost
- fields = ('created', 'title', 'slug', 'content', 'url', 'comments')
- ordering = ('-created',)
-
- def comments(self, instance):
- return reverse('comments', kwargs={'blogpost': instance.key})
-
-
-class CommentResource(ModelResource):
- """
- A Comment is associated with a given Blog Post and has a *username* and *comment*, and optionally a *rating*.
- """
- model = Comment
- fields = ('username', 'comment', 'created', 'rating', 'url', 'blogpost')
- ordering = ('-created',)
+from blogpost.resources import BlogPostResource, CommentResource
urlpatterns = patterns('',
url(r'^$', ListOrCreateModelView.as_view(resource=BlogPostResource), name='blog-posts-root'),
- url(r'^(?P<key>[^/]+)/$', InstanceModelView.as_view(resource=BlogPostResource)),
+ url(r'^(?P<key>[^/]+)/$', InstanceModelView.as_view(resource=BlogPostResource), name='blog-post'),
url(r'^(?P<blogpost>[^/]+)/comments/$', ListOrCreateModelView.as_view(resource=CommentResource), name='comments'),
url(r'^(?P<blogpost>[^/]+)/comments/(?P<id>[^/]+)/$', InstanceModelView.as_view(resource=CommentResource)),
)
diff --git a/examples/modelresourceexample/resources.py b/examples/modelresourceexample/resources.py
new file mode 100644
index 00000000..634ea6b3
--- /dev/null
+++ b/examples/modelresourceexample/resources.py
@@ -0,0 +1,7 @@
+from djangorestframework.resources import ModelResource
+from modelresourceexample.models import MyModel
+
+class MyModelResource(ModelResource):
+ model = MyModel
+ fields = ('foo', 'bar', 'baz', 'url')
+ ordering = ('created',)
diff --git a/examples/modelresourceexample/urls.py b/examples/modelresourceexample/urls.py
index bb71ddd3..b6a16542 100644
--- a/examples/modelresourceexample/urls.py
+++ b/examples/modelresourceexample/urls.py
@@ -1,14 +1,8 @@
from django.conf.urls.defaults import patterns, url
from djangorestframework.views import ListOrCreateModelView, InstanceModelView
-from djangorestframework.resources import ModelResource
-from modelresourceexample.models import MyModel
-
-class MyModelResource(ModelResource):
- model = MyModel
- fields = ('foo', 'bar', 'baz', 'url')
- ordering = ('created',)
+from modelresourceexample.resources import MyModelResource
urlpatterns = patterns('',
url(r'^$', ListOrCreateModelView.as_view(resource=MyModelResource), name='model-resource-root'),
- url(r'^([0-9]+)/$', InstanceModelView.as_view(resource=MyModelResource)),
+ url(r'^(?P<pk>[0-9]+)/$', InstanceModelView.as_view(resource=MyModelResource)),
)
diff --git a/examples/modelresourceexample/views.py b/examples/permissionsexample/__init__.py
index e69de29b..e69de29b 100644
--- a/examples/modelresourceexample/views.py
+++ b/examples/permissionsexample/__init__.py
diff --git a/examples/permissionsexample/urls.py b/examples/permissionsexample/urls.py
new file mode 100644
index 00000000..d17f5159
--- /dev/null
+++ b/examples/permissionsexample/urls.py
@@ -0,0 +1,6 @@
+from django.conf.urls.defaults import patterns, url
+from permissionsexample.views import ThrottlingExampleView
+
+urlpatterns = patterns('',
+ url(r'^$', ThrottlingExampleView.as_view(), name='throttled-resource'),
+)
diff --git a/examples/permissionsexample/views.py b/examples/permissionsexample/views.py
new file mode 100644
index 00000000..20e7cba7
--- /dev/null
+++ b/examples/permissionsexample/views.py
@@ -0,0 +1,20 @@
+from djangorestframework.views import View
+from djangorestframework.permissions import PerUserThrottling
+
+
+class ThrottlingExampleView(View):
+ """
+ A basic read-only View that has a **per-user throttle** of 10 requests per minute.
+
+ If a user exceeds the 10 requests limit within a period of one minute, the
+ throttle will be applied until 60 seconds have passed since the first request.
+ """
+
+ permissions = ( PerUserThrottling, )
+ throttle = '10/min'
+
+ def get(self, request):
+ """
+ Handle GET requests.
+ """
+ return "Successful response to GET request because throttle is not yet active." \ No newline at end of file
diff --git a/examples/pygments_api/views.py b/examples/pygments_api/views.py
index 76647107..e50029f6 100644
--- a/examples/pygments_api/views.py
+++ b/examples/pygments_api/views.py
@@ -46,19 +46,12 @@ class HTMLRenderer(BaseRenderer):
media_type = 'text/html'
-
-class PygmentsFormResource(FormResource):
- """
- """
- form = PygmentsForm
-
-
class PygmentsRoot(View):
"""
- This example demonstrates a simple RESTful Web API aound the awesome pygments library.
+ This example demonstrates a simple RESTful Web API around the awesome pygments library.
This top level resource is used to create highlighted code snippets, and to list all the existing code snippets.
"""
- resource = PygmentsFormResource
+ form = PygmentsForm
def get(self, request):
"""
diff --git a/examples/requirements.txt b/examples/requirements.txt
index 09cda945..0bcd8d43 100644
--- a/examples/requirements.txt
+++ b/examples/requirements.txt
@@ -1,8 +1,6 @@
-# For the examples we need Django, pygments and httplib2...
+# Pygments for the code highlighting example,
+# markdown for the docstring -> auto-documentation
-Django==1.2.4
-wsgiref==0.1.2
Pygments==1.4
-httplib2==0.6.0
Markdown==2.0.3
diff --git a/examples/sandbox/views.py b/examples/sandbox/views.py
index 1c55c28f..1e326f43 100644
--- a/examples/sandbox/views.py
+++ b/examples/sandbox/views.py
@@ -31,4 +31,6 @@ class Sandbox(View):
{'name': 'Simple Mixin-only example', 'url': reverse('mixin-view')},
{'name': 'Object store API', 'url': reverse('object-store-root')},
{'name': 'Code highlighting API', 'url': reverse('pygments-root')},
- {'name': 'Blog posts API', 'url': reverse('blog-posts-root')}]
+ {'name': 'Blog posts API', 'url': reverse('blog-posts-root')},
+ {'name': 'Permissions example', 'url': reverse('throttled-resource')}
+ ]
diff --git a/examples/urls.py b/examples/urls.py
index cf4d4042..08d97a14 100644
--- a/examples/urls.py
+++ b/examples/urls.py
@@ -10,6 +10,7 @@ urlpatterns = patterns('',
(r'^object-store/', include('objectstore.urls')),
(r'^pygments/', include('pygments_api.urls')),
(r'^blog-post/', include('blogpost.urls')),
+ (r'^permissions-example/', include('permissionsexample.urls')),
(r'^', include('djangorestframework.urls')),
)