diff options
43 files changed, 1328 insertions, 370 deletions
@@ -1,11 +1,16 @@ -Tom Christie <tomchristie> - tom@tomchristie.com, @thisneonsoul - +Tom Christie <tomchristie> - tom@tomchristie.com, @thisneonsoul - Author. Paul Bagwell <pbgwl> - Suggestions & bugfixes. -Marko Tibold <markotibold> - Contributions & Providing the Hudson CI Server. +Marko Tibold <markotibold> - Contributions & Providing the Jenkins CI Server. Sébastien Piquemal <sebpiq> - Contributions. Carmen Wick <cwick> - Bugfixes. Alex Ehlke <aehlke> - Design Contributions. +Alen Mujezinovic <flashingpumpkin> - Contributions. +Carles Barrobés <txels> - HEAD support. +Michael Fötsch <mfoetsch> - File format support. +David Larlet <david> - OAuth support. +Andrew Straw <astraw> - Bugfixes. THANKS TO: + Jesper Noehr <jespern> & the django-piston contributors for providing the starting point for this project. -And of course, to the Django core team and the Django community at large. +And of course, to the Django core team and the Django community at large. You guys rock. @@ -52,6 +52,16 @@ To create the sdist packages Release Notes ============= +0.2.3 + + * Fix some throttling bugs + * X-Throttle header on throttling + * Support for nesting resources on related models + +0.2.2 + + * Throttling support complete + 0.2.1 * Couple of simple bugfixes over 0.2.0 diff --git a/djangorestframework/__init__.py b/djangorestframework/__init__.py index 1ee96d36..b1ef6dda 100644 --- a/djangorestframework/__init__.py +++ b/djangorestframework/__init__.py @@ -1,3 +1,3 @@ -__version__ = '0.2.1' +__version__ = '0.2.3' VERSION = __version__ # synonym diff --git a/djangorestframework/compat.py b/djangorestframework/compat.py index 0274511a..827b4adf 100644 --- a/djangorestframework/compat.py +++ b/djangorestframework/compat.py @@ -67,6 +67,14 @@ except ImportError: # django.views.generic.View (Django >= 1.3) try: from django.views.generic import View + if not hasattr(View, 'head'): + # First implementation of Django class-based views did not include head method + # in base View class - https://code.djangoproject.com/ticket/15668 + class ViewPlusHead(View): + def head(self, request, *args, **kwargs): + return self.get(request, *args, **kwargs) + View = ViewPlusHead + except ImportError: from django import http from django.utils.functional import update_wrapper @@ -145,6 +153,8 @@ except ImportError: #) return http.HttpResponseNotAllowed(allowed_methods) + def head(self, request, *args, **kwargs): + return self.get(request, *args, **kwargs) try: import markdown @@ -193,4 +203,4 @@ try: return md.convert(text) except ImportError: - apply_markdown = None
\ No newline at end of file + apply_markdown = None diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py index 11e3bb38..1b3aa241 100644 --- a/djangorestframework/mixins.py +++ b/djangorestframework/mixins.py @@ -11,6 +11,7 @@ from django.http.multipartparser import LimitBytes from djangorestframework import status from djangorestframework.parsers import FormParser, MultiPartParser +from djangorestframework.renderers import BaseRenderer from djangorestframework.resources import Resource, FormResource, ModelResource from djangorestframework.response import Response, ErrorResponse from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX @@ -290,7 +291,7 @@ class ResponseMixin(object): accept_list = [token.strip() for token in request.META["HTTP_ACCEPT"].split(',')] else: # No accept header specified - return (self._default_renderer(self), self._default_renderer.media_type) + accept_list = ['*/*'] # Check the acceptable media types against each renderer, # attempting more specific media types first @@ -298,12 +299,12 @@ class ResponseMixin(object): # Worst case is we're looping over len(accept_list) * len(self.renderers) renderers = [renderer_cls(self) for renderer_cls in self.renderers] - for media_type_lst in order_by_precedence(accept_list): + for accepted_media_type_lst in order_by_precedence(accept_list): for renderer in renderers: - for media_type in media_type_lst: - if renderer.can_handle_response(media_type): - return renderer, media_type - + for accepted_media_type in accepted_media_type_lst: + if renderer.can_handle_response(accepted_media_type): + return renderer, accepted_media_type + # No acceptable renderers were found raise ErrorResponse(status.HTTP_406_NOT_ACCEPTABLE, {'detail': 'Could not satisfy the client\'s Accept header', @@ -316,6 +317,13 @@ class ResponseMixin(object): Return an list of all the media types that this view can render. """ return [renderer.media_type for renderer in self.renderers] + + @property + def _rendered_formats(self): + """ + Return a list of all the formats that this view can render. + """ + return [renderer.format for renderer in self.renderers] @property def _default_renderer(self): @@ -466,7 +474,7 @@ class InstanceMixin(object): # We do a little dance when we store the view callable... # we need to store it wrapped in a 1-tuple, so that inspect will treat it # as a function when we later look it up (rather than turning it into a method). - # This makes sure our URL reversing works ok. + # This makes sure our URL reversing works ok. resource.view_callable = (view,) return view @@ -479,13 +487,17 @@ class ReadModelMixin(object): """ def get(self, request, *args, **kwargs): model = self.resource.model + try: if args: # If we have any none kwargs then assume the last represents the primrary key instance = model.objects.get(pk=args[-1], **kwargs) else: # Otherwise assume the kwargs uniquely identify the model - instance = model.objects.get(**kwargs) + filtered_keywords = kwargs.copy() + if BaseRenderer._FORMAT_QUERY_PARAM in filtered_keywords: + del filtered_keywords[BaseRenderer._FORMAT_QUERY_PARAM] + instance = model.objects.get(**filtered_keywords) except model.DoesNotExist: raise ErrorResponse(status.HTTP_404_NOT_FOUND) @@ -498,6 +510,7 @@ class CreateModelMixin(object): """ def post(self, request, *args, **kwargs): model = self.resource.model + # translated 'related_field' kwargs into 'related_field_id' for related_name in [field.name for field in model._meta.fields if isinstance(field, RelatedField)]: if kwargs.has_key(related_name): @@ -522,6 +535,7 @@ class UpdateModelMixin(object): """ def put(self, request, *args, **kwargs): model = self.resource.model + # TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url try: if args: @@ -547,6 +561,7 @@ class DeleteModelMixin(object): """ def delete(self, request, *args, **kwargs): model = self.resource.model + try: if args: # If we have any none kwargs then assume the last represents the primrary key @@ -581,8 +596,15 @@ class ListModelMixin(object): queryset = None def get(self, request, *args, **kwargs): - queryset = self.queryset if self.queryset else self.resource.model.objects.all() - ordering = getattr(self.resource, 'ordering', None) + model = self.resource.model + + queryset = self.queryset if self.queryset else model.objects.all() + + if hasattr(self, 'resource'): + ordering = getattr(self.resource, 'ordering', None) + else: + ordering = None + if ordering: args = as_tuple(ordering) queryset = queryset.order_by(*args) diff --git a/djangorestframework/parsers.py b/djangorestframework/parsers.py index 726e09e9..37882984 100644 --- a/djangorestframework/parsers.py +++ b/djangorestframework/parsers.py @@ -11,12 +11,12 @@ We need a method to be able to: and multipart/form-data. (eg also handle multipart/json) """ +from django.http import QueryDict from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser +from django.http.multipartparser import MultiPartParserError from django.utils import simplejson as json from djangorestframework import status -from djangorestframework.compat import parse_qs from djangorestframework.response import ErrorResponse -from djangorestframework.utils import as_tuple from djangorestframework.utils.mediatypes import media_type_matches __all__ = ( @@ -117,7 +117,7 @@ class FormParser(BaseParser): `data` will be a :class:`QueryDict` containing all the form parameters. `files` will always be :const:`None`. """ - data = parse_qs(stream.read(), keep_blank_values=True) + data = QueryDict(stream.read()) return (data, None) @@ -136,6 +136,10 @@ class MultiPartParser(BaseParser): `files` will be a :class:`QueryDict` containing all the form files. """ upload_handlers = self.view.request._get_upload_handlers() - django_parser = DjangoMultiPartParser(self.view.request.META, stream, upload_handlers) + try: + django_parser = DjangoMultiPartParser(self.view.request.META, stream, upload_handlers) + except MultiPartParserError, exc: + raise ErrorResponse(status.HTTP_400_BAD_REQUEST, + {'detail': 'multipart parse error - %s' % unicode(exc)}) return django_parser.parse() diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py index 1f6151f8..7dcabcf0 100644 --- a/djangorestframework/permissions.py +++ b/djangorestframework/permissions.py @@ -1,6 +1,6 @@ """ The :mod:`permissions` module bundles a set of permission classes that are used -for checking if a request passes a certain set of constraints. You can assign a permision +for checking if a request passes a certain set of constraints. You can assign a permission class to your view by setting your View's :attr:`permissions` class attribute. """ @@ -15,7 +15,9 @@ __all__ = ( 'IsAuthenticated', 'IsAdminUser', 'IsUserOrIsAnonReadOnly', - 'PerUserThrottling' + 'PerUserThrottling', + 'PerViewThrottling', + 'PerResourceThrottling' ) @@ -24,12 +26,11 @@ _403_FORBIDDEN_RESPONSE = ErrorResponse( {'detail': 'You do not have permission to access this resource. ' + 'You may need to login or otherwise authenticate the request.'}) -_503_THROTTLED_RESPONSE = ErrorResponse( +_503_SERVICE_UNAVAILABLE = ErrorResponse( status.HTTP_503_SERVICE_UNAVAILABLE, {'detail': 'request was throttled'}) - class BasePermission(object): """ A base class from which all permission classes should inherit. @@ -88,37 +89,131 @@ class IsUserOrIsAnonReadOnly(BasePermission): raise _403_FORBIDDEN_RESPONSE -class PerUserThrottling(BasePermission): +class BaseThrottle(BasePermission): """ - Rate throttling of requests on a per-user basis. + Rate throttling of requests. - The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class. - The attribute is a two tuple of the form (number of requests, duration in seconds). + The rate (requests / seconds) is set by a :attr:`throttle` attribute + on the :class:`.View` class. The attribute is a string of the form 'number of + requests/period'. - The user id will be used as a unique identifier if the user is authenticated. - For anonymous requests, the IP address of the client will be used. + Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') Previous request information used for throttling is stored in the cache. + """ + + attr_name = 'throttle' + default = '0/sec' + timer = time.time + + def get_cache_key(self): + """ + Should return a unique cache-key which can be used for throttling. + Muse be overridden. + """ + pass + + def check_permission(self, auth): + """ + Check the throttling. + Return `None` or raise an :exc:`.ErrorResponse`. + """ + num, period = getattr(self.view, self.attr_name, self.default).split('/') + self.num_requests = int(num) + self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] + self.auth = auth + self.check_throttle() + + def check_throttle(self): + """ + Implement the check to see if the request should be throttled. + + On success calls :meth:`throttle_success`. + On failure calls :meth:`throttle_failure`. + """ + self.key = self.get_cache_key() + self.history = cache.get(self.key, []) + self.now = self.timer() + + # Drop any requests from the history which have now passed the + # throttle duration + while self.history and self.history[-1] <= self.now - self.duration: + self.history.pop() + if len(self.history) >= self.num_requests: + self.throttle_failure() + else: + self.throttle_success() + + def throttle_success(self): + """ + Inserts the current request's timestamp along with the key + into the cache. + """ + self.history.insert(0, self.now) + cache.set(self.key, self.history, self.duration) + header = 'status=SUCCESS; next=%s sec' % self.next() + self.view.add_header('X-Throttle', header) + + def throttle_failure(self): + """ + Called when a request to the API has failed due to throttling. + Raises a '503 service unavailable' response. + """ + header = 'status=FAILURE; next=%s sec' % self.next() + self.view.add_header('X-Throttle', header) + raise _503_SERVICE_UNAVAILABLE + + def next(self): + """ + Returns the recommended next request time in seconds. + """ + if self.history: + remaining_duration = self.duration - (self.now - self.history[-1]) + else: + remaining_duration = self.duration + + available_requests = self.num_requests - len(self.history) + 1 + + return '%.2f' % (remaining_duration / float(available_requests)) + + +class PerUserThrottling(BaseThrottle): """ + Limits the rate of API calls that may be made by a given user. - def check_permission(self, user): - (num_requests, duration) = getattr(self.view, 'throttle', (0, 0)) + The user id will be used as a unique identifier if the user is + authenticated. For anonymous requests, the IP address of the client will + be used. + """ - if user.is_authenticated(): - ident = str(auth) + def get_cache_key(self): + if self.auth.is_authenticated(): + ident = str(self.auth) else: ident = self.view.request.META.get('REMOTE_ADDR', None) + return 'throttle_user_%s' % ident - key = 'throttle_%s' % ident - history = cache.get(key, []) - now = time.time() - - # Drop any requests from the history which have now passed the throttle duration - while history and history[0] < now - duration: - history.pop() - if len(history) >= num_requests: - raise _503_THROTTLED_RESPONSE +class PerViewThrottling(BaseThrottle): + """ + Limits the rate of API calls that may be used on a given view. + + The class name of the view is used as a unique identifier to + throttle against. + """ + + def get_cache_key(self): + return 'throttle_view_%s' % self.view.__class__.__name__ + + +class PerResourceThrottling(BaseThrottle): + """ + Limits the rate of API calls that may be used against all views on + a given resource. + + The class name of the resource is used as a unique identifier to + throttle against. + """ - history.insert(0, now) - cache.set(key, history, duration) + def get_cache_key(self): + return 'throttle_resource_%s' % self.view.resource.__class__.__name__ diff --git a/djangorestframework/renderers.py b/djangorestframework/renderers.py index 9834ba5e..e09e2abc 100644 --- a/djangorestframework/renderers.py +++ b/djangorestframework/renderers.py @@ -17,6 +17,7 @@ from djangorestframework.utils import dict2xml, url_resolves from djangorestframework.utils.breadcrumbs import get_breadcrumbs from djangorestframework.utils.description import get_name, get_description from djangorestframework.utils.mediatypes import get_media_type_params, add_media_type_param, media_type_matches +from djangorestframework import VERSION from decimal import Decimal import re @@ -39,8 +40,11 @@ class BaseRenderer(object): All renderers must extend this class, set the :attr:`media_type` attribute, and override the :meth:`render` method. """ + + _FORMAT_QUERY_PARAM = 'format' media_type = None + format = None def __init__(self, view): self.view = view @@ -57,6 +61,11 @@ class BaseRenderer(object): This may be overridden to provide for other behavior, but typically you'll instead want to just set the :attr:`media_type` attribute on the class. """ + format = self.view.kwargs.get(self._FORMAT_QUERY_PARAM, None) + if format is None: + format = self.view.request.GET.get(self._FORMAT_QUERY_PARAM, None) + if format is not None: + return format == self.format return media_type_matches(self.media_type, accept) def render(self, obj=None, media_type=None): @@ -83,6 +92,7 @@ class JSONRenderer(BaseRenderer): """ media_type = 'application/json' + format = 'json' def render(self, obj=None, media_type=None): """ @@ -108,7 +118,9 @@ class XMLRenderer(BaseRenderer): """ Renderer which serializes to XML. """ + media_type = 'application/xml' + format = 'xml' def render(self, obj=None, media_type=None): """ @@ -181,7 +193,7 @@ class DocumentingTemplateRenderer(BaseRenderer): # Get the form instance if we have one bound to the input form_instance = None - if method == view.method.lower(): + if method == getattr(view, 'method', view.request.method).lower(): form_instance = getattr(view, 'bound_form_instance', None) if not form_instance and hasattr(view, 'get_bound_form'): @@ -251,6 +263,7 @@ class DocumentingTemplateRenderer(BaseRenderer): The context used in the template contains all the information needed to self-document the response to this request. """ + content = self._get_content(self.view, self.view.request, obj, media_type) put_form_instance = self._get_form_instance(self.view, 'put') @@ -283,14 +296,15 @@ class DocumentingTemplateRenderer(BaseRenderer): 'response': self.view.response, 'description': description, 'name': name, + 'version': VERSION, 'markeddown': markeddown, 'breadcrumblist': breadcrumb_list, - 'available_media_types': self.view._rendered_media_types, + 'available_formats': self.view._rendered_formats, 'put_form': put_form_instance, 'post_form': post_form_instance, 'login_url': login_url, 'logout_url': logout_url, - 'ACCEPT_PARAM': getattr(self.view, '_ACCEPT_QUERY_PARAM', None), + 'FORMAT_PARAM': self._FORMAT_QUERY_PARAM, 'METHOD_PARAM': getattr(self.view, '_METHOD_PARAM', None), 'ADMIN_MEDIA_PREFIX': settings.ADMIN_MEDIA_PREFIX }) @@ -313,6 +327,7 @@ class DocumentingHTMLRenderer(DocumentingTemplateRenderer): """ media_type = 'text/html' + format = 'html' template = 'renderer.html' @@ -324,6 +339,7 @@ class DocumentingXHTMLRenderer(DocumentingTemplateRenderer): """ media_type = 'application/xhtml+xml' + format = 'xhtml' template = 'renderer.html' @@ -335,6 +351,7 @@ class DocumentingPlainTextRenderer(DocumentingTemplateRenderer): """ media_type = 'text/plain' + format = 'txt' template = 'renderer.txt' diff --git a/djangorestframework/resources.py b/djangorestframework/resources.py index 4b81bf79..b42bd952 100644 --- a/djangorestframework/resources.py +++ b/djangorestframework/resources.py @@ -6,6 +6,7 @@ from django.db.models.fields.related import RelatedField from django.utils.encoding import smart_unicode from djangorestframework.response import ErrorResponse +from djangorestframework.serializer import Serializer, _SkipField from djangorestframework.utils import as_tuple import decimal @@ -13,122 +14,9 @@ import inspect import re -# TODO: _IgnoreFieldException -# Map model classes to resource classes -#_model_to_resource = {} - -def _model_to_dict(instance, resource=None): - """ - Given a model instance, return a ``dict`` representing the model. - - The implementation is similar to Django's ``django.forms.model_to_dict``, except: - - * It doesn't coerce related objects into primary keys. - * It doesn't drop ``editable=False`` fields. - * It also supports attribute or method fields on the instance or resource. - """ - opts = instance._meta - data = {} - - #print [rel.name for rel in opts.get_all_related_objects()] - #related = [rel.get_accessor_name() for rel in opts.get_all_related_objects()] - #print [getattr(instance, rel) for rel in related] - #if resource.fields: - # fields = resource.fields - #else: - # fields = set(opts.fields + opts.many_to_many) - - fields = resource and resource.fields or () - include = resource and resource.include or () - exclude = resource and resource.exclude or () - - extra_fields = fields and list(fields) or list(include) - - # Model fields - for f in opts.fields + opts.many_to_many: - if fields and not f.name in fields: - continue - if exclude and f.name in exclude: - continue - if isinstance(f, models.ForeignKey): - data[f.name] = getattr(instance, f.name) - else: - data[f.name] = f.value_from_object(instance) - - if extra_fields and f.name in extra_fields: - extra_fields.remove(f.name) - - # Method fields - for fname in extra_fields: - - if isinstance(fname, (tuple, list)): - fname, fields = fname - else: - fname, fields = fname, False - - try: - if hasattr(resource, fname): - # check the resource first, to allow it to override fields - obj = getattr(resource, fname) - # if it's a method like foo(self, instance), then call it - if inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) == 2: - obj = obj(instance) - elif hasattr(instance, fname): - # now check the object instance - obj = getattr(instance, fname) - else: - continue - - # TODO: It would be nicer if this didn't recurse here. - # Let's keep _model_to_dict flat, and _object_to_data recursive. - if fields: - Resource = type('Resource', (object,), {'fields': fields, - 'include': (), - 'exclude': ()}) - data[fname] = _object_to_data(obj, Resource()) - else: - data[fname] = _object_to_data(obj) - - except NoReverseMatch: - # Ug, bit of a hack for now - pass - - return data - - -def _object_to_data(obj, resource=None): - """ - Convert an object into a serializable representation. - """ - if isinstance(obj, dict): - # dictionaries - # TODO: apply same _model_to_dict logic fields/exclude here - return dict([ (key, _object_to_data(val)) for key, val in obj.iteritems() ]) - if isinstance(obj, (tuple, list, set, QuerySet)): - # basic iterables - return [_object_to_data(item, resource) for item in obj] - if isinstance(obj, models.Manager): - # Manager objects - return [_object_to_data(item, resource) for item in obj.all()] - if isinstance(obj, models.Model): - # Model instances - return _object_to_data(_model_to_dict(obj, resource)) - if isinstance(obj, decimal.Decimal): - # Decimals (force to string representation) - return str(obj) - if inspect.isfunction(obj) and not inspect.getargspec(obj)[0]: - # function with no args - return _object_to_data(obj(), resource) - if inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1: - # bound method - return _object_to_data(obj(), resource) - - return smart_unicode(obj, strings_only=True) - - -class BaseResource(object): +class BaseResource(Serializer): """ Base class for all Resource classes, which simply defines the interface they provide. """ @@ -136,7 +24,8 @@ class BaseResource(object): include = None exclude = None - def __init__(self, view): + def __init__(self, view=None, depth=None, stack=[], **kwargs): + super(BaseResource, self).__init__(depth, stack, **kwargs) self.view = view def validate_request(self, data, files=None): @@ -150,7 +39,7 @@ class BaseResource(object): """ Given the response content, filter it into a serializable object. """ - return _object_to_data(obj, self) + return self.serialize(obj) class Resource(BaseResource): @@ -241,14 +130,12 @@ class FormResource(Resource): # In addition to regular validation we also ensure no additional fields are being passed in... unknown_fields = seen_fields_set - (form_fields_set | allowed_extra_fields_set) unknown_fields = unknown_fields - set(('csrfmiddlewaretoken', '_accept', '_method')) # TODO: Ugh. - + # Check using both regular validation, and our stricter no additional fields rule if bound_form.is_valid() and not unknown_fields: # Validation succeeded... cleaned_data = bound_form.cleaned_data - cleaned_data.update(bound_form.files) - # Add in any extra fields to the cleaned content... for key in (allowed_extra_fields_set & seen_fields_set) - set(cleaned_data.keys()): cleaned_data[key] = data[key] @@ -299,7 +186,7 @@ class FormResource(Resource): """ # A form on the view overrides a form on the resource. - form = getattr(self.view, 'form', self.form) + form = getattr(self.view, 'form', None) or self.form # Use the requested method or determine the request method if method is None and hasattr(self.view, 'request') and hasattr(self.view, 'method'): @@ -316,7 +203,7 @@ class FormResource(Resource): if not form: return None - if data is not None: + if data is not None or files is not None: return form(data, files) return form() @@ -392,8 +279,8 @@ class ModelResource(FormResource): """ super(ModelResource, self).__init__(view) - if getattr(view, 'model', None): - self.model = view.model + self.model = getattr(view, 'model', None) or self.model + def validate_request(self, data, files=None): """ @@ -455,7 +342,7 @@ class ModelResource(FormResource): """ if not hasattr(self, 'view_callable'): - raise NoReverseMatch + raise _SkipField # dis does teh magicks... urlconf = get_urlconf() @@ -478,13 +365,13 @@ class ModelResource(FormResource): if isinstance(attr, models.Model): instance_attrs[param] = attr.pk else: - instance_attrs[param] = attr + instance_attrs[param] = attr try: return reverse(self.view_callable[0], kwargs=instance_attrs) except NoReverseMatch: pass - raise NoReverseMatch + raise _SkipField @property diff --git a/djangorestframework/response.py b/djangorestframework/response.py index d68ececf..311e0bb7 100644 --- a/djangorestframework/response.py +++ b/djangorestframework/response.py @@ -16,13 +16,13 @@ class Response(object): An HttpResponse that may include content that hasn't yet been serialized. """ - def __init__(self, status=200, content=None, headers={}): + def __init__(self, status=200, content=None, headers=None): self.status = status self.media_type = None self.has_content_body = content is not None self.raw_content = content # content prior to filtering self.cleaned_content = content # content after filtering - self.headers = headers + self.headers = headers or {} @property def status_text(self): diff --git a/djangorestframework/runtests/runtests.py b/djangorestframework/runtests/runtests.py index a3cdfa67..1da918f5 100644 --- a/djangorestframework/runtests/runtests.py +++ b/djangorestframework/runtests/runtests.py @@ -13,21 +13,27 @@ os.environ['DJANGO_SETTINGS_MODULE'] = 'djangorestframework.runtests.settings' from django.conf import settings from django.test.utils import get_runner +def usage(): + return """ + Usage: python runtests.py [UnitTestClass].[method] + + You can pass the Class name of the `UnitTestClass` you want to test. + + Append a method name if you only want to test a specific method of that class. + """ + def main(): TestRunner = get_runner(settings) - if hasattr(TestRunner, 'func_name'): - # Pre 1.2 test runners were just functions, - # and did not support the 'failfast' option. - import warnings - warnings.warn( - 'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.', - DeprecationWarning - ) - failures = TestRunner(['djangorestframework']) + test_runner = TestRunner() + if len(sys.argv) == 2: + test_case = '.' + sys.argv[1] + elif len(sys.argv) == 1: + test_case = '' else: - test_runner = TestRunner() - failures = test_runner.run_tests(['djangorestframework']) + print usage() + sys.exit(1) + failures = test_runner.run_tests(['djangorestframework' + test_case]) sys.exit(failures) diff --git a/djangorestframework/runtests/settings.py b/djangorestframework/runtests/settings.py index 0cc7f4e3..006727bc 100644 --- a/djangorestframework/runtests/settings.py +++ b/djangorestframework/runtests/settings.py @@ -2,6 +2,7 @@ DEBUG = True TEMPLATE_DEBUG = DEBUG +DEBUG_PROPAGATE_EXCEPTIONS = True ADMINS = ( # ('Your Name', 'your_email@domain.com'), @@ -83,7 +84,7 @@ TEMPLATE_DIRS = ( # Don't forget to use absolute paths, not relative paths. ) -INSTALLED_APPS = ( +INSTALLED_APPS = [ 'django.contrib.auth', 'django.contrib.contenttypes', 'django.contrib.sessions', @@ -94,8 +95,17 @@ INSTALLED_APPS = ( # Uncomment the next line to enable admin documentation: # 'django.contrib.admindocs', 'djangorestframework', -) +] + +# OAuth support is optional, so we only test oauth if it's installed. +try: + import oauth_provider +except: + pass +else: + INSTALLED_APPS.append('oauth_provider') +# If we're running on the Jenkins server we want to archive the coverage reports as XML. import os if os.environ.get('HUDSON_URL', None): TEST_RUNNER = 'xmlrunner.extra.djangotestrunner.XMLTestRunner' diff --git a/djangorestframework/serializer.py b/djangorestframework/serializer.py new file mode 100644 index 00000000..da8036e9 --- /dev/null +++ b/djangorestframework/serializer.py @@ -0,0 +1,310 @@ +""" +Customizable serialization. +""" +from django.db import models +from django.db.models.query import QuerySet +from django.db.models.fields.related import RelatedField +from django.utils.encoding import smart_unicode, is_protected_type + +import decimal +import inspect +import types + + +# We register serializer classes, so that we can refer to them by their +# class names, if there are cyclical serialization heirachys. +_serializers = {} + + +def _field_to_tuple(field): + """ + Convert an item in the `fields` attribute into a 2-tuple. + """ + if isinstance(field, (tuple, list)): + return (field[0], field[1]) + return (field, None) + +def _fields_to_list(fields): + """ + Return a list of field names. + """ + return [_field_to_tuple(field)[0] for field in fields or ()] + +def _fields_to_dict(fields): + """ + Return a `dict` of field name -> None, or tuple of fields, or Serializer class + """ + return dict([_field_to_tuple(field) for field in fields or ()]) + + +class _SkipField(Exception): + """ + Signals that a serialized field should be ignored. + We use this mechanism as the default behavior for ensuring + that we don't infinitely recurse when dealing with nested data. + """ + pass + + +class _RegisterSerializer(type): + """ + Metaclass to register serializers. + """ + def __new__(cls, name, bases, attrs): + # Build the class and register it. + ret = super(_RegisterSerializer, cls).__new__(cls, name, bases, attrs) + _serializers[name] = ret + return ret + + +class Serializer(object): + """ + Converts python objects into plain old native types suitable for + serialization. In particular it handles models and querysets. + + The output format is specified by setting a number of attributes + on the class. + + You may also override any of the serialization methods, to provide + for more flexible behavior. + + Valid output types include anything that may be directly rendered into + json, xml etc... + """ + __metaclass__ = _RegisterSerializer + + fields = () + """ + Specify the fields to be serialized on a model or dict. + Overrides `include` and `exclude`. + """ + + include = () + """ + Fields to add to the default set to be serialized on a model/dict. + """ + + exclude = () + """ + Fields to remove from the default set to be serialized on a model/dict. + """ + + rename = {} + """ + A dict of key->name to use for the field keys. + """ + + related_serializer = None + """ + The default serializer class to use for any related models. + """ + + depth = None + """ + The maximum depth to serialize to, or `None`. + """ + + + def __init__(self, depth=None, stack=[], **kwargs): + self.depth = depth or self.depth + self.stack = stack + + + def get_fields(self, obj): + """ + Return the set of field names/keys to use for a model instance/dict. + """ + fields = self.fields + + # If `fields` is not set, we use the default fields and modify + # them with `include` and `exclude` + if not fields: + default = self.get_default_fields(obj) + include = self.include or () + exclude = self.exclude or () + fields = set(default + list(include)) - set(exclude) + + else: + fields = _fields_to_list(self.fields) + + return fields + + + def get_default_fields(self, obj): + """ + Return the default list of field names/keys for a model instance/dict. + These are used if `fields` is not given. + """ + if isinstance(obj, models.Model): + opts = obj._meta + return [field.name for field in opts.fields + opts.many_to_many] + else: + return obj.keys() + + + def get_related_serializer(self, key): + info = _fields_to_dict(self.fields).get(key, None) + + # If an element in `fields` is a 2-tuple of (str, tuple) + # then the second element of the tuple is the fields to + # set on the related serializer + if isinstance(info, (list, tuple)): + class OnTheFlySerializer(Serializer): + fields = info + return OnTheFlySerializer + + # If an element in `fields` is a 2-tuple of (str, Serializer) + # then the second element of the tuple is the Serializer + # class to use for that field. + elif isinstance(info, type) and issubclass(info, Serializer): + return info + + # If an element in `fields` is a 2-tuple of (str, str) + # then the second element of the tuple is the name of the Serializer + # class to use for that field. + # + # Black magic to deal with cyclical Serializer dependancies. + # Similar to what Django does for cyclically related models. + elif isinstance(info, str) and info in _serializers: + return _serializers[info] + + # Otherwise use `related_serializer` or fall back to `Serializer` + return getattr(self, 'related_serializer') or Serializer + + + def serialize_key(self, key): + """ + Keys serialize to their string value, + unless they exist in the `rename` dict. + """ + return getattr(self.rename, key, key) + + + def serialize_val(self, key, obj): + """ + Convert a model field or dict value into a serializable representation. + """ + related_serializer = self.get_related_serializer(key) + + if self.depth is None: + depth = None + elif self.depth <= 0: + return self.serialize_max_depth(obj) + else: + depth = self.depth - 1 + + if any([obj is elem for elem in self.stack]): + return self.serialize_recursion(obj) + else: + stack = self.stack[:] + stack.append(obj) + + return related_serializer(depth=depth, stack=stack).serialize(obj) + + + def serialize_max_depth(self, obj): + """ + Determine how objects should be serialized once `depth` is exceeded. + The default behavior is to ignore the field. + """ + raise _SkipField + + + def serialize_recursion(self, obj): + """ + Determine how objects should be serialized if recursion occurs. + The default behavior is to ignore the field. + """ + raise _SkipField + + + def serialize_model(self, instance): + """ + Given a model instance or dict, serialize it to a dict.. + """ + data = {} + + fields = self.get_fields(instance) + + # serialize each required field + for fname in fields: + if hasattr(self, fname): + # check for a method 'fname' on self first + meth = getattr(self, fname) + if inspect.ismethod(meth) and len(inspect.getargspec(meth)[0]) == 2: + obj = meth(instance) + elif hasattr(instance, fname): + # now check for an attribute 'fname' on the instance + obj = getattr(instance, fname) + elif fname in instance: + # finally check for a key 'fname' on the instance + obj = instance[fname] + else: + continue + + try: + key = self.serialize_key(fname) + val = self.serialize_val(fname, obj) + data[key] = val + except _SkipField: + pass + + return data + + + def serialize_iter(self, obj): + """ + Convert iterables into a serializable representation. + """ + return [self.serialize(item) for item in obj] + + + def serialize_func(self, obj): + """ + Convert no-arg methods and functions into a serializable representation. + """ + return self.serialize(obj()) + + + def serialize_manager(self, obj): + """ + Convert a model manager into a serializable representation. + """ + return self.serialize_iter(obj.all()) + + + def serialize_fallback(self, obj): + """ + Convert any unhandled object into a serializable representation. + """ + return smart_unicode(obj, strings_only=True) + + + def serialize(self, obj): + """ + Convert any object into a serializable representation. + """ + + if isinstance(obj, (dict, models.Model)): + # Model instances & dictionaries + return self.serialize_model(obj) + elif isinstance(obj, (tuple, list, set, QuerySet, types.GeneratorType)): + # basic iterables + return self.serialize_iter(obj) + elif isinstance(obj, models.Manager): + # Manager objects + return self.serialize_manager(obj) + elif inspect.isfunction(obj) and not inspect.getargspec(obj)[0]: + # function with no args + return self.serialize_func(obj) + elif inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1: + # bound method + return self.serialize_func(obj) + + # Protected types are passed through as is. + # (i.e. Primitives like None, numbers, dates, and Decimals.) + if is_protected_type(obj): + return obj + + # All other values are converted to string. + return self.serialize_fallback(obj) diff --git a/djangorestframework/templates/renderer.html b/djangorestframework/templates/renderer.html index 94748d28..5b32d1ec 100644 --- a/djangorestframework/templates/renderer.html +++ b/djangorestframework/templates/renderer.html @@ -18,7 +18,7 @@ <div id="header"> <div id="branding"> - <h1 id="site-name"><a href='http://django-rest-framework.org'>Django REST framework</a></h1> + <h1 id="site-name"><a href='http://django-rest-framework.org'>Django REST framework</a> <small>{{ version }}</small></h1> </div> <div id="user-tools"> {% if user.is_active %}Welcome, {{ user }}.{% if logout_url %} <a href='{{ logout_url }}'>Log out</a>{% endif %}{% else %}Anonymous {% if login_url %}<a href='{{ login_url }}'>Log in</a>{% endif %}{% endif %} @@ -48,9 +48,9 @@ <h2>GET {{ name }}</h2> <div class='submit-row' style='margin: 0; border: 0'> <a href='{{ request.get_full_path }}' rel="nofollow" style='float: left'>GET</a> - {% for media_type in available_media_types %} - {% with ACCEPT_PARAM|add:"="|add:media_type as param %} - [<a href='{{ request.get_full_path|add_query_param:param }}' rel="nofollow">{{ media_type }}</a>] + {% for format in available_formats %} + {% with FORMAT_PARAM|add:"="|add:format as param %} + [<a href='{{ request.get_full_path|add_query_param:param }}' rel="nofollow">{{ format }}</a>] {% endwith %} {% endfor %} </div> @@ -122,4 +122,4 @@ </div> </div> </body> -</html>
\ No newline at end of file +</html> diff --git a/djangorestframework/tests/content.py b/djangorestframework/tests/content.py index ee3597a4..83ad72d0 100644 --- a/djangorestframework/tests/content.py +++ b/djangorestframework/tests/content.py @@ -6,7 +6,6 @@ from djangorestframework.compat import RequestFactory from djangorestframework.mixins import RequestMixin from djangorestframework.parsers import FormParser, MultiPartParser, PlainTextParser - class TestContentParsing(TestCase): def setUp(self): self.req = RequestFactory() @@ -16,6 +15,11 @@ class TestContentParsing(TestCase): view.request = self.req.get('/') self.assertEqual(view.DATA, None) + def ensure_determines_no_content_HEAD(self, view): + """Ensure view.DATA returns None for HEAD request.""" + view.request = self.req.head('/') + self.assertEqual(view.DATA, None) + def ensure_determines_form_content_POST(self, view): """Ensure view.DATA returns content for POST request with form content.""" form_data = {'qwerty': 'uiop'} @@ -50,6 +54,10 @@ class TestContentParsing(TestCase): """Ensure view.DATA returns None for GET request with no content.""" self.ensure_determines_no_content_GET(RequestMixin()) + def test_standard_behaviour_determines_no_content_HEAD(self): + """Ensure view.DATA returns None for HEAD request.""" + self.ensure_determines_no_content_HEAD(RequestMixin()) + def test_standard_behaviour_determines_form_content_POST(self): """Ensure view.DATA returns content for POST request with form content.""" self.ensure_determines_form_content_POST(RequestMixin()) diff --git a/djangorestframework/tests/files.py b/djangorestframework/tests/files.py index 25aad9b4..992d3cba 100644 --- a/djangorestframework/tests/files.py +++ b/djangorestframework/tests/files.py @@ -12,20 +12,16 @@ class UploadFilesTests(TestCase): def test_upload_file(self): - class FileForm(forms.Form): - file = forms.FileField - - class MockResource(FormResource): - form = FileForm + file = forms.FileField() class MockView(View): permissions = () - resource = MockResource + form = FileForm def post(self, request, *args, **kwargs): - return {'FILE_NAME': self.CONTENT['file'][0].name, - 'FILE_CONTENT': self.CONTENT['file'][0].read()} + return {'FILE_NAME': self.CONTENT['file'].name, + 'FILE_CONTENT': self.CONTENT['file'].read()} file = StringIO.StringIO('stuff') file.name = 'stuff.txt' diff --git a/djangorestframework/tests/methods.py b/djangorestframework/tests/methods.py index d8f0d919..c3a3a28d 100644 --- a/djangorestframework/tests/methods.py +++ b/djangorestframework/tests/methods.py @@ -24,3 +24,9 @@ class TestMethodOverloading(TestCase): view = RequestMixin() view.request = self.req.post('/', {view._METHOD_PARAM: 'DELETE'}) self.assertEqual(view.method, 'DELETE') + + def test_HEAD_is_a_valid_method(self): + """HEAD requests identified""" + view = RequestMixin() + view.request = self.req.head('/') + self.assertEqual(view.method, 'HEAD') diff --git a/djangorestframework/tests/oauthentication.py b/djangorestframework/tests/oauthentication.py new file mode 100644 index 00000000..7f74b804 --- /dev/null +++ b/djangorestframework/tests/oauthentication.py @@ -0,0 +1,212 @@ +import time + +from django.conf.urls.defaults import patterns, url, include +from django.contrib.auth.models import User +from django.test import Client, TestCase + +from djangorestframework.views import View + +# Since oauth2 / django-oauth-plus are optional dependancies, we don't want to +# always run these tests. + +# Unfortunatly we can't skip tests easily until 2.7, se we'll just do this for now. +try: + import oauth2 as oauth + from oauth_provider.decorators import oauth_required + from oauth_provider.models import Resource, Consumer, Token + +except: + pass + +else: + # Alrighty, we're good to go here. + class ClientView(View): + def get(self, request): + return {'resource': 'Protected!'} + + urlpatterns = patterns('', + url(r'^$', oauth_required(ClientView.as_view())), + url(r'^oauth/', include('oauth_provider.urls')), + url(r'^accounts/login/$', 'djangorestframework.utils.staticviews.api_login'), + ) + + + class OAuthTests(TestCase): + """ + OAuth authentication: + * the user would like to access his API data from a third-party website + * the third-party website proposes a link to get that API data + * the user is redirected to the API and must log in if not authenticated + * the API displays a webpage to confirm that the user trusts the third-party website + * if confirmed, the user is redirected to the third-party website through the callback view + * the third-party website is able to retrieve data from the API + """ + urls = 'djangorestframework.tests.oauthentication' + + def setUp(self): + self.client = Client() + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + # OAuth requirements + self.resource = Resource(name='data', url='/') + self.resource.save() + self.CONSUMER_KEY = 'dpf43f3p2l4k3l03' + self.CONSUMER_SECRET = 'kd94hf93k423kf44' + self.consumer = Consumer(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, + name='api.example.com', user=self.user) + self.consumer.save() + + def test_oauth_invalid_and_anonymous_access(self): + """ + Verify that the resource is protected and the OAuth authorization view + require the user to be logged in. + """ + response = self.client.get('/') + self.assertEqual(response.content, 'Invalid request parameters.') + self.assertEqual(response.status_code, 401) + response = self.client.get('/oauth/authorize/', follow=True) + self.assertRedirects(response, '/accounts/login/?next=/oauth/authorize/') + + def test_oauth_authorize_access(self): + """ + Verify that once logged in, the user can access the authorization page + but can't display the page because the request token is not specified. + """ + self.client.login(username=self.username, password=self.password) + response = self.client.get('/oauth/authorize/', follow=True) + self.assertEqual(response.content, 'No request token specified.') + + def _create_request_token_parameters(self): + """ + A shortcut to create request's token parameters. + """ + return { + 'oauth_consumer_key': self.CONSUMER_KEY, + 'oauth_signature_method': 'PLAINTEXT', + 'oauth_signature': '%s&' % self.CONSUMER_SECRET, + 'oauth_timestamp': str(int(time.time())), + 'oauth_nonce': 'requestnonce', + 'oauth_version': '1.0', + 'oauth_callback': 'http://api.example.com/request_token_ready', + 'scope': 'data', + } + + def test_oauth_request_token_retrieval(self): + """ + Verify that the request token can be retrieved by the server. + """ + response = self.client.get("/oauth/request_token/", + self._create_request_token_parameters()) + self.assertEqual(response.status_code, 200) + token = list(Token.objects.all())[-1] + self.failIf(token.key not in response.content) + self.failIf(token.secret not in response.content) + + def test_oauth_user_request_authorization(self): + """ + Verify that the user can access the authorization page once logged in + and the request token has been retrieved. + """ + # Setup + response = self.client.get("/oauth/request_token/", + self._create_request_token_parameters()) + token = list(Token.objects.all())[-1] + + # Starting the test here + self.client.login(username=self.username, password=self.password) + parameters = {'oauth_token': token.key,} + response = self.client.get("/oauth/authorize/", parameters) + self.assertEqual(response.status_code, 200) + self.failIf(not response.content.startswith('Fake authorize view for api.example.com with params: oauth_token=')) + self.assertEqual(token.is_approved, 0) + parameters['authorize_access'] = 1 # fake authorization by the user + response = self.client.post("/oauth/authorize/", parameters) + self.assertEqual(response.status_code, 302) + self.failIf(not response['Location'].startswith('http://api.example.com/request_token_ready?oauth_verifier=')) + token = Token.objects.get(key=token.key) + self.failIf(token.key not in response['Location']) + self.assertEqual(token.is_approved, 1) + + def _create_access_token_parameters(self, token): + """ + A shortcut to create access' token parameters. + """ + return { + 'oauth_consumer_key': self.CONSUMER_KEY, + 'oauth_token': token.key, + 'oauth_signature_method': 'PLAINTEXT', + 'oauth_signature': '%s&%s' % (self.CONSUMER_SECRET, token.secret), + 'oauth_timestamp': str(int(time.time())), + 'oauth_nonce': 'accessnonce', + 'oauth_version': '1.0', + 'oauth_verifier': token.verifier, + 'scope': 'data', + } + + def test_oauth_access_token_retrieval(self): + """ + Verify that the request token can be retrieved by the server. + """ + # Setup + response = self.client.get("/oauth/request_token/", + self._create_request_token_parameters()) + token = list(Token.objects.all())[-1] + self.client.login(username=self.username, password=self.password) + parameters = {'oauth_token': token.key,} + response = self.client.get("/oauth/authorize/", parameters) + parameters['authorize_access'] = 1 # fake authorization by the user + response = self.client.post("/oauth/authorize/", parameters) + token = Token.objects.get(key=token.key) + + # Starting the test here + response = self.client.get("/oauth/access_token/", self._create_access_token_parameters(token)) + self.assertEqual(response.status_code, 200) + self.failIf(not response.content.startswith('oauth_token_secret=')) + access_token = list(Token.objects.filter(token_type=Token.ACCESS))[-1] + self.failIf(access_token.key not in response.content) + self.failIf(access_token.secret not in response.content) + self.assertEqual(access_token.user.username, 'john') + + def _create_access_parameters(self, access_token): + """ + A shortcut to create access' parameters. + """ + parameters = { + 'oauth_consumer_key': self.CONSUMER_KEY, + 'oauth_token': access_token.key, + 'oauth_signature_method': 'HMAC-SHA1', + 'oauth_timestamp': str(int(time.time())), + 'oauth_nonce': 'accessresourcenonce', + 'oauth_version': '1.0', + } + oauth_request = oauth.Request.from_token_and_callback(access_token, + http_url='http://testserver/', parameters=parameters) + signature_method = oauth.SignatureMethod_HMAC_SHA1() + signature = signature_method.sign(oauth_request, self.consumer, access_token) + parameters['oauth_signature'] = signature + return parameters + + def test_oauth_protected_resource_access(self): + """ + Verify that the request token can be retrieved by the server. + """ + # Setup + response = self.client.get("/oauth/request_token/", + self._create_request_token_parameters()) + token = list(Token.objects.all())[-1] + self.client.login(username=self.username, password=self.password) + parameters = {'oauth_token': token.key,} + response = self.client.get("/oauth/authorize/", parameters) + parameters['authorize_access'] = 1 # fake authorization by the user + response = self.client.post("/oauth/authorize/", parameters) + token = Token.objects.get(key=token.key) + response = self.client.get("/oauth/access_token/", self._create_access_token_parameters(token)) + access_token = list(Token.objects.filter(token_type=Token.ACCESS))[-1] + + # Starting the test here + response = self.client.get("/", self._create_access_token_parameters(access_token)) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, '{"resource": "Protected!"}') diff --git a/djangorestframework/tests/parsers.py b/djangorestframework/tests/parsers.py index 3ab1a61c..deba688e 100644 --- a/djangorestframework/tests/parsers.py +++ b/djangorestframework/tests/parsers.py @@ -131,3 +131,25 @@ # self.assertEqual(data['key1'], 'val1') # self.assertEqual(files['file1'].read(), 'blablabla') +from StringIO import StringIO +from cgi import parse_qs +from django import forms +from django.test import TestCase +from djangorestframework.parsers import FormParser + +class Form(forms.Form): + field1 = forms.CharField(max_length=3) + field2 = forms.CharField() + +class TestFormParser(TestCase): + def setUp(self): + self.string = "field1=abc&field2=defghijk" + + def test_parse(self): + """ Make sure the `QueryDict` works OK """ + parser = FormParser(None) + + stream = StringIO(self.string) + (data, files) = parser.parse(stream) + + self.assertEqual(Form(data).is_valid(), True) diff --git a/djangorestframework/tests/renderers.py b/djangorestframework/tests/renderers.py index 54276993..bf135e55 100644 --- a/djangorestframework/tests/renderers.py +++ b/djangorestframework/tests/renderers.py @@ -1,13 +1,18 @@ from django.conf.urls.defaults import patterns, url from django import http from django.test import TestCase + +from djangorestframework import status from djangorestframework.compat import View as DjangoView from djangorestframework.renderers import BaseRenderer, JSONRenderer +from djangorestframework.parsers import JSONParser from djangorestframework.mixins import ResponseMixin from djangorestframework.response import Response from djangorestframework.utils.mediatypes import add_media_type_param -DUMMYSTATUS = 200 +from StringIO import StringIO + +DUMMYSTATUS = status.HTTP_200_OK DUMMYCONTENT = 'dummycontent' RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x @@ -15,12 +20,14 @@ RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x class RendererA(BaseRenderer): media_type = 'mock/renderera' + format="formata" def render(self, obj=None, media_type=None): return RENDERER_A_SERIALIZER(obj) class RendererB(BaseRenderer): media_type = 'mock/rendererb' + format="formatb" def render(self, obj=None, media_type=None): return RENDERER_B_SERIALIZER(obj) @@ -28,11 +35,13 @@ class RendererB(BaseRenderer): class MockView(ResponseMixin, DjangoView): renderers = (RendererA, RendererB) - def get(self, request): + def get(self, request, **kwargs): response = Response(DUMMYSTATUS, DUMMYCONTENT) return self.render(response) + urlpatterns = patterns('', + url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderers=[RendererA, RendererB])), url(r'^$', MockView.as_view(renderers=[RendererA, RendererB])), ) @@ -51,6 +60,13 @@ class RendererIntegrationTests(TestCase): self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.status_code, DUMMYSTATUS) + def test_head_method_serializes_no_content(self): + """No response must be included in HEAD requests.""" + resp = self.client.head('/') + self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEquals(resp['Content-Type'], RendererA.media_type) + self.assertEquals(resp.content, '') + def test_default_renderer_serializes_content_on_accept_any(self): """If the Accept header is set to */* the default renderer should serialize the response.""" resp = self.client.get('/', HTTP_ACCEPT='*/*') @@ -74,12 +90,58 @@ class RendererIntegrationTests(TestCase): self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.status_code, DUMMYSTATUS) + def test_specified_renderer_serializes_content_on_accept_query(self): + """The '_accept' query string should behave in the same way as the Accept header.""" + resp = self.client.get('/?_accept=%s' % RendererB.media_type) + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + def test_unsatisfiable_accept_header_on_request_returns_406_status(self): """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" resp = self.client.get('/', HTTP_ACCEPT='foo/bar') - self.assertEquals(resp.status_code, 406) + self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) + def test_specified_renderer_serializes_content_on_format_query(self): + """If a 'format' query is specified, the renderer with the matching + format attribute should serialize the response.""" + resp = self.client.get('/?format=%s' % RendererB.format) + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + def test_specified_renderer_serializes_content_on_format_kwargs(self): + """If a 'format' keyword arg is specified, the renderer with the matching + format attribute should serialize the response.""" + resp = self.client.get('/something.formatb') + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): + """If both a 'format' query and a matching Accept header specified, + the renderer with the matching format attribute should serialize the response.""" + resp = self.client.get('/?format=%s' % RendererB.format, + HTTP_ACCEPT=RendererB.media_type) + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_conflicting_format_query_and_accept_ignores_accept(self): + """If a 'format' query is specified that does not match the Accept + header, we should only honor the 'format' query string.""" + resp = self.client.get('/?format=%s' % RendererB.format, + HTTP_ACCEPT='dummy') + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_bla(self): + resp = self.client.get('/?format=formatb', + HTTP_ACCEPT='text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8') + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) _flat_repr = '{"foo": ["bar", "baz"]}' @@ -95,14 +157,35 @@ class JSONRendererTests(TestCase): """ Tests specific to the JSON Renderer """ + def test_without_content_type_args(self): + """ + Test basic JSON rendering. + """ obj = {'foo':['bar','baz']} renderer = JSONRenderer(None) content = renderer.render(obj, 'application/json') self.assertEquals(content, _flat_repr) def test_with_content_type_args(self): + """ + Test JSON rendering with additional content type arguments supplied. + """ obj = {'foo':['bar','baz']} renderer = JSONRenderer(None) content = renderer.render(obj, 'application/json; indent=2') self.assertEquals(content, _indented_repr) + + def test_render_and_parse(self): + """ + Test rendering and then parsing returns the original object. + IE obj -> render -> parse -> obj. + """ + obj = {'foo':['bar','baz']} + + renderer = JSONRenderer(None) + parser = JSONParser(None) + + content = renderer.render(obj, 'application/json') + (data, files) = parser.parse(StringIO(content)) + self.assertEquals(obj, data) diff --git a/djangorestframework/tests/resources.py b/djangorestframework/tests/resources.py deleted file mode 100644 index 088e3159..00000000 --- a/djangorestframework/tests/resources.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Tests for the resource module""" -from django.test import TestCase -from djangorestframework.resources import _object_to_data - -from django.db import models - -import datetime -import decimal - -class TestObjectToData(TestCase): - """Tests for the _object_to_data function""" - - def test_decimal(self): - """Decimals need to be converted to a string representation.""" - self.assertEquals(_object_to_data(decimal.Decimal('1.5')), '1.5') - - def test_function(self): - """Functions with no arguments should be called.""" - def foo(): - return 1 - self.assertEquals(_object_to_data(foo), 1) - - def test_method(self): - """Methods with only a ``self`` argument should be called.""" - class Foo(object): - def foo(self): - return 1 - self.assertEquals(_object_to_data(Foo().foo), 1) - - def test_datetime(self): - """datetime objects are left as-is.""" - now = datetime.datetime.now() - self.assertEquals(_object_to_data(now), now) - - def test_tuples(self): - """ Test tuple serialisation """ - class M1(models.Model): - field1 = models.CharField() - field2 = models.CharField() - - class M2(models.Model): - field = models.OneToOneField(M1) - - class M3(models.Model): - field = models.ForeignKey(M1) - - m1 = M1(field1='foo', field2='bar') - m2 = M2(field=m1) - m3 = M3(field=m1) - - Resource = type('Resource', (object,), {'fields':(), 'include':(), 'exclude':()}) - - r = Resource() - r.fields = (('field', ('field1')),) - - self.assertEqual(_object_to_data(m2, r), dict(field=dict(field1=u'foo'))) - - r.fields = (('field', ('field2')),) - self.assertEqual(_object_to_data(m3, r), dict(field=dict(field2=u'bar'))) - diff --git a/djangorestframework/tests/reverse.py b/djangorestframework/tests/reverse.py index b4b0a793..2d1ca79e 100644 --- a/djangorestframework/tests/reverse.py +++ b/djangorestframework/tests/reverse.py @@ -24,9 +24,5 @@ class ReverseTests(TestCase): urls = 'djangorestframework.tests.reverse' def test_reversed_urls_are_fully_qualified(self): - try: - response = self.client.get('/') - except: - import traceback - traceback.print_exc() + response = self.client.get('/') self.assertEqual(json.loads(response.content), 'http://testserver/another') diff --git a/djangorestframework/tests/serializer.py b/djangorestframework/tests/serializer.py new file mode 100644 index 00000000..9f629050 --- /dev/null +++ b/djangorestframework/tests/serializer.py @@ -0,0 +1,117 @@ +"""Tests for the resource module""" +from django.test import TestCase +from djangorestframework.serializer import Serializer + +from django.db import models + +import datetime +import decimal + +class TestObjectToData(TestCase): + """ + Tests for the Serializer class. + """ + + def setUp(self): + self.serializer = Serializer() + self.serialize = self.serializer.serialize + + def test_decimal(self): + """Decimals need to be converted to a string representation.""" + self.assertEquals(self.serialize(decimal.Decimal('1.5')), decimal.Decimal('1.5')) + + def test_function(self): + """Functions with no arguments should be called.""" + def foo(): + return 1 + self.assertEquals(self.serialize(foo), 1) + + def test_method(self): + """Methods with only a ``self`` argument should be called.""" + class Foo(object): + def foo(self): + return 1 + self.assertEquals(self.serialize(Foo().foo), 1) + + def test_datetime(self): + """ + datetime objects are left as-is. + """ + now = datetime.datetime.now() + self.assertEquals(self.serialize(now), now) + + +class TestFieldNesting(TestCase): + """ + Test nesting the fields in the Serializer class + """ + def setUp(self): + self.serializer = Serializer() + self.serialize = self.serializer.serialize + + class M1(models.Model): + field1 = models.CharField() + field2 = models.CharField() + + class M2(models.Model): + field = models.OneToOneField(M1) + + class M3(models.Model): + field = models.ForeignKey(M1) + + self.m1 = M1(field1='foo', field2='bar') + self.m2 = M2(field=self.m1) + self.m3 = M3(field=self.m1) + + + def test_tuple_nesting(self): + """ + Test tuple nesting on `fields` attr + """ + class SerializerM2(Serializer): + fields = (('field', ('field1',)),) + + class SerializerM3(Serializer): + fields = (('field', ('field2',)),) + + self.assertEqual(SerializerM2().serialize(self.m2), {'field': {'field1': u'foo'}}) + self.assertEqual(SerializerM3().serialize(self.m3), {'field': {'field2': u'bar'}}) + + + def test_serializer_class_nesting(self): + """ + Test related model serialization + """ + class NestedM2(Serializer): + fields = ('field1', ) + + class NestedM3(Serializer): + fields = ('field2', ) + + class SerializerM2(Serializer): + fields = [('field', NestedM2)] + + class SerializerM3(Serializer): + fields = [('field', NestedM3)] + + self.assertEqual(SerializerM2().serialize(self.m2), {'field': {'field1': u'foo'}}) + self.assertEqual(SerializerM3().serialize(self.m3), {'field': {'field2': u'bar'}}) + + def test_serializer_classname_nesting(self): + """ + Test related model serialization + """ + class SerializerM2(Serializer): + fields = [('field', 'NestedM2')] + + class SerializerM3(Serializer): + fields = [('field', 'NestedM3')] + + class NestedM2(Serializer): + fields = ('field1', ) + + class NestedM3(Serializer): + fields = ('field2', ) + + self.assertEqual(SerializerM2().serialize(self.m2), {'field': {'field1': u'foo'}}) + self.assertEqual(SerializerM3().serialize(self.m3), {'field': {'field2': u'bar'}}) diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py index a8f08b18..b620ee24 100644 --- a/djangorestframework/tests/throttling.py +++ b/djangorestframework/tests/throttling.py @@ -1,38 +1,148 @@ -from django.conf.urls.defaults import patterns +""" +Tests for the throttling implementations in the permissions module. +""" + from django.test import TestCase -from django.utils import simplejson as json +from django.contrib.auth.models import User +from django.core.cache import cache from djangorestframework.compat import RequestFactory from djangorestframework.views import View -from djangorestframework.permissions import PerUserThrottling - +from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling +from djangorestframework.resources import FormResource class MockView(View): permissions = ( PerUserThrottling, ) - throttle = (3, 1) # 3 requests per second + throttle = '3/sec' def get(self, request): return 'foo' -urlpatterns = patterns('', - (r'^$', MockView.as_view()), -) - - -#class ThrottlingTests(TestCase): -# """Basic authentication""" -# urls = 'djangorestframework.tests.throttling' -# -# def test_requests_are_throttled(self): -# """Ensure request rate is limited""" -# for dummy in range(3): -# response = self.client.get('/') -# response = self.client.get('/') -# -# def test_request_throttling_is_per_user(self): -# """Ensure request rate is only limited per user, not globally""" -# pass -# -# def test_request_throttling_expires(self): -# """Ensure request rate is limited for a limited duration only""" -# pass +class MockView_PerViewThrottling(MockView): + permissions = ( PerViewThrottling, ) + +class MockView_PerResourceThrottling(MockView): + permissions = ( PerResourceThrottling, ) + resource = FormResource + +class MockView_MinuteThrottling(MockView): + throttle = '3/min' + + + +class ThrottlingTests(TestCase): + urls = 'djangorestframework.tests.throttling' + + def setUp(self): + """ + Reset the cache so that no throttles will be active + """ + cache.clear() + self.factory = RequestFactory() + + def test_requests_are_throttled(self): + """ + Ensure request rate is limited + """ + request = self.factory.get('/') + for dummy in range(4): + response = MockView.as_view()(request) + self.assertEqual(503, response.status_code) + + def set_throttle_timer(self, view, value): + """ + Explicitly set the timer, overriding time.time() + """ + view.permissions[0].timer = lambda self: value + + def test_request_throttling_expires(self): + """ + Ensure request rate is limited for a limited duration only + """ + self.set_throttle_timer(MockView, 0) + + request = self.factory.get('/') + for dummy in range(4): + response = MockView.as_view()(request) + self.assertEqual(503, response.status_code) + + # Advance the timer by one second + self.set_throttle_timer(MockView, 1) + + response = MockView.as_view()(request) + self.assertEqual(200, response.status_code) + + def ensure_is_throttled(self, view, expect): + request = self.factory.get('/') + request.user = User.objects.create(username='a') + for dummy in range(3): + view.as_view()(request) + request.user = User.objects.create(username='b') + response = view.as_view()(request) + self.assertEqual(expect, response.status_code) + + def test_request_throttling_is_per_user(self): + """ + Ensure request rate is only limited per user, not globally for + PerUserThrottles + """ + self.ensure_is_throttled(MockView, 200) + + def test_request_throttling_is_per_view(self): + """ + Ensure request rate is limited globally per View for PerViewThrottles + """ + self.ensure_is_throttled(MockView_PerViewThrottling, 503) + + def test_request_throttling_is_per_resource(self): + """ + Ensure request rate is limited globally per Resource for PerResourceThrottles + """ + self.ensure_is_throttled(MockView_PerResourceThrottling, 503) + + + def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): + """ + Ensure the response returns an X-Throttle field with status and next attributes + set properly. + """ + request = self.factory.get('/') + for timer, expect in expected_headers: + self.set_throttle_timer(view, timer) + response = view.as_view()(request) + self.assertEquals(response['X-Throttle'], expect) + + def test_seconds_fields(self): + """ + Ensure for second based throttles. + """ + self.ensure_response_header_contains_proper_throttle_field(MockView, + ((0, 'status=SUCCESS; next=0.33 sec'), + (0, 'status=SUCCESS; next=0.50 sec'), + (0, 'status=SUCCESS; next=1.00 sec'), + (0, 'status=FAILURE; next=1.00 sec') + )) + + def test_minutes_fields(self): + """ + Ensure for minute based throttles. + """ + self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, + ((0, 'status=SUCCESS; next=20.00 sec'), + (0, 'status=SUCCESS; next=30.00 sec'), + (0, 'status=SUCCESS; next=60.00 sec'), + (0, 'status=FAILURE; next=60.00 sec') + )) + + def test_next_rate_remains_constant_if_followed(self): + """ + If a client follows the recommended next request rate, + the throttling rate should stay constant. + """ + self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, + ((0, 'status=SUCCESS; next=20.00 sec'), + (20, 'status=SUCCESS; next=20.00 sec'), + (40, 'status=SUCCESS; next=20.00 sec'), + (60, 'status=SUCCESS; next=20.00 sec'), + (80, 'status=SUCCESS; next=20.00 sec') + )) diff --git a/djangorestframework/views.py b/djangorestframework/views.py index 6f2ab5b7..18d064e1 100644 --- a/djangorestframework/views.py +++ b/djangorestframework/views.py @@ -64,7 +64,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): """ permissions = ( permissions.FullAnonAccess, ) - + @classmethod def as_view(cls, **initkwargs): """ @@ -101,6 +101,14 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): """ pass + + def add_header(self, field, value): + """ + Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class. + """ + self.headers[field] = value + + # Note: session based authentication is explicitly CSRF validated, # all other authentication is CSRF exempt. @csrf_exempt @@ -108,6 +116,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): self.request = request self.args = args self.kwargs = kwargs + self.headers = {} # Calls to 'reverse' will not be fully qualified unless we set the scheme/host/port here. prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host()) @@ -149,7 +158,10 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView): # also it's currently sub-obtimal for HTTP caching - need to sort that out. response.headers['Allow'] = ', '.join(self.allowed_methods) response.headers['Vary'] = 'Authenticate, Accept' - + + # merge with headers possibly set at some point in the view + response.headers.update(self.headers) + return self.render(response) diff --git a/docs/examples/blogpost.rst b/docs/examples/blogpost.rst index 36b9d982..be91913d 100644 --- a/docs/examples/blogpost.rst +++ b/docs/examples/blogpost.rst @@ -18,9 +18,19 @@ In this example we're working from two related models: Creating the resources ---------------------- -Once we have some existing models there's very little we need to do to create the API. -Firstly create a resource for each model that defines which fields we want to expose on the model. -Secondly we map a base view and an instance view for each resource. +We need to create two resources that we map to our two existing models, in order to describe how the models should be serialized. +Our resource descriptions will typically go into a module called something like 'resources.py' + +``resources.py`` + +.. include:: ../../examples/blogpost/resources.py + :literal: + +Creating views for our resources +-------------------------------- + +Once we've created the resources there's very little we need to do to create the API. +For each resource we'll create a base view, and an instance view. The generic views :class:`.ListOrCreateModelView` and :class:`.InstanceModelView` provide default operations for listing, creating and updating our models via the API, and also automatically provide input validation using default ModelForms for each model. ``urls.py`` diff --git a/docs/examples/modelviews.rst b/docs/examples/modelviews.rst index 7cc78d39..c60c9f24 100644 --- a/docs/examples/modelviews.rst +++ b/docs/examples/modelviews.rst @@ -25,7 +25,14 @@ Here's the model we're working from in this example: .. include:: ../../examples/modelresourceexample/models.py :literal: -To add an API for the model, all we need to do is create a Resource for the model, and map a couple of views to it in our urlconf. +To add an API for the model, first we need to create a Resource for the model. + +``resources.py`` + +.. include:: ../../examples/modelresourceexample/resources.py + :literal: + +Then we simply map a couple of views to the Resource in our urlconf. ``urls.py`` diff --git a/docs/howto/alternativeframeworks.rst b/docs/howto/alternativeframeworks.rst index c6eba1dd..dc8d1ea6 100644 --- a/docs/howto/alternativeframeworks.rst +++ b/docs/howto/alternativeframeworks.rst @@ -1,6 +1,35 @@ -Alternative Frameworks -====================== +Alternative frameworks & Why Django REST framework +================================================== -#. `django-piston <https://bitbucket.org/jespern/django-piston/wiki/Home>`_ is excellent, and has a great community behind it. This project is based on piston code in parts. +Alternative frameworks +---------------------- -#. `django-tasypie <https://github.com/toastdriven/django-tastypie>`_ is also well worth looking at. +There are a number of alternative REST frameworks for Django: + +* `django-piston <https://bitbucket.org/jespern/django-piston/wiki/Home>`_ is very mature, and has a large community behind it. This project was originally based on piston code in parts. +* `django-tasypie <https://github.com/toastdriven/django-tastypie>`_ is also very good, and has a very active and helpful developer community and maintainers. +* Other interesting projects include `dagny <https://github.com/zacharyvoase/dagny>`_ and `dj-webmachine <http://benoitc.github.com/dj-webmachine/>`_ + + +Why use Django REST framework? +------------------------------ + +The big benefits of using Django REST framework come down to: + +1. It's based on Django's class based views, which makes it simple, modular, and future-proof. +2. It stays as close as possible to Django idioms and language throughout. +3. The browse-able API makes working with the APIs extremely quick and easy. + + +Why was this project created? +----------------------------- + +For me the browse-able API is the most important aspect of Django REST framework. + +I wanted to show that Web APIs could easily be made Web browse-able, +and demonstrate how much better browse-able Web APIs are to work with. + +Being able to navigate and use a Web API directly in the browser is a huge win over only having command line and programmatic +access to the API. It enables the API to be properly self-describing, and it makes it much much quicker and easier to work with. +There's no fundamental reason why the Web APIs we're creating shouldn't be able to render to HTML as well as JSON/XML/whatever, +and I really think that more Web API frameworks *in whatever language* ought to be taking a similar approach. diff --git a/docs/index.rst b/docs/index.rst index dfa361bd..8a285271 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -31,7 +31,7 @@ Resources * The ``djangorestframework`` package is `available on PyPI <http://pypi.python.org/pypi/djangorestframework>`_. * We have an active `discussion group <http://groups.google.com/group/django-rest-framework>`_ and a `project blog <http://blog.django-rest-framework.org>`_. * Bug reports are handled on the `issue tracker <https://github.com/tomchristie/django-rest-framework/issues>`_. -* There is a `Jenkins CI server <http://datacenter.tibold.nl/job/djangorestframework/>`_ which tracks test status and coverage reporting. (Thanks Marko!) +* There is a `Jenkins CI server <http://jenkins.tibold.nl/job/djangorestframework/>`_ which tracks test status and coverage reporting. (Thanks Marko!) Any and all questions, thoughts, bug reports and contributions are *hugely appreciated*. @@ -83,8 +83,8 @@ Using Django REST framework can be as simple as adding a few lines to your urlco model = MyModel urlpatterns = patterns('', - url(r'^$', RootModelResource.as_view(resource=MyResource)), - url(r'^(?P<pk>[^/]+)/$', ModelResource.as_view(resource=MyResource)), + url(r'^$', ListOrCreateModelView.as_view(resource=MyResource)), + url(r'^(?P<pk>[^/]+)/$', InstanceModelView.as_view(resource=MyResource)), ) Django REST framework comes with two "getting started" examples. @@ -134,6 +134,7 @@ Library Reference library/renderers library/resource library/response + library/serializer library/status library/views diff --git a/docs/library/serializer.rst b/docs/library/serializer.rst new file mode 100644 index 00000000..63dd3308 --- /dev/null +++ b/docs/library/serializer.rst @@ -0,0 +1,5 @@ +:mod:`serializer` +================= + +.. automodule:: serializer + :members: diff --git a/examples/blogpost/models.py b/examples/blogpost/models.py index c4925a15..d77f530d 100644 --- a/examples/blogpost/models.py +++ b/examples/blogpost/models.py @@ -22,6 +22,9 @@ class BlogPost(models.Model): slug = models.SlugField(editable=False, default='') def save(self, *args, **kwargs): + """ + For the purposes of the sandbox, limit the maximum number of stored models. + """ self.slug = slugify(self.title) super(self.__class__, self).save(*args, **kwargs) for obj in self.__class__.objects.order_by('-created')[MAX_POSTS:]: diff --git a/examples/blogpost/resources.py b/examples/blogpost/resources.py new file mode 100644 index 00000000..9b91ed73 --- /dev/null +++ b/examples/blogpost/resources.py @@ -0,0 +1,27 @@ +from django.core.urlresolvers import reverse +from djangorestframework.resources import ModelResource +from blogpost.models import BlogPost, Comment + + +class BlogPostResource(ModelResource): + """ + A Blog Post has a *title* and *content*, and can be associated with zero or more comments. + """ + model = BlogPost + fields = ('created', 'title', 'slug', 'content', 'url', 'comments') + ordering = ('-created',) + + def comments(self, instance): + return reverse('comments', kwargs={'blogpost': instance.key}) + + +class CommentResource(ModelResource): + """ + A Comment is associated with a given Blog Post and has a *username* and *comment*, and optionally a *rating*. + """ + model = Comment + fields = ('username', 'comment', 'created', 'rating', 'url', 'blogpost') + ordering = ('-created',) + + def blogpost(self, instance): + return reverse('blog-post', kwargs={'key': instance.blogpost.key})
\ No newline at end of file diff --git a/examples/blogpost/tests.py b/examples/blogpost/tests.py index 30b152fa..e55f0f90 100644 --- a/examples/blogpost/tests.py +++ b/examples/blogpost/tests.py @@ -7,9 +7,10 @@ from django.core.urlresolvers import reverse from django.utils import simplejson as json from djangorestframework.compat import RequestFactory +from djangorestframework.views import InstanceModelView, ListOrCreateModelView -from blogpost import models -import blogpost +from blogpost import models, urls +#import blogpost # class AcceptHeaderTests(TestCase): @@ -178,32 +179,33 @@ class TestRotation(TestCase): models.BlogPost.objects.all().delete() def test_get_to_root(self): - '''Simple test to demonstrate how the requestfactory needs to be used''' + '''Simple get to the *root* url of blogposts''' request = self.factory.get('/blog-post') - view = views.BlogPosts.as_view() + view = ListOrCreateModelView.as_view(resource=urls.BlogPostResource) response = view(request) self.assertEqual(response.status_code, 200) def test_blogposts_not_exceed_MAX_POSTS(self): '''Posting blog-posts should not result in more than MAX_POSTS items stored.''' - for post in range(views.MAX_POSTS + 5): + for post in range(models.MAX_POSTS + 5): form_data = {'title': 'This is post #%s' % post, 'content': 'This is the content of post #%s' % post} request = self.factory.post('/blog-post', data=form_data) - view = views.BlogPosts.as_view() + view = ListOrCreateModelView.as_view(resource=urls.BlogPostResource) view(request) - self.assertEquals(len(models.BlogPost.objects.all()),views.MAX_POSTS) + self.assertEquals(len(models.BlogPost.objects.all()),models.MAX_POSTS) def test_fifo_behaviour(self): '''It's fine that the Blogposts are capped off at MAX_POSTS. But we want to make sure we see FIFO behaviour.''' for post in range(15): form_data = {'title': '%s' % post, 'content': 'This is the content of post #%s' % post} request = self.factory.post('/blog-post', data=form_data) - view = views.BlogPosts.as_view() + view = ListOrCreateModelView.as_view(resource=urls.BlogPostResource) view(request) request = self.factory.get('/blog-post') - view = views.BlogPosts.as_view() + view = ListOrCreateModelView.as_view(resource=urls.BlogPostResource) response = view(request) response_posts = json.loads(response.content) response_titles = [d['title'] for d in response_posts] - self.assertEquals(response_titles, ['%s' % i for i in range(views.MAX_POSTS - 5, views.MAX_POSTS + 5)]) + response_titles.reverse() + self.assertEquals(response_titles, ['%s' % i for i in range(models.MAX_POSTS - 5, models.MAX_POSTS + 5)])
\ No newline at end of file diff --git a/examples/blogpost/urls.py b/examples/blogpost/urls.py index c677b8fa..e9bd2754 100644 --- a/examples/blogpost/urls.py +++ b/examples/blogpost/urls.py @@ -1,36 +1,11 @@ from django.conf.urls.defaults import patterns, url -from django.core.urlresolvers import reverse - from djangorestframework.views import ListOrCreateModelView, InstanceModelView -from djangorestframework.resources import ModelResource - -from blogpost.models import BlogPost, Comment - - -class BlogPostResource(ModelResource): - """ - A Blog Post has a *title* and *content*, and can be associated with zero or more comments. - """ - model = BlogPost - fields = ('created', 'title', 'slug', 'content', 'url', 'comments') - ordering = ('-created',) - - def comments(self, instance): - return reverse('comments', kwargs={'blogpost': instance.key}) - - -class CommentResource(ModelResource): - """ - A Comment is associated with a given Blog Post and has a *username* and *comment*, and optionally a *rating*. - """ - model = Comment - fields = ('username', 'comment', 'created', 'rating', 'url', 'blogpost') - ordering = ('-created',) +from blogpost.resources import BlogPostResource, CommentResource urlpatterns = patterns('', url(r'^$', ListOrCreateModelView.as_view(resource=BlogPostResource), name='blog-posts-root'), - url(r'^(?P<key>[^/]+)/$', InstanceModelView.as_view(resource=BlogPostResource)), + url(r'^(?P<key>[^/]+)/$', InstanceModelView.as_view(resource=BlogPostResource), name='blog-post'), url(r'^(?P<blogpost>[^/]+)/comments/$', ListOrCreateModelView.as_view(resource=CommentResource), name='comments'), url(r'^(?P<blogpost>[^/]+)/comments/(?P<id>[^/]+)/$', InstanceModelView.as_view(resource=CommentResource)), ) diff --git a/examples/modelresourceexample/resources.py b/examples/modelresourceexample/resources.py new file mode 100644 index 00000000..634ea6b3 --- /dev/null +++ b/examples/modelresourceexample/resources.py @@ -0,0 +1,7 @@ +from djangorestframework.resources import ModelResource +from modelresourceexample.models import MyModel + +class MyModelResource(ModelResource): + model = MyModel + fields = ('foo', 'bar', 'baz', 'url') + ordering = ('created',) diff --git a/examples/modelresourceexample/urls.py b/examples/modelresourceexample/urls.py index bb71ddd3..b6a16542 100644 --- a/examples/modelresourceexample/urls.py +++ b/examples/modelresourceexample/urls.py @@ -1,14 +1,8 @@ from django.conf.urls.defaults import patterns, url from djangorestframework.views import ListOrCreateModelView, InstanceModelView -from djangorestframework.resources import ModelResource -from modelresourceexample.models import MyModel - -class MyModelResource(ModelResource): - model = MyModel - fields = ('foo', 'bar', 'baz', 'url') - ordering = ('created',) +from modelresourceexample.resources import MyModelResource urlpatterns = patterns('', url(r'^$', ListOrCreateModelView.as_view(resource=MyModelResource), name='model-resource-root'), - url(r'^([0-9]+)/$', InstanceModelView.as_view(resource=MyModelResource)), + url(r'^(?P<pk>[0-9]+)/$', InstanceModelView.as_view(resource=MyModelResource)), ) diff --git a/examples/modelresourceexample/views.py b/examples/permissionsexample/__init__.py index e69de29b..e69de29b 100644 --- a/examples/modelresourceexample/views.py +++ b/examples/permissionsexample/__init__.py diff --git a/examples/permissionsexample/urls.py b/examples/permissionsexample/urls.py new file mode 100644 index 00000000..d17f5159 --- /dev/null +++ b/examples/permissionsexample/urls.py @@ -0,0 +1,6 @@ +from django.conf.urls.defaults import patterns, url +from permissionsexample.views import ThrottlingExampleView + +urlpatterns = patterns('', + url(r'^$', ThrottlingExampleView.as_view(), name='throttled-resource'), +) diff --git a/examples/permissionsexample/views.py b/examples/permissionsexample/views.py new file mode 100644 index 00000000..20e7cba7 --- /dev/null +++ b/examples/permissionsexample/views.py @@ -0,0 +1,20 @@ +from djangorestframework.views import View +from djangorestframework.permissions import PerUserThrottling + + +class ThrottlingExampleView(View): + """ + A basic read-only View that has a **per-user throttle** of 10 requests per minute. + + If a user exceeds the 10 requests limit within a period of one minute, the + throttle will be applied until 60 seconds have passed since the first request. + """ + + permissions = ( PerUserThrottling, ) + throttle = '10/min' + + def get(self, request): + """ + Handle GET requests. + """ + return "Successful response to GET request because throttle is not yet active."
\ No newline at end of file diff --git a/examples/pygments_api/views.py b/examples/pygments_api/views.py index 76647107..e50029f6 100644 --- a/examples/pygments_api/views.py +++ b/examples/pygments_api/views.py @@ -46,19 +46,12 @@ class HTMLRenderer(BaseRenderer): media_type = 'text/html' - -class PygmentsFormResource(FormResource): - """ - """ - form = PygmentsForm - - class PygmentsRoot(View): """ - This example demonstrates a simple RESTful Web API aound the awesome pygments library. + This example demonstrates a simple RESTful Web API around the awesome pygments library. This top level resource is used to create highlighted code snippets, and to list all the existing code snippets. """ - resource = PygmentsFormResource + form = PygmentsForm def get(self, request): """ diff --git a/examples/requirements.txt b/examples/requirements.txt index 09cda945..0bcd8d43 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -1,8 +1,6 @@ -# For the examples we need Django, pygments and httplib2... +# Pygments for the code highlighting example, +# markdown for the docstring -> auto-documentation -Django==1.2.4 -wsgiref==0.1.2 Pygments==1.4 -httplib2==0.6.0 Markdown==2.0.3 diff --git a/examples/sandbox/views.py b/examples/sandbox/views.py index 1c55c28f..1e326f43 100644 --- a/examples/sandbox/views.py +++ b/examples/sandbox/views.py @@ -31,4 +31,6 @@ class Sandbox(View): {'name': 'Simple Mixin-only example', 'url': reverse('mixin-view')}, {'name': 'Object store API', 'url': reverse('object-store-root')}, {'name': 'Code highlighting API', 'url': reverse('pygments-root')}, - {'name': 'Blog posts API', 'url': reverse('blog-posts-root')}] + {'name': 'Blog posts API', 'url': reverse('blog-posts-root')}, + {'name': 'Permissions example', 'url': reverse('throttled-resource')} + ] diff --git a/examples/urls.py b/examples/urls.py index cf4d4042..08d97a14 100644 --- a/examples/urls.py +++ b/examples/urls.py @@ -10,6 +10,7 @@ urlpatterns = patterns('', (r'^object-store/', include('objectstore.urls')), (r'^pygments/', include('pygments_api.urls')), (r'^blog-post/', include('blogpost.urls')), + (r'^permissions-example/', include('permissionsexample.urls')), (r'^', include('djangorestframework.urls')), ) |
