From 52847a215d4e8de88e81d9ae79ce8bee9a36a9a2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 15 Jan 2013 17:50:51 +0000 Subject: Fix implementation --- rest_framework/mixins.py | 3 -- rest_framework/resources.py | 67 ++++++++++++++++----------------------------- 2 files changed, 23 insertions(+), 47 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 8873e4ae..9bd566da 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -25,9 +25,6 @@ class CreateModelMixin(object): return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - def pre_save(self, obj): - pass - class ListModelMixin(object): """ diff --git a/rest_framework/resources.py b/rest_framework/resources.py index dd8a5471..d4019a94 100644 --- a/rest_framework/resources.py +++ b/rest_framework/resources.py @@ -1,31 +1,27 @@ ##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### from functools import update_wrapper -import inspect from django.utils.decorators import classonlymethod -from rest_framework import views, generics - - -def wrapped(source, dest): - """ - Copy public, non-method attributes from source to dest, and return dest. - """ - for attr in [attr for attr in dir(source) - if not attr.startswith('_') and not inspect.ismethod(attr)]: - setattr(dest, attr, getattr(source, attr)) - return dest +from rest_framework import views, generics, mixins ##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### class ResourceMixin(object): """ - Clone Django's `View.as_view()` behaviour *except* using REST framework's - 'method -> action' binding for resources. + This is the magic. + + Overrides `.as_view()` so that it takes an `actions` keyword that performs + the binding of HTTP methods to actions on the Resource. + + For example, to create a concrete view binding the 'GET' and 'POST' methods + to the 'list' and 'create' actions... + + my_resource = MyResource.as_view({'get': 'list', 'post': 'create'}) """ @classonlymethod - def as_view(cls, actions, **initkwargs): + def as_view(cls, actions=None, **initkwargs): """ Main entry point for a request-response process. """ @@ -61,36 +57,19 @@ class ResourceMixin(object): return view -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - class Resource(ResourceMixin, views.APIView): pass -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -class ModelResource(ResourceMixin, views.APIView): - # TODO: Actually delegation won't work - root_class = generics.ListCreateAPIView - detail_class = generics.RetrieveUpdateDestroyAPIView - - def root_view(self): - return wrapped(self, self.root_class()) - - def detail_view(self): - return wrapped(self, self.detail_class()) - - def list(self, request, *args, **kwargs): - return self.root_view().list(request, args, kwargs) - - def create(self, request, *args, **kwargs): - return self.root_view().create(request, args, kwargs) - - def retrieve(self, request, *args, **kwargs): - return self.detail_view().retrieve(request, args, kwargs) - - def update(self, request, *args, **kwargs): - return self.detail_view().update(request, args, kwargs) - - def destroy(self, request, *args, **kwargs): - return self.detail_view().destroy(request, args, kwargs) +# Note the inheritence of both MultipleObjectAPIView *and* SingleObjectAPIView +# is a bit weird given the diamond inheritence, but it will work for now. +# There's some implementation clean up that can happen later. +class ModelResource(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + mixins.ListModelMixin, + ResourceMixin, + generics.MultipleObjectAPIView, + generics.SingleObjectAPIView): + pass -- cgit v1.2.3 From 4a7139e41d2500776c30e663c1cebce74b49270d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 15 Jan 2013 21:49:24 +0000 Subject: Tweaks --- rest_framework/routers.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 rest_framework/routers.py (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py new file mode 100644 index 00000000..a5aef5b7 --- /dev/null +++ b/rest_framework/routers.py @@ -0,0 +1,33 @@ +# Not properly implemented yet, just the basic idea + + +class BaseRouter(object): + def __init__(self): + self.resources = [] + + def register(self, name, resource): + self.resources.append((name, resource)) + + @property + def urlpatterns(self): + ret = [] + + for name, resource in self.resources: + list_actions = { + 'get': getattr(resource, 'list', None), + 'post': getattr(resource, 'create', None) + } + detail_actions = { + 'get': getattr(resource, 'retrieve', None), + 'put': getattr(resource, 'update', None), + 'delete': getattr(resource, 'destroy', None) + } + list_regex = r'^%s/$' % name + detail_regex = r'^%s/(?P[0-9]+)/$' % name + list_name = '%s-list' + detail_name = '%s-detail' + + ret += url(list_regex, resource.as_view(list_actions), list_name) + ret += url(detail_regex, resource.as_view(detail_actions), detail_name) + + return ret -- cgit v1.2.3 From 922ee61d8611b41e2944b6503af736b1790abe83 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 18 Mar 2013 21:05:13 +0000 Subject: Remove erronous pre_save --- rest_framework/generics.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 55918267..36ecf915 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -82,9 +82,6 @@ class GenericAPIView(views.APIView): """ pass - def pre_save(self, obj): - pass - class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): """ -- cgit v1.2.3 From ec076a00786c6b89a55b6ffe2556bb3b777100f5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sun, 31 Mar 2013 11:36:58 +0100 Subject: Add viewsets/routers to indexs etc --- rest_framework/routers.py | 33 --------------------------------- rest_framework/viewsets.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 33 deletions(-) delete mode 100644 rest_framework/routers.py create mode 100644 rest_framework/viewsets.py (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py deleted file mode 100644 index a5aef5b7..00000000 --- a/rest_framework/routers.py +++ /dev/null @@ -1,33 +0,0 @@ -# Not properly implemented yet, just the basic idea - - -class BaseRouter(object): - def __init__(self): - self.resources = [] - - def register(self, name, resource): - self.resources.append((name, resource)) - - @property - def urlpatterns(self): - ret = [] - - for name, resource in self.resources: - list_actions = { - 'get': getattr(resource, 'list', None), - 'post': getattr(resource, 'create', None) - } - detail_actions = { - 'get': getattr(resource, 'retrieve', None), - 'put': getattr(resource, 'update', None), - 'delete': getattr(resource, 'destroy', None) - } - list_regex = r'^%s/$' % name - detail_regex = r'^%s/(?P[0-9]+)/$' % name - list_name = '%s-list' - detail_name = '%s-detail' - - ret += url(list_regex, resource.as_view(list_actions), list_name) - ret += url(detail_regex, resource.as_view(detail_actions), detail_name) - - return ret diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py new file mode 100644 index 00000000..a5aef5b7 --- /dev/null +++ b/rest_framework/viewsets.py @@ -0,0 +1,33 @@ +# Not properly implemented yet, just the basic idea + + +class BaseRouter(object): + def __init__(self): + self.resources = [] + + def register(self, name, resource): + self.resources.append((name, resource)) + + @property + def urlpatterns(self): + ret = [] + + for name, resource in self.resources: + list_actions = { + 'get': getattr(resource, 'list', None), + 'post': getattr(resource, 'create', None) + } + detail_actions = { + 'get': getattr(resource, 'retrieve', None), + 'put': getattr(resource, 'update', None), + 'delete': getattr(resource, 'destroy', None) + } + list_regex = r'^%s/$' % name + detail_regex = r'^%s/(?P[0-9]+)/$' % name + list_name = '%s-list' + detail_name = '%s-detail' + + ret += url(list_regex, resource.as_view(list_actions), list_name) + ret += url(detail_regex, resource.as_view(detail_actions), detail_name) + + return ret -- cgit v1.2.3 From c785628300d2b7cce63862a18915c537f8a3ab24 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 20:00:44 +0100 Subject: Fleshing out viewsets/routers --- rest_framework/resources.py | 75 ---------------------------- rest_framework/routers.py | 43 ++++++++++++++++ rest_framework/viewsets.py | 119 ++++++++++++++++++++++++++++++++------------ 3 files changed, 129 insertions(+), 108 deletions(-) delete mode 100644 rest_framework/resources.py create mode 100644 rest_framework/routers.py (limited to 'rest_framework') diff --git a/rest_framework/resources.py b/rest_framework/resources.py deleted file mode 100644 index d4019a94..00000000 --- a/rest_framework/resources.py +++ /dev/null @@ -1,75 +0,0 @@ -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -from functools import update_wrapper -from django.utils.decorators import classonlymethod -from rest_framework import views, generics, mixins - - -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -class ResourceMixin(object): - """ - This is the magic. - - Overrides `.as_view()` so that it takes an `actions` keyword that performs - the binding of HTTP methods to actions on the Resource. - - For example, to create a concrete view binding the 'GET' and 'POST' methods - to the 'list' and 'create' actions... - - my_resource = MyResource.as_view({'get': 'list', 'post': 'create'}) - """ - - @classonlymethod - def as_view(cls, actions=None, **initkwargs): - """ - Main entry point for a request-response process. - """ - # sanitize keyword arguments - for key in initkwargs: - if key in cls.http_method_names: - raise TypeError("You tried to pass in the %s method name as a " - "keyword argument to %s(). Don't do that." - % (key, cls.__name__)) - if not hasattr(cls, key): - raise TypeError("%s() received an invalid keyword %r" % ( - cls.__name__, key)) - - def view(request, *args, **kwargs): - self = cls(**initkwargs) - - # Bind methods to actions - for method, action in actions.items(): - handler = getattr(self, action) - setattr(self, method, handler) - - # As you were, solider. - if hasattr(self, 'get') and not hasattr(self, 'head'): - self.head = self.get - return self.dispatch(request, *args, **kwargs) - - # take name and docstring from class - update_wrapper(view, cls, updated=()) - - # and possible attributes set by decorators - # like csrf_exempt from dispatch - update_wrapper(view, cls.dispatch, assigned=()) - return view - - -class Resource(ResourceMixin, views.APIView): - pass - - -# Note the inheritence of both MultipleObjectAPIView *and* SingleObjectAPIView -# is a bit weird given the diamond inheritence, but it will work for now. -# There's some implementation clean up that can happen later. -class ModelResource(mixins.CreateModelMixin, - mixins.RetrieveModelMixin, - mixins.UpdateModelMixin, - mixins.DestroyModelMixin, - mixins.ListModelMixin, - ResourceMixin, - generics.MultipleObjectAPIView, - generics.SingleObjectAPIView): - pass diff --git a/rest_framework/routers.py b/rest_framework/routers.py new file mode 100644 index 00000000..63eae5d7 --- /dev/null +++ b/rest_framework/routers.py @@ -0,0 +1,43 @@ +from django.conf.urls import url, patterns + + +class BaseRouter(object): + def __init__(self): + self.registry = [] + + def register(self, prefix, viewset, base_name): + self.registry.append((prefix, viewset, base_name)) + + def get_urlpatterns(self): + raise NotImplemented('get_urlpatterns must be overridden') + + @property + def urlpatterns(self): + if not hasattr(self, '_urlpatterns'): + print self.get_urlpatterns() + self._urlpatterns = patterns('', *self.get_urlpatterns()) + return self._urlpatterns + + +class DefaultRouter(BaseRouter): + route_list = [ + (r'$', {'get': 'list', 'post': 'create'}, '%s-list'), + (r'(?P[^/]+)/$', {'get': 'retrieve', 'put': 'update', 'delete': 'destroy'}, '%s-detail'), + ] + + def get_urlpatterns(self): + ret = [] + for prefix, viewset, base_name in self.registry: + for suffix, action_mapping, name_format in self.route_list: + + # Only actions which actually exist on the viewset will be bound + bound_actions = {} + for method, action in action_mapping.items(): + if hasattr(viewset, action): + bound_actions[method] = action + + regex = prefix + suffix + view = viewset.as_view(bound_actions) + name = name_format % base_name + ret.append(url(regex, view, name=name)) + return ret diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index a5aef5b7..887a9722 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -1,33 +1,86 @@ -# Not properly implemented yet, just the basic idea - - -class BaseRouter(object): - def __init__(self): - self.resources = [] - - def register(self, name, resource): - self.resources.append((name, resource)) - - @property - def urlpatterns(self): - ret = [] - - for name, resource in self.resources: - list_actions = { - 'get': getattr(resource, 'list', None), - 'post': getattr(resource, 'create', None) - } - detail_actions = { - 'get': getattr(resource, 'retrieve', None), - 'put': getattr(resource, 'update', None), - 'delete': getattr(resource, 'destroy', None) - } - list_regex = r'^%s/$' % name - detail_regex = r'^%s/(?P[0-9]+)/$' % name - list_name = '%s-list' - detail_name = '%s-detail' - - ret += url(list_regex, resource.as_view(list_actions), list_name) - ret += url(detail_regex, resource.as_view(detail_actions), detail_name) - - return ret +from functools import update_wrapper +from django.utils.decorators import classonlymethod +from rest_framework import views, generics, mixins + + +class ViewSetMixin(object): + """ + This is the magic. + + Overrides `.as_view()` so that it takes an `actions` keyword that performs + the binding of HTTP methods to actions on the Resource. + + For example, to create a concrete view binding the 'GET' and 'POST' methods + to the 'list' and 'create' actions... + + view = MyViewSet.as_view({'get': 'list', 'post': 'create'}) + """ + + @classonlymethod + def as_view(cls, actions=None, **initkwargs): + """ + Main entry point for a request-response process. + + Because of the way class based views create a closure around the + instantiated view, we need to totally reimplement `.as_view`, + and slightly modify the view function that is created and returned. + """ + # sanitize keyword arguments + for key in initkwargs: + if key in cls.http_method_names: + raise TypeError("You tried to pass in the %s method name as a " + "keyword argument to %s(). Don't do that." + % (key, cls.__name__)) + if not hasattr(cls, key): + raise TypeError("%s() received an invalid keyword %r" % ( + cls.__name__, key)) + + def view(request, *args, **kwargs): + self = cls(**initkwargs) + + # Bind methods to actions + # This is the bit that's different to a standard view + for method, action in actions.items(): + handler = getattr(self, action) + setattr(self, method, handler) + + # Patch this in as it's otherwise only present from 1.5 onwards + if hasattr(self, 'get') and not hasattr(self, 'head'): + self.head = self.get + + # And continue as usual + return self.dispatch(request, *args, **kwargs) + + # take name and docstring from class + update_wrapper(view, cls, updated=()) + + # and possible attributes set by decorators + # like csrf_exempt from dispatch + update_wrapper(view, cls.dispatch, assigned=()) + return view + + +class ViewSet(ViewSetMixin, views.APIView): + pass + + +# Note the inheritence of both MultipleObjectAPIView *and* SingleObjectAPIView +# is a bit weird given the diamond inheritence, but it will work for now. +# There's some implementation clean up that can happen later. +class ModelViewSet(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + mixins.ListModelMixin, + ViewSetMixin, + generics.MultipleObjectAPIView, + generics.SingleObjectAPIView): + pass + + +class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, + mixins.ListModelMixin, + ViewSetMixin, + generics.MultipleObjectAPIView, + generics.SingleObjectAPIView): + pass -- cgit v1.2.3 From fb41d2ac8f495ae0728e3f38c6a21306f0507316 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 20:35:40 +0100 Subject: Add support for action and link routing --- rest_framework/decorators.py | 22 ++++++++++++++++++++++ rest_framework/routers.py | 20 ++++++++++++++++++++ 2 files changed, 42 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 8250cd3b..00b37f8b 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -97,3 +97,25 @@ def permission_classes(permission_classes): func.permission_classes = permission_classes return func return decorator + + +def link(**kwargs): + """ + Used to mark a method on a ViewSet that should be routed for GET requests. + """ + def decorator(func): + func.bind_to_method = 'get' + func.kwargs = kwargs + return func + return decorator + + +def action(**kwargs): + """ + Used to mark a method on a ViewSet that should be routed for POST requests. + """ + def decorator(func): + func.bind_to_method = 'post' + func.kwargs = kwargs + return func + return decorator diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 63eae5d7..d1e96156 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -24,10 +24,12 @@ class DefaultRouter(BaseRouter): (r'$', {'get': 'list', 'post': 'create'}, '%s-list'), (r'(?P[^/]+)/$', {'get': 'retrieve', 'put': 'update', 'delete': 'destroy'}, '%s-detail'), ] + extra_routes = (r'(?P[^/]+)/%s/$', '%s-%s') def get_urlpatterns(self): ret = [] for prefix, viewset, base_name in self.registry: + # Bind standard routes for suffix, action_mapping, name_format in self.route_list: # Only actions which actually exist on the viewset will be bound @@ -36,8 +38,26 @@ class DefaultRouter(BaseRouter): if hasattr(viewset, action): bound_actions[method] = action + # Build the url pattern regex = prefix + suffix view = viewset.as_view(bound_actions) name = name_format % base_name ret.append(url(regex, view, name=name)) + + # Bind any extra @action or @link routes + for attr in dir(viewset): + func = getattr(viewset, attr) + http_method = getattr(func, 'bind_to_method', None) + if not http_method: + continue + + regex_format, name_format = self.extra_routes + + # Build the url pattern + regex = regex_format % attr + view = viewset.as_view({http_method: attr}, **func.kwargs) + name = name_format % (base_name, attr) + ret.append(url(regex, view, name=name)) + + # Return a list of url patterns return ret -- cgit v1.2.3 From 9e24db022cd8da1a588dd43e6239e07798881c02 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 20:38:42 +0100 Subject: Commenting --- rest_framework/routers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index d1e96156..283add8d 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -29,7 +29,7 @@ class DefaultRouter(BaseRouter): def get_urlpatterns(self): ret = [] for prefix, viewset, base_name in self.registry: - # Bind standard routes + # Bind standard CRUD routes for suffix, action_mapping, name_format in self.route_list: # Only actions which actually exist on the viewset will be bound @@ -44,10 +44,12 @@ class DefaultRouter(BaseRouter): name = name_format % base_name ret.append(url(regex, view, name=name)) - # Bind any extra @action or @link routes + # Bind any extra `@action` or `@link` routes for attr in dir(viewset): func = getattr(viewset, attr) http_method = getattr(func, 'bind_to_method', None) + + # Skip if this is not an @action or @link method if not http_method: continue -- cgit v1.2.3 From f68721ade8d66806296323116ff9a61773ad2be1 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 21:42:26 +0100 Subject: Factor view names/descriptions out of View class --- rest_framework/renderers.py | 11 ++--- rest_framework/routers.py | 34 ++++++++------ rest_framework/utils/breadcrumbs.py | 5 ++- rest_framework/utils/formatting.py | 77 ++++++++++++++++++++++++++++++++ rest_framework/views.py | 89 ++++--------------------------------- rest_framework/viewsets.py | 5 ++- 6 files changed, 117 insertions(+), 104 deletions(-) create mode 100644 rest_framework/utils/formatting.py (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 4c15e0db..752306ad 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -24,6 +24,7 @@ from rest_framework.settings import api_settings from rest_framework.request import clone_request from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs +from rest_framework.utils.formatting import get_view_name, get_view_description from rest_framework import exceptions, parsers, status, VERSION @@ -438,16 +439,10 @@ class BrowsableAPIRenderer(BaseRenderer): return GenericContentForm() def get_name(self, view): - try: - return view.get_name() - except AttributeError: - return smart_text(view.__class__.__name__) + return get_view_name(view.__class__) def get_description(self, view): - try: - return view.get_description(html=True) - except AttributeError: - return smart_text(view.__doc__ or '') + return get_view_description(view.__class__, html=True) def render(self, data, accepted_media_type=None, renderer_context=None): """ diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 283add8d..c37909ff 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -14,23 +14,31 @@ class BaseRouter(object): @property def urlpatterns(self): if not hasattr(self, '_urlpatterns'): - print self.get_urlpatterns() self._urlpatterns = patterns('', *self.get_urlpatterns()) return self._urlpatterns class DefaultRouter(BaseRouter): route_list = [ - (r'$', {'get': 'list', 'post': 'create'}, '%s-list'), - (r'(?P[^/]+)/$', {'get': 'retrieve', 'put': 'update', 'delete': 'destroy'}, '%s-detail'), + (r'$', {'get': 'list', 'post': 'create'}, 'list'), + (r'(?P[^/]+)/$', {'get': 'retrieve', 'put': 'update', 'delete': 'destroy'}, 'detail'), ] - extra_routes = (r'(?P[^/]+)/%s/$', '%s-%s') + extra_routes = r'(?P[^/]+)/%s/$' + name_format = '%s-%s' def get_urlpatterns(self): ret = [] for prefix, viewset, base_name in self.registry: + # Bind regular views + if not getattr(viewset, '_is_viewset', False): + regex = prefix + view = viewset + name = base_name + ret.append(url(regex, view, name=name)) + continue + # Bind standard CRUD routes - for suffix, action_mapping, name_format in self.route_list: + for suffix, action_mapping, action_name in self.route_list: # Only actions which actually exist on the viewset will be bound bound_actions = {} @@ -40,25 +48,25 @@ class DefaultRouter(BaseRouter): # Build the url pattern regex = prefix + suffix - view = viewset.as_view(bound_actions) - name = name_format % base_name + view = viewset.as_view(bound_actions, name_suffix=action_name) + name = self.name_format % (base_name, action_name) ret.append(url(regex, view, name=name)) # Bind any extra `@action` or `@link` routes - for attr in dir(viewset): - func = getattr(viewset, attr) + for action_name in dir(viewset): + func = getattr(viewset, action_name) http_method = getattr(func, 'bind_to_method', None) # Skip if this is not an @action or @link method if not http_method: continue - regex_format, name_format = self.extra_routes + suffix = self.extra_routes % action_name # Build the url pattern - regex = regex_format % attr - view = viewset.as_view({http_method: attr}, **func.kwargs) - name = name_format % (base_name, attr) + regex = prefix + suffix + view = viewset.as_view({http_method: action_name}, **func.kwargs) + name = self.name_format % (base_name, action_name) ret.append(url(regex, view, name=name)) # Return a list of url patterns diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index af21ac79..18b3b207 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals from django.core.urlresolvers import resolve, get_script_prefix +from rest_framework.utils.formatting import get_view_name def get_breadcrumbs(url): @@ -16,11 +17,11 @@ def get_breadcrumbs(url): pass else: # Check if this is a REST framework view, and if so add it to the breadcrumbs - if isinstance(getattr(view, 'cls_instance', None), APIView): + if issubclass(getattr(view, 'cls', None), APIView): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: - breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url)) + breadcrumbs_list.insert(0, (get_view_name(view.cls), prefix + url)) seen.append(view) if url == '': diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py new file mode 100644 index 00000000..79566db1 --- /dev/null +++ b/rest_framework/utils/formatting.py @@ -0,0 +1,77 @@ +""" +Utility functions to return a formatted name and description for a given view. +""" +from __future__ import unicode_literals + +from django.utils.html import escape +from django.utils.safestring import mark_safe +from rest_framework.compat import apply_markdown +import re + + +def _remove_trailing_string(content, trailing): + """ + Strip trailing component `trailing` from `content` if it exists. + Used when generating names from view classes. + """ + if content.endswith(trailing) and content != trailing: + return content[:-len(trailing)] + return content + + +def _remove_leading_indent(content): + """ + Remove leading indent from a block of text. + Used when generating descriptions from docstrings. + """ + whitespace_counts = [len(line) - len(line.lstrip(' ')) + for line in content.splitlines()[1:] if line.lstrip()] + + # unindent the content if needed + if whitespace_counts: + whitespace_pattern = '^' + (' ' * min(whitespace_counts)) + content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) + content = content.strip('\n') + return content + + +def _camelcase_to_spaces(content): + """ + Translate 'CamelCaseNames' to 'Camel Case Names'. + Used when generating names from view classes. + """ + camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' + content = re.sub(camelcase_boundry, ' \\1', content).strip() + return ' '.join(content.split('_')).title() + + +def get_view_name(cls): + """ + Return a formatted name for an `APIView` class or `@api_view` function. + """ + name = cls.__name__ + name = _remove_trailing_string(name, 'View') + name = _remove_trailing_string(name, 'ViewSet') + return _camelcase_to_spaces(name) + + +def get_view_description(cls, html=False): + """ + Return a description for an `APIView` class or `@api_view` function. + """ + description = cls.__doc__ or '' + description = _remove_leading_indent(description) + if html: + return markup_description(description) + return description + + +def markup_description(description): + """ + Apply HTML markup to the given description. + """ + if apply_markdown: + description = apply_markdown(description) + else: + description = escape(description).replace('\n', '
') + return mark_safe(description) diff --git a/rest_framework/views.py b/rest_framework/views.py index 81cbdcbb..12298ca5 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -4,51 +4,13 @@ Provides an APIView class that is used as the base of all class-based views. from __future__ import unicode_literals from django.core.exceptions import PermissionDenied from django.http import Http404 -from django.utils.html import escape -from django.utils.safestring import mark_safe from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions -from rest_framework.compat import View, apply_markdown +from rest_framework.compat import View from rest_framework.response import Response from rest_framework.request import Request from rest_framework.settings import api_settings -import re - - -def _remove_trailing_string(content, trailing): - """ - Strip trailing component `trailing` from `content` if it exists. - Used when generating names from view classes. - """ - if content.endswith(trailing) and content != trailing: - return content[:-len(trailing)] - return content - - -def _remove_leading_indent(content): - """ - Remove leading indent from a block of text. - Used when generating descriptions from docstrings. - """ - whitespace_counts = [len(line) - len(line.lstrip(' ')) - for line in content.splitlines()[1:] if line.lstrip()] - - # unindent the content if needed - if whitespace_counts: - whitespace_pattern = '^' + (' ' * min(whitespace_counts)) - content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) - content = content.strip('\n') - return content - - -def _camelcase_to_spaces(content): - """ - Translate 'CamelCaseNames' to 'Camel Case Names'. - Used when generating names from view classes. - """ - camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' - content = re.sub(camelcase_boundry, ' \\1', content).strip() - return ' '.join(content.split('_')).title() +from rest_framework.utils.formatting import get_view_name, get_view_description class APIView(View): @@ -64,13 +26,13 @@ class APIView(View): @classmethod def as_view(cls, **initkwargs): """ - Override the default :meth:`as_view` to store an instance of the view - as an attribute on the callable function. This allows us to discover - information about the view when we do URL reverse lookups. + Store the original class on the view function. + + This allows us to discover information about the view when we do URL + reverse lookups. Used for breadcrumb generation. """ - # TODO: deprecate? view = super(APIView, cls).as_view(**initkwargs) - view.cls_instance = cls(**initkwargs) + view.cls = cls return view @property @@ -90,43 +52,10 @@ class APIView(View): 'Vary': 'Accept' } - def get_name(self): - """ - Return the resource or view class name for use as this view's name. - Override to customize. - """ - # TODO: deprecate? - name = self.__class__.__name__ - name = _remove_trailing_string(name, 'View') - return _camelcase_to_spaces(name) - - def get_description(self, html=False): - """ - Return the resource or view docstring for use as this view's description. - Override to customize. - """ - # TODO: deprecate? - description = self.__doc__ or '' - description = _remove_leading_indent(description) - if html: - return self.markup_description(description) - return description - - def markup_description(self, description): - """ - Apply HTML markup to the description of this view. - """ - # TODO: deprecate? - if apply_markdown: - description = apply_markdown(description) - else: - description = escape(description).replace('\n', '
') - return mark_safe(description) - def metadata(self, request): return { - 'name': self.get_name(), - 'description': self.get_description(), + 'name': get_view_name(self.__class__), + 'description': get_view_description(self.__class__), 'renders': [renderer.media_type for renderer in self.renderer_classes], 'parses': [parser.media_type for parser in self.parser_classes], } diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 887a9722..0818c0d9 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -15,9 +15,10 @@ class ViewSetMixin(object): view = MyViewSet.as_view({'get': 'list', 'post': 'create'}) """ + _is_viewset = True @classonlymethod - def as_view(cls, actions=None, **initkwargs): + def as_view(cls, actions=None, name_suffix=None, **initkwargs): """ Main entry point for a request-response process. @@ -57,6 +58,8 @@ class ViewSetMixin(object): # and possible attributes set by decorators # like csrf_exempt from dispatch update_wrapper(view, cls.dispatch, assigned=()) + + view.cls = cls return view -- cgit v1.2.3 From fd3f538e9f9ef5d4d929c107b9619e0735e426f1 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 21:48:23 +0100 Subject: Fix up view name/description tests --- rest_framework/tests/description.py | 63 ++++++++++++++----------------------- 1 file changed, 23 insertions(+), 40 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py index 5b3315bc..52c1a34c 100644 --- a/rest_framework/tests/description.py +++ b/rest_framework/tests/description.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from django.test import TestCase from rest_framework.views import APIView from rest_framework.compat import apply_markdown +from rest_framework.utils.formatting import get_view_name, get_view_description # We check that docstrings get nicely un-indented. DESCRIPTION = """an example docstring @@ -49,22 +50,16 @@ MARKED_DOWN_gte_21 = """

an example docstring

class TestViewNamesAndDescriptions(TestCase): - def test_resource_name_uses_classname_by_default(self): - """Ensure Resource names are based on the classname by default.""" + def test_view_name_uses_class_name(self): + """ + Ensure view names are based on the class name. + """ class MockView(APIView): pass - self.assertEqual(MockView().get_name(), 'Mock') + self.assertEqual(get_view_name(MockView), 'Mock') - def test_resource_name_can_be_set_explicitly(self): - """Ensure Resource names can be set using the 'get_name' method.""" - example = 'Some Other Name' - class MockView(APIView): - def get_name(self): - return example - self.assertEqual(MockView().get_name(), example) - - def test_resource_description_uses_docstring_by_default(self): - """Ensure Resource names are based on the docstring by default.""" + def test_view_description_uses_docstring(self): + """Ensure view descriptions are based on the docstring.""" class MockView(APIView): """an example docstring ==================== @@ -81,44 +76,32 @@ class TestViewNamesAndDescriptions(TestCase): # hash style header #""" - self.assertEqual(MockView().get_description(), DESCRIPTION) - - def test_resource_description_can_be_set_explicitly(self): - """Ensure Resource descriptions can be set using the 'get_description' method.""" - example = 'Some other description' - - class MockView(APIView): - """docstring""" - def get_description(self): - return example - self.assertEqual(MockView().get_description(), example) + self.assertEqual(get_view_description(MockView), DESCRIPTION) - def test_resource_description_supports_unicode(self): + def test_view_description_supports_unicode(self): + """ + Unicode in docstrings should be respected. + """ class MockView(APIView): """Проверка""" pass - self.assertEqual(MockView().get_description(), "Проверка") - - - def test_resource_description_does_not_require_docstring(self): - """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method.""" - example = 'Some other description' - - class MockView(APIView): - def get_description(self): - return example - self.assertEqual(MockView().get_description(), example) + self.assertEqual(get_view_description(MockView), "Проверка") - def test_resource_description_can_be_empty(self): - """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string.""" + def test_view_description_can_be_empty(self): + """ + Ensure that if a view has no docstring, + then it's description is the empty string. + """ class MockView(APIView): pass - self.assertEqual(MockView().get_description(), '') + self.assertEqual(get_view_description(MockView), '') def test_markdown(self): - """Ensure markdown to HTML works as expected""" + """ + Ensure markdown to HTML works as expected. + """ if apply_markdown: gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21 lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21 -- cgit v1.2.3 From 371698331c979305b5684f864ee6bf5b6d11a44e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 22:24:30 +0100 Subject: Tweaks --- rest_framework/generics.py | 9 +++------ rest_framework/mixins.py | 4 ++++ rest_framework/routers.py | 12 ++++++++++-- 3 files changed, 17 insertions(+), 8 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 36ecf915..dea980a5 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -187,8 +187,7 @@ class UpdateAPIView(mixins.UpdateModelMixin, return self.update(request, *args, **kwargs) def patch(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) + return self.partial_update(request, *args, **kwargs) class ListCreateAPIView(mixins.ListModelMixin, @@ -217,8 +216,7 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, return self.update(request, *args, **kwargs) def patch(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) + return self.partial_update(request, *args, **kwargs) class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, @@ -248,8 +246,7 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, return self.update(request, *args, **kwargs) def patch(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) + return self.partial_update(request, *args, **kwargs) def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 7d9a6e65..c700602e 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -137,6 +137,10 @@ class UpdateModelMixin(object): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def partial_update(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + def pre_save(self, obj): """ Set any attributes on the object that are implicit in the request. diff --git a/rest_framework/routers.py b/rest_framework/routers.py index c37909ff..afc51f3b 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -20,8 +20,16 @@ class BaseRouter(object): class DefaultRouter(BaseRouter): route_list = [ - (r'$', {'get': 'list', 'post': 'create'}, 'list'), - (r'(?P[^/]+)/$', {'get': 'retrieve', 'put': 'update', 'delete': 'destroy'}, 'detail'), + (r'$', { + 'get': 'list', + 'post': 'create' + }, 'list'), + (r'(?P[^/]+)/$', { + 'get': 'retrieve', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy' + }, 'detail'), ] extra_routes = r'(?P[^/]+)/%s/$' name_format = '%s-%s' -- cgit v1.2.3 From c73d0e1e39e661c7324eb0df8c3ce6e18f57915b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Apr 2013 18:22:39 +0100 Subject: Minor cleaning up on View --- rest_framework/compat.py | 20 ++++++++++++-------- rest_framework/views.py | 8 ++++---- 2 files changed, 16 insertions(+), 12 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 6551723a..8bfebe68 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -87,9 +87,7 @@ else: raise ImportError("User model is not to be found.") -# First implementation of Django class-based views did not include head method -# in base View class - https://code.djangoproject.com/ticket/15668 -if django.VERSION >= (1, 4): +if django.VERSION >= (1, 5): from django.views.generic import View else: from django.views.generic import View as _View @@ -97,6 +95,8 @@ else: from django.utils.functional import update_wrapper class View(_View): + # 1.3 does not include head method in base View class + # See: https://code.djangoproject.com/ticket/15668 @classonlymethod def as_view(cls, **initkwargs): """ @@ -126,11 +126,15 @@ else: update_wrapper(view, cls.dispatch, assigned=()) return view -# Taken from @markotibold's attempt at supporting PATCH. -# https://github.com/markotibold/django-rest-framework/tree/patch -http_method_names = set(View.http_method_names) -http_method_names.add('patch') -View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django + # _allowed_methods only present from 1.5 onwards + def _allowed_methods(self): + return [m.upper() for m in self.http_method_names if hasattr(self, m)] + + +# PATCH method is not implemented by Django +if 'patch' not in View.http_method_names: + View.http_method_names = View.http_method_names + ['patch'] + # PUT, DELETE do not require CSRF until 1.4. They should. Make it better. if django.VERSION >= (1, 4): diff --git a/rest_framework/views.py b/rest_framework/views.py index 12298ca5..d7d3a2e2 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -38,10 +38,9 @@ class APIView(View): @property def allowed_methods(self): """ - Return the list of allowed HTTP methods, uppercased. + Wrap Django's private `_allowed_methods` interface in a public property. """ - return [method.upper() for method in self.http_method_names - if hasattr(self, method)] + return self._allowed_methods() @property def default_response_headers(self): @@ -69,7 +68,8 @@ class APIView(View): def http_method_not_allowed(self, request, *args, **kwargs): """ - Called if `request.method` does not correspond to a handler method. + If `request.method` does not correspond to a handler method, + determine what kind of exception to raise. """ raise exceptions.MethodNotAllowed(request.method) -- cgit v1.2.3 From 099163f81f9d89746de50f3aed2955ead54dba4e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Apr 2013 18:45:15 +0100 Subject: Removed SingleObjectMixin and MultipleObjectMixin --- rest_framework/generics.py | 139 +++++++++++++++++++++++++++++++++------------ rest_framework/mixins.py | 5 +- 2 files changed, 106 insertions(+), 38 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index dea980a5..af3b69da 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -4,21 +4,35 @@ Generic views that provide commonly needed behaviour. from __future__ import unicode_literals from rest_framework import views, mixins from rest_framework.settings import api_settings -from django.views.generic.detail import SingleObjectMixin -from django.views.generic.list import MultipleObjectMixin - +from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist +from django.core.paginator import Paginator, InvalidPage +from django.http import Http404 +from django.utils.translation import ugettext as _ ### Base classes for the generic views ### + class GenericAPIView(views.APIView): """ Base class for all other generic views. """ - model = None + queryset = None serializer_class = None - model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS + filter_backend = api_settings.FILTER_BACKEND + paginate_by = api_settings.PAGINATE_BY + paginate_by_param = api_settings.PAGINATE_BY_PARAM + pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS + allow_empty = True + page_kwarg = 'page' + + # Pending deprecation + model = None + model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS + pk_url_kwarg = 'pk' # Not provided in Django 1.3 + slug_url_kwarg = 'slug' # Not provided in Django 1.3 + slug_field = 'slug' def filter_queryset(self, queryset): """ @@ -82,15 +96,7 @@ class GenericAPIView(views.APIView): """ pass - -class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): - """ - Base class for generic views onto a queryset. - """ - - paginate_by = api_settings.PAGINATE_BY - paginate_by_param = api_settings.PAGINATE_BY_PARAM - pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS + # Pagination def get_pagination_serializer(self, page=None): """ @@ -116,28 +122,81 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): pass return self.paginate_by - -class SingleObjectAPIView(SingleObjectMixin, GenericAPIView): - """ - Base class for generic views onto a model instance. - """ - - pk_url_kwarg = 'pk' # Not provided in Django 1.3 - slug_url_kwarg = 'slug' # Not provided in Django 1.3 - slug_field = 'slug' + def paginate_queryset(self, queryset, page_size, paginator_class=Paginator): + """ + Paginate a queryset. + """ + paginator = paginator_class(queryset, page_size, allow_empty_first_page=self.allow_empty) + page_kwarg = self.page_kwarg + page = self.kwargs.get(page_kwarg) or self.request.GET.get(page_kwarg) or 1 + try: + page_number = int(page) + except ValueError: + if page == 'last': + page_number = paginator.num_pages + else: + raise Http404(_("Page is not 'last', nor can it be converted to an int.")) + try: + page = paginator.page(page_number) + return (paginator, page, page.object_list, page.has_other_pages()) + except InvalidPage as e: + raise Http404(_('Invalid page (%(page_number)s): %(message)s') % { + 'page_number': page_number, + 'message': str(e) + }) + + def get_queryset(self): + """ + Get the list of items for this view. This must be an iterable, and may + be a queryset (in which qs-specific behavior will be enabled). + """ + if self.queryset is not None: + queryset = self.queryset + if hasattr(queryset, '_clone'): + queryset = queryset._clone() + elif self.model is not None: + queryset = self.model._default_manager.all() + else: + raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" + % self.__class__.__name__) + return queryset def get_object(self, queryset=None): """ - Override default to add support for object-level permissions. + Returns the object the view is displaying. + By default this requires `self.queryset` and a `pk` or `slug` argument + in the URLconf, but subclasses can override this to return any object. """ - obj = super(SingleObjectAPIView, self).get_object(queryset) + # Use a custom queryset if provided; this is required for subclasses + # like DateDetailView + if queryset is None: + queryset = self.get_queryset() + # Next, try looking up by primary key. + pk = self.kwargs.get(self.pk_url_kwarg, None) + slug = self.kwargs.get(self.slug_url_kwarg, None) + if pk is not None: + queryset = queryset.filter(pk=pk) + # Next, try looking up by slug. + elif slug is not None: + queryset = queryset.filter(**{self.slug_field: slug}) + # If none of those are defined, it's an error. + else: + raise AttributeError("Generic detail view %s must be called with " + "either an object pk or a slug." + % self.__class__.__name__) + try: + # Get the single item from the filtered queryset + obj = queryset.get() + except ObjectDoesNotExist: + raise Http404(_("No %(verbose_name)s found matching the query") % + {'verbose_name': queryset.model._meta.verbose_name}) + self.check_object_permissions(self.request, obj) return obj ### Concrete view classes that provide method handlers ### -### by composing the mixin classes with a base view. ### - +### by composing the mixin classes with the base view. ### class CreateAPIView(mixins.CreateModelMixin, GenericAPIView): @@ -150,7 +209,7 @@ class CreateAPIView(mixins.CreateModelMixin, class ListAPIView(mixins.ListModelMixin, - MultipleObjectAPIView): + GenericAPIView): """ Concrete view for listing a queryset. """ @@ -159,7 +218,7 @@ class ListAPIView(mixins.ListModelMixin, class RetrieveAPIView(mixins.RetrieveModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving a model instance. """ @@ -168,7 +227,7 @@ class RetrieveAPIView(mixins.RetrieveModelMixin, class DestroyAPIView(mixins.DestroyModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for deleting a model instance. @@ -178,7 +237,7 @@ class DestroyAPIView(mixins.DestroyModelMixin, class UpdateAPIView(mixins.UpdateModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for updating a model instance. @@ -192,7 +251,7 @@ class UpdateAPIView(mixins.UpdateModelMixin, class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, - MultipleObjectAPIView): + GenericAPIView): """ Concrete view for listing a queryset or creating a model instance. """ @@ -205,7 +264,7 @@ class ListCreateAPIView(mixins.ListModelMixin, class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving, updating a model instance. """ @@ -221,7 +280,7 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, mixins.DestroyModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving or deleting a model instance. """ @@ -235,7 +294,7 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving, updating or deleting a model instance. """ @@ -250,3 +309,13 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) + + +### Deprecated classes ### + +class MultipleObjectAPIView(GenericAPIView): + pass + + +class SingleObjectAPIView(GenericAPIView): + pass diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index c700602e..b15cb11f 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -72,8 +72,7 @@ class ListModelMixin(object): # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. - allow_empty = self.get_allow_empty() - if not allow_empty and not self.object_list: + if not self.allow_empty and not self.object_list: class_name = self.__class__.__name__ error_msg = self.empty_error % {'class_name': class_name} raise Http404(error_msg) @@ -148,7 +147,7 @@ class UpdateModelMixin(object): # pk and/or slug attributes are implicit in the URL. pk = self.kwargs.get(self.pk_url_kwarg, None) slug = self.kwargs.get(self.slug_url_kwarg, None) - slug_field = slug and self.get_slug_field() or None + slug_field = slug and self.slug_field or None if pk: setattr(obj, 'pk', pk) -- cgit v1.2.3 From dc45bc7bfad64a17f3e5ed0f5a487bccc379aac2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Apr 2013 19:01:01 +0100 Subject: Add lookup_kwarg --- rest_framework/generics.py | 18 ++++++++++++------ rest_framework/tests/filterset.py | 6 +++--- 2 files changed, 15 insertions(+), 9 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index af3b69da..d4a50dcd 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -26,6 +26,7 @@ class GenericAPIView(views.APIView): pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS allow_empty = True page_kwarg = 'page' + lookup_kwarg = 'pk' # Pending deprecation model = None @@ -167,23 +168,26 @@ class GenericAPIView(views.APIView): By default this requires `self.queryset` and a `pk` or `slug` argument in the URLconf, but subclasses can override this to return any object. """ - # Use a custom queryset if provided; this is required for subclasses - # like DateDetailView + # Determine the base queryset to use. if queryset is None: queryset = self.get_queryset() - # Next, try looking up by primary key. + + # Perform the lookup filtering. pk = self.kwargs.get(self.pk_url_kwarg, None) slug = self.kwargs.get(self.slug_url_kwarg, None) - if pk is not None: + lookup = self.kwargs.get(self.lookup_kwarg, None) + + if lookup is not None: + queryset = queryset.filter(**{self.lookup_kwarg: lookup}) + elif pk is not None: queryset = queryset.filter(pk=pk) - # Next, try looking up by slug. elif slug is not None: queryset = queryset.filter(**{self.slug_field: slug}) - # If none of those are defined, it's an error. else: raise AttributeError("Generic detail view %s must be called with " "either an object pk or a slug." % self.__class__.__name__) + try: # Get the single item from the filtered queryset obj = queryset.get() @@ -191,7 +195,9 @@ class GenericAPIView(views.APIView): raise Http404(_("No %(verbose_name)s found matching the query") % {'verbose_name': queryset.model._meta.verbose_name}) + # May raise a permission denied self.check_object_permissions(self.request, obj) + return obj diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 1a71558c..1e53a5cd 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -61,7 +61,7 @@ if django_filters: class CommonFilteringTestCase(TestCase): def _serialize_object(self, obj): return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} - + def setUp(self): """ Create 10 FilterableItem instances. @@ -190,7 +190,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): Integration tests for filtered detail views. """ urls = 'rest_framework.tests.filterset' - + def _get_url(self, item): return reverse('detail-view', kwargs=dict(pk=item.pk)) @@ -221,7 +221,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, low_item_data) - + # Tests that multiple filters works. search_decimal = Decimal('5.25') search_date = datetime.date(2012, 10, 2) -- cgit v1.2.3 From 1de6cff11b71e4aaa7b76219d4d2118021e23a00 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Apr 2013 19:06:49 +0100 Subject: Cleaning up get_object and get_queryset --- rest_framework/generics.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index d4a50dcd..4ae2ac8e 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -148,25 +148,22 @@ class GenericAPIView(views.APIView): def get_queryset(self): """ - Get the list of items for this view. This must be an iterable, and may - be a queryset (in which qs-specific behavior will be enabled). + Get the list of items for this view. + + This must be an iterable, and may be a queryset. """ if self.queryset is not None: - queryset = self.queryset - if hasattr(queryset, '_clone'): - queryset = queryset._clone() - elif self.model is not None: - queryset = self.model._default_manager.all() - else: - raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" - % self.__class__.__name__) - return queryset + return self.queryset._clone() + + if self.model is not None: + return self.model._default_manager.all() + + raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" + % self.__class__.__name__) def get_object(self, queryset=None): """ Returns the object the view is displaying. - By default this requires `self.queryset` and a `pk` or `slug` argument - in the URLconf, but subclasses can override this to return any object. """ # Determine the base queryset to use. if queryset is None: -- cgit v1.2.3 From 9bb1277e512a88e6c11c52457d0c24e73f30bb98 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Apr 2013 19:37:19 +0100 Subject: Cleaning up around bits of API that will be pending deprecation --- rest_framework/generics.py | 116 +++++++++++++++++++++++++++------------------ rest_framework/mixins.py | 9 ++-- rest_framework/viewsets.py | 6 +-- 3 files changed, 75 insertions(+), 56 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 4ae2ac8e..124dba38 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -35,15 +35,6 @@ class GenericAPIView(views.APIView): slug_url_kwarg = 'slug' # Not provided in Django 1.3 slug_field = 'slug' - def filter_queryset(self, queryset): - """ - Given a queryset, filter it with whichever filter backend is in use. - """ - if not self.filter_backend: - return queryset - backend = self.filter_backend() - return backend.filter_queryset(self.request, queryset, self) - def get_serializer_context(self): """ Extra context provided to the serializer class. @@ -54,24 +45,6 @@ class GenericAPIView(views.APIView): 'view': self } - def get_serializer_class(self): - """ - Return the class to use for the serializer. - - Defaults to using `self.serializer_class`, falls back to constructing a - model serializer class using `self.model_serializer_class`, with - `self.model` as the model. - """ - serializer_class = self.serializer_class - - if serializer_class is None: - class DefaultSerializer(self.model_serializer_class): - class Meta: - model = self.model - serializer_class = DefaultSerializer - - return serializer_class - def get_serializer(self, instance=None, data=None, files=None, many=False, partial=False): """ @@ -83,22 +56,6 @@ class GenericAPIView(views.APIView): return serializer_class(instance, data=data, files=files, many=many, partial=partial, context=context) - def pre_save(self, obj): - """ - Placeholder method for calling before saving an object. - May be used eg. to set attributes on the object that are implicit - in either the request, or the url. - """ - pass - - def post_save(self, obj, created=False): - """ - Placeholder method for calling after saving an object. - """ - pass - - # Pagination - def get_pagination_serializer(self, page=None): """ Return a serializer instance to use with paginated data. @@ -111,9 +68,14 @@ class GenericAPIView(views.APIView): context = self.get_serializer_context() return pagination_serializer_class(instance=page, context=context) - def get_paginate_by(self, queryset): + def get_paginate_by(self, queryset=None): """ Return the size of pages to use with pagination. + + If `PAGINATE_BY_PARAM` is set it will attempt to get the page size + from a named query parameter in the url, eg. ?page_size=100 + + Otherwise defaults to using `self.paginate_by`. """ if self.paginate_by_param: query_params = self.request.QUERY_PARAMS @@ -121,6 +83,7 @@ class GenericAPIView(views.APIView): return int(query_params[self.paginate_by_param]) except (KeyError, ValueError): pass + return self.paginate_by def paginate_queryset(self, queryset, page_size, paginator_class=Paginator): @@ -146,16 +109,54 @@ class GenericAPIView(views.APIView): 'message': str(e) }) + def filter_queryset(self, queryset): + """ + Given a queryset, filter it with whichever filter backend is in use. + """ + if not self.filter_backend: + return queryset + backend = self.filter_backend() + return backend.filter_queryset(self.request, queryset, self) + + ### The following methods provide default implementations + ### that you may want to override for more complex cases. + + def get_serializer_class(self): + """ + Return the class to use for the serializer. + Defaults to using `self.serializer_class`. + + You may want to override this if you need to provide different + serializations depending on the incoming request. + + (Eg. admins get full serialization, others get basic serilization) + """ + serializer_class = self.serializer_class + if serializer_class is not None: + return serializer_class + + # TODO: Deprecation warning + class DefaultSerializer(self.model_serializer_class): + class Meta: + model = self.model + return DefaultSerializer + def get_queryset(self): """ Get the list of items for this view. - This must be an iterable, and may be a queryset. + Defaults to using `self.queryset`. + + You may want to override this if you need to provide different + querysets depending on the incoming request. + + (Eg. return a list of items that is specific to the user) """ if self.queryset is not None: return self.queryset._clone() if self.model is not None: + # TODO: Deprecation warning return self.model._default_manager.all() raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" @@ -164,10 +165,14 @@ class GenericAPIView(views.APIView): def get_object(self, queryset=None): """ Returns the object the view is displaying. + + You may want to override this if you need to provide non-standard + queryset lookups. Eg if objects are referenced using multiple + keyword arguments in the url conf. """ # Determine the base queryset to use. if queryset is None: - queryset = self.get_queryset() + queryset = self.filter_queryset(self.get_queryset()) # Perform the lookup filtering. pk = self.kwargs.get(self.pk_url_kwarg, None) @@ -177,8 +182,10 @@ class GenericAPIView(views.APIView): if lookup is not None: queryset = queryset.filter(**{self.lookup_kwarg: lookup}) elif pk is not None: + # TODO: Deprecation warning queryset = queryset.filter(pk=pk) elif slug is not None: + # TODO: Deprecation warning queryset = queryset.filter(**{self.slug_field: slug}) else: raise AttributeError("Generic detail view %s must be called with " @@ -197,6 +204,23 @@ class GenericAPIView(views.APIView): return obj + ### The following methods are intended to be overridden. + + def pre_save(self, obj): + """ + Placeholder method for calling before saving an object. + + May be used to set attributes on the object that are implicit + in either the request, or the url. + """ + pass + + def post_save(self, obj, created=False): + """ + Placeholder method for calling after saving an object. + """ + pass + ### Concrete view classes that provide method handlers ### ### by composing the mixin classes with the base view. ### diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index b15cb11f..6e40b5c4 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -67,8 +67,7 @@ class ListModelMixin(object): empty_error = "Empty list and '%(class_name)s.allow_empty' is False." def list(self, request, *args, **kwargs): - queryset = self.get_queryset() - self.object_list = self.filter_queryset(queryset) + self.object_list = self.filter_queryset(self.get_queryset()) # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. @@ -79,7 +78,7 @@ class ListModelMixin(object): # Pagination size is set by the `.paginate_by` attribute, # which may be `None` to disable pagination. - page_size = self.get_paginate_by(self.object_list) + page_size = self.get_paginate_by() if page_size: packed = self.paginate_queryset(self.object_list, page_size) paginator, page, queryset, is_paginated = packed @@ -96,9 +95,7 @@ class RetrieveModelMixin(object): Should be mixed in with `SingleObjectAPIView`. """ def retrieve(self, request, *args, **kwargs): - queryset = self.get_queryset() - filtered_queryset = self.filter_queryset(queryset) - self.object = self.get_object(filtered_queryset) + self.object = self.get_object() serializer = self.get_serializer(self.object) return Response(serializer.data) diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 0818c0d9..28ab30e2 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -76,14 +76,12 @@ class ModelViewSet(mixins.CreateModelMixin, mixins.DestroyModelMixin, mixins.ListModelMixin, ViewSetMixin, - generics.MultipleObjectAPIView, - generics.SingleObjectAPIView): + generics.GenericAPIView): pass class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, ViewSetMixin, - generics.MultipleObjectAPIView, - generics.SingleObjectAPIView): + generics.GenericAPIView): pass -- cgit v1.2.3 From 07af4373616c28e7600ee2ec7981b5a1d0a92f7d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Apr 2013 19:47:16 +0100 Subject: Cleaning up around bits of API that will be pending deprecation --- rest_framework/generics.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 124dba38..ba7d1f43 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -24,15 +24,17 @@ class GenericAPIView(views.APIView): paginate_by = api_settings.PAGINATE_BY paginate_by_param = api_settings.PAGINATE_BY_PARAM pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS - allow_empty = True page_kwarg = 'page' lookup_kwarg = 'pk' + allow_empty = True + + ###################################### + # These are all pending deprecation... - # Pending deprecation model = None model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS - pk_url_kwarg = 'pk' # Not provided in Django 1.3 - slug_url_kwarg = 'slug' # Not provided in Django 1.3 + pk_url_kwarg = 'pk' + slug_url_kwarg = 'slug' slug_field = 'slug' def get_serializer_context(self): @@ -90,7 +92,8 @@ class GenericAPIView(views.APIView): """ Paginate a queryset. """ - paginator = paginator_class(queryset, page_size, allow_empty_first_page=self.allow_empty) + paginator = paginator_class(queryset, page_size, + allow_empty_first_page=self.allow_empty) page_kwarg = self.page_kwarg page = self.kwargs.get(page_kwarg) or self.request.GET.get(page_kwarg) or 1 try: @@ -118,6 +121,7 @@ class GenericAPIView(views.APIView): backend = self.filter_backend() return backend.filter_queryset(self.request, queryset, self) + ######################## ### The following methods provide default implementations ### that you may want to override for more complex cases. @@ -204,7 +208,9 @@ class GenericAPIView(views.APIView): return obj - ### The following methods are intended to be overridden. + ######################## + ### The following are placeholder methods, + ### and are intended to be overridden. def pre_save(self, obj): """ @@ -222,8 +228,10 @@ class GenericAPIView(views.APIView): pass +########################################################## ### Concrete view classes that provide method handlers ### ### by composing the mixin classes with the base view. ### +########################################################## class CreateAPIView(mixins.CreateModelMixin, GenericAPIView): @@ -338,7 +346,9 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, return self.destroy(request, *args, **kwargs) +########################## ### Deprecated classes ### +########################## class MultipleObjectAPIView(GenericAPIView): pass -- cgit v1.2.3 From 76e039d70e8fc7f1d5c65180cb544abab81e600e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 10 Apr 2013 22:38:02 +0100 Subject: First pass on automatically including reverse relationship --- rest_framework/serializers.py | 43 ++++++++++++++++++++++++++++++++------ rest_framework/tests/serializer.py | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e28bbe81..eac909c7 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -598,6 +598,24 @@ class ModelSerializer(Serializer): if field: ret[model_field.name] = field + # Reverse relationships are only included if they are explicitly + # present in `Meta.fields`. + if self.opts.fields: + reverse = opts.get_all_related_objects() + reverse += opts.get_all_related_many_to_many_objects() + for rel in reverse: + name = rel.get_accessor_name() + if name not in self.opts.fields: + continue + + if nested: + field = self.get_nested_field(None, rel) + else: + field = self.get_related_field(None, rel, to_many=True) + + if field: + ret[name] = field + for field_name in self.opts.read_only_fields: assert field_name in ret, \ "read_only_fields on '%s' included invalid item '%s'" % \ @@ -612,24 +630,36 @@ class ModelSerializer(Serializer): """ return self.get_field(model_field) - def get_nested_field(self, model_field): + def get_nested_field(self, model_field, rel=None): """ Creates a default instance of a nested relational field. """ + if rel: + model_class = rel.model + else: + model_class = model_field.rel.to + class NestedModelSerializer(ModelSerializer): class Meta: - model = model_field.rel.to + model = model_class return NestedModelSerializer() - def get_related_field(self, model_field, to_many=False): + def get_related_field(self, model_field, rel=None, to_many=False): """ Creates a default instance of a flat relational field. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) + if rel: + model_class = rel.model + required = True + else: + model_class = model_field.rel.to + required = not(model_field.null or model_field.blank) + kwargs = { - 'required': not(model_field.null or model_field.blank), - 'queryset': model_field.rel.to._default_manager, + 'required': required, + 'queryset': model_class._default_manager, 'many': to_many } @@ -797,7 +827,8 @@ class HyperlinkedModelSerializer(ModelSerializer): return self._default_view_name % format_kwargs def get_pk_field(self, model_field): - return None + if self.opts.fields and model_field.name in self.opts.fields: + return self.get_field(model_field) def get_related_field(self, model_field, to_many): """ diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 05217f35..3a94fad5 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -738,6 +738,43 @@ class ManyRelatedTests(TestCase): self.assertEqual(serializer.data, expected) + def test_include_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + fields = ('id', 'title', 'blogpostcomment_set') + + serializer = BlogPostSerializer(instance=post) + expected = { + 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2] + } + self.assertEqual(serializer.data, expected) + + def test_depth_include_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + fields = ('id', 'title', 'blogpostcomment_set') + depth = 1 + + serializer = BlogPostSerializer(instance=post) + expected = { + 'id': 1, 'title': 'Test blog post', + 'blogpostcomment_set': [ + {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1}, + {'id': 2, 'text': 'I love this blog post', 'blog_post': 1} + ] + } + self.assertEqual(serializer.data, expected) + def test_callable_source(self): post = BlogPost.objects.create(title="Test blog post") post.blogpostcomment_set.create(text="I love this blog post") -- cgit v1.2.3 From e0020c5b033308cd789408a8823d6707deed8032 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 11 Apr 2013 15:48:18 +0100 Subject: Simplify get_object --- rest_framework/generics.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index ba7d1f43..ea62123d 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -4,9 +4,10 @@ Generic views that provide commonly needed behaviour. from __future__ import unicode_literals from rest_framework import views, mixins from rest_framework.settings import api_settings -from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist +from django.core.exceptions import ImproperlyConfigured from django.core.paginator import Paginator, InvalidPage from django.http import Http404 +from django.shortcuts import get_object_or_404 from django.utils.translation import ugettext as _ ### Base classes for the generic views ### @@ -163,7 +164,7 @@ class GenericAPIView(views.APIView): # TODO: Deprecation warning return self.model._default_manager.all() - raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" + raise ImproperlyConfigured("'%s' must define 'queryset'" % self.__class__.__name__) def get_object(self, queryset=None): @@ -177,6 +178,8 @@ class GenericAPIView(views.APIView): # Determine the base queryset to use. if queryset is None: queryset = self.filter_queryset(self.get_queryset()) + else: + pass # Deprecation warning # Perform the lookup filtering. pk = self.kwargs.get(self.pk_url_kwarg, None) @@ -184,24 +187,19 @@ class GenericAPIView(views.APIView): lookup = self.kwargs.get(self.lookup_kwarg, None) if lookup is not None: - queryset = queryset.filter(**{self.lookup_kwarg: lookup}) + filter_kwargs = {self.lookup_kwarg: lookup} elif pk is not None: # TODO: Deprecation warning - queryset = queryset.filter(pk=pk) + filter_kwargs = {'pk': pk} elif slug is not None: # TODO: Deprecation warning - queryset = queryset.filter(**{self.slug_field: slug}) + filter_kwargs = {self.slug_field: slug} else: raise AttributeError("Generic detail view %s must be called with " "either an object pk or a slug." % self.__class__.__name__) - try: - # Get the single item from the filtered queryset - obj = queryset.get() - except ObjectDoesNotExist: - raise Http404(_("No %(verbose_name)s found matching the query") % - {'verbose_name': queryset.model._meta.verbose_name}) + obj = get_object_or_404(queryset, **filter_kwargs) # May raise a permission denied self.check_object_permissions(self.request, obj) -- cgit v1.2.3 From 750451f5b4de61684f4a4e69dd5776bd84ac054c Mon Sep 17 00:00:00 2001 From: Johannes Spielmann Date: Sun, 14 Apr 2013 18:30:44 +0200 Subject: adding test case for generic view with overriden get_object() --- rest_framework/tests/generics.py | 173 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index f564890c..b40b0102 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -24,6 +24,28 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView): model = BasicModel +class InstanceDetailView(generics.RetrieveUpdateDestroyAPIView): + """ + Example detail view for override of get_object(). + """ + + # we have to implement this too, otherwise we can't be sure that get_object + # will be called + def get_serializer(self, instance=None, data=None, files=None, partial=None): + class InstanceDetailSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + return InstanceDetailSerializer(instance=instance, data=data, files=files, partial=partial) + + def get_object(self): + try: + pk = int(self.kwargs['pk']) + self.object = BasicModel.objects.get(id=pk) + return self.object + except BasicModel.DoesNotExist: + return self.permission_denied(self.request) + + class SlugSerializer(serializers.ModelSerializer): slug = serializers.Field() # read only @@ -301,6 +323,157 @@ class TestInstanceView(TestCase): new_obj = SlugBasedModel.objects.get(slug='test_slug') self.assertEqual(new_obj.text, 'foobar') +class TestInstanceDetailView(TestCase): + """ + Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the + queryset/model mechanism but instead overrides get_object() + """ + def setUp(self): + """ + Create 3 BasicModel intances. + """ + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view_class = InstanceDetailView + self.view = InstanceDetailView.as_view() + + def test_get_instance_view(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should return a single object. + """ + request = factory.get('/1') + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0]) + + def test_post_instance_view(self): + """ + POST requests to RetrieveUpdateDestroyAPIView should not be allowed + """ + content = {'text': 'foobar'} + request = factory.post('/', json.dumps(content), + content_type='application/json') + with self.assertNumQueries(0): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) + + def test_put_instance_view(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView should update an object. + """ + content = {'text': 'foobar'} + request = factory.put('/1', json.dumps(content), + content_type='application/json') + with self.assertNumQueries(2): + response = self.view(request, pk='1').render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) + updated = self.objects.get(id=1) + self.assertEqual(updated.text, 'foobar') + + def test_patch_instance_view(self): + """ + PATCH requests to RetrieveUpdateDestroyAPIView should update an object. + """ + content = {'text': 'foobar'} + request = factory.patch('/1', json.dumps(content), + content_type='application/json') + + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) + updated = self.objects.get(id=1) + self.assertEqual(updated.text, 'foobar') + + def test_delete_instance_view(self): + """ + DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. + """ + request = factory.delete('/1') + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(response.content, six.b('')) + ids = [obj.id for obj in self.objects.all()] + self.assertEqual(ids, [2, 3]) + + def test_options_instance_view(self): + """ + OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata + """ + request = factory.options('/') + with self.assertNumQueries(0): + response = self.view(request).render() + expected = { + 'parses': [ + 'application/json', + 'application/x-www-form-urlencoded', + 'multipart/form-data' + ], + 'renders': [ + 'application/json', + 'text/html' + ], + 'name': 'Instance Detail', + 'description': 'Example detail view for override of get_object().' + } + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, expected) + + def test_put_cannot_set_id(self): + """ + PUT requests to create a new object should not be able to set the id. + """ + content = {'id': 999, 'text': 'foobar'} + request = factory.put('/1', json.dumps(content), + content_type='application/json') + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) + updated = self.objects.get(id=1) + self.assertEqual(updated.text, 'foobar') + + def test_put_to_deleted_instance(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView should create an object + if it does not currently exist. In our DetailView, however, + we cannot access any other id's than those that already exist. + See the InstanceView for the normal behaviour. + """ + self.objects.get(id=1).delete() + content = {'text': 'foobar'} + request = factory.put('/1', json.dumps(content), + content_type='application/json') + with self.assertNumQueries(1): + response = self.view(request, pk=5).render() + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_put_as_create_on_id_based_url(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView should create an object + at the requested url if it doesn't exist. In our DetailView, however, + we cannot access any other id's than those that already exist. + See the InstanceView for the normal behaviour. + """ + content = {'text': 'foobar'} + # pk fields can not be created on demand, only the database can set the pk for a new object + request = factory.put('/5', json.dumps(content), + content_type='application/json') + with self.assertNumQueries(1): + response = self.view(request, pk=5).render() + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + # Regression test for #285 -- cgit v1.2.3 From ad436d966fa9ee2f5817aa5c26612c82558c4262 Mon Sep 17 00:00:00 2001 From: Stephan Groß Date: Mon, 15 Apr 2013 12:40:18 +0200 Subject: Add DecimalField support --- rest_framework/fields.py | 75 +++++++++++++++++++ rest_framework/tests/fields.py | 165 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f3496b53..a1b9f546 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals import copy import datetime +from decimal import Decimal, DecimalException import inspect import re import warnings @@ -721,6 +722,80 @@ class FloatField(WritableField): raise ValidationError(msg) +class DecimalField(WritableField): + type_name = 'DecimalField' + form_field_class = forms.DecimalField + + default_error_messages = { + 'invalid': _('Enter a number.'), + 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), + 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), + 'max_digits': _('Ensure that there are no more than %s digits in total.'), + 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), + 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') + } + + def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): + self.max_value, self.min_value = max_value, min_value + self.max_digits, self.decimal_places = max_digits, decimal_places + super(DecimalField, self).__init__(self, *args, **kwargs) + + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) + + def from_native(self, value): + """ + Validates that the input is a decimal number. Returns a Decimal + instance. Returns None for empty values. Ensures that there are no more + than max_digits in the number, and no more than decimal_places digits + after the decimal point. + """ + if value in validators.EMPTY_VALUES: + return None + value = smart_text(value).strip() + try: + value = Decimal(value) + except DecimalException: + raise ValidationError(self.error_messages['invalid']) + return value + + def to_native(self, value): + if value is not None: + return str(value) + return value + + def validate(self, value): + super(DecimalField, self).validate(value) + if value in validators.EMPTY_VALUES: + return + # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, + # since it is never equal to itself. However, NaN is the only value that + # isn't equal to itself, so we can use this to identify NaN + if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): + raise ValidationError(self.error_messages['invalid']) + sign, digittuple, exponent = value.as_tuple() + decimals = abs(exponent) + # digittuple doesn't include any leading zeros. + digits = len(digittuple) + if decimals > digits: + # We have leading zeros up to or past the decimal point. Count + # everything past the decimal point as a digit. We do not count + # 0 before the decimal point as a digit since that would mean + # we would not allow max_digits = decimal_places. + digits = decimals + whole_digits = digits - decimals + + if self.max_digits is not None and digits > self.max_digits: + raise ValidationError(self.error_messages['max_digits'] % self.max_digits) + if self.decimal_places is not None and decimals > self.decimal_places: + raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) + if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): + raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) + return value + + class FileField(WritableField): use_files = True type_name = 'FileField' diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index 19c663d8..f833aa32 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -3,12 +3,14 @@ General serializer field tests. """ from __future__ import unicode_literals import datetime +from decimal import Decimal from django.db import models from django.test import TestCase from django.core import validators from rest_framework import serializers +from rest_framework.serializers import Serializer class TimestampedModel(models.Model): @@ -481,3 +483,166 @@ class TimeFieldTest(TestCase): self.assertEqual('04 - 00 [000000]', result_1) self.assertEqual('04 - 59 [000000]', result_2) self.assertEqual('04 - 59 [000200]', result_3) + + +class DecimalFieldTest(TestCase): + """ + Tests for the DecimalField from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts string values + """ + f = serializers.DecimalField() + result_1 = f.from_native('9000') + result_2 = f.from_native('1.00000001') + + self.assertEqual(Decimal('9000'), result_1) + self.assertEqual(Decimal('1.00000001'), result_2) + + def test_from_native_invalid_string(self): + """ + Make sure from_native() raises ValidationError on passing invalid string + """ + f = serializers.DecimalField() + + try: + f.from_native('123.45.6') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Enter a number."]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_integer(self): + """ + Make sure from_native() accepts integer values + """ + f = serializers.DecimalField() + result = f.from_native(9000) + + self.assertEqual(Decimal('9000'), result) + + def test_from_native_float(self): + """ + Make sure from_native() accepts float values + """ + f = serializers.DecimalField() + result = f.from_native(1.00000001) + + self.assertEqual(Decimal('1.00000001'), result) + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.DecimalField() + result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DecimalField() + result = f.from_native(None) + + self.assertEqual(result, None) + + def test_to_native(self): + """ + Make sure to_native() returns Decimal as string. + """ + f = serializers.DecimalField() + + result_1 = f.to_native(Decimal('9000')) + result_2 = f.to_native(Decimal('1.00000001')) + + self.assertEqual('9000', result_1) + self.assertEqual('1.00000001', result_2) + + def test_to_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DecimalField(required=False) + self.assertEqual(None, f.to_native(None)) + + def test_valid_serialization(self): + """ + Make sure the serializer works correctly + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_value=9010, + min_value=9000, + max_digits=6, + decimal_places=2) + + self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid()) + self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid()) + self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid()) + + self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid()) + self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid()) + self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid()) + + def test_raise_max_value(self): + """ + Make sure max_value violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_value=100) + + s = DecimalSerializer(data={'decimal_field': '123'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': [u'Ensure this value is less than or equal to 100.']}) + + def test_raise_min_value(self): + """ + Make sure min_value violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(min_value=100) + + s = DecimalSerializer(data={'decimal_field': '99'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': [u'Ensure this value is greater than or equal to 100.']}) + + def test_raise_max_digits(self): + """ + Make sure max_digits violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_digits=5) + + s = DecimalSerializer(data={'decimal_field': '123.456'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 5 digits in total.']}) + + def test_raise_max_decimal_places(self): + """ + Make sure max_decimal_places violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(decimal_places=3) + + s = DecimalSerializer(data={'decimal_field': '123.4567'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 3 decimal places.']}) + + def test_raise_max_whole_digits(self): + """ + Make sure max_whole_digits violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) + + s = DecimalSerializer(data={'decimal_field': '12345.6'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file -- cgit v1.2.3 From 37f7d8bc0f00feb1a4d23c0e163eab8b47faaec3 Mon Sep 17 00:00:00 2001 From: Stephan Groß Date: Mon, 15 Apr 2013 12:55:29 +0200 Subject: Fix unicodes --- rest_framework/tests/fields.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index f833aa32..597180b4 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -597,7 +597,7 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '123'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': [u'Ensure this value is less than or equal to 100.']}) + self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']}) def test_raise_min_value(self): """ @@ -609,7 +609,7 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '99'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': [u'Ensure this value is greater than or equal to 100.']}) + self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']}) def test_raise_max_digits(self): """ @@ -621,7 +621,7 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '123.456'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 5 digits in total.']}) + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']}) def test_raise_max_decimal_places(self): """ @@ -633,7 +633,7 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '123.4567'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 3 decimal places.']}) + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']}) def test_raise_max_whole_digits(self): """ @@ -645,4 +645,4 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '12345.6'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file -- cgit v1.2.3 From c329d2f08511dbc7660af9b8fc94e92d97c015cc Mon Sep 17 00:00:00 2001 From: Stephan Groß Date: Mon, 15 Apr 2013 13:11:41 +0200 Subject: Add DecimalField to field_mapping --- rest_framework/serializers.py | 1 + 1 file changed, 1 insertion(+) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e28bbe81..cbc6586d 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -548,6 +548,7 @@ class ModelSerializer(Serializer): models.DateTimeField: DateTimeField, models.DateField: DateField, models.TimeField: TimeField, + models.DecimalField: DecimalField, models.EmailField: EmailField, models.CharField: CharField, models.URLField: URLField, -- cgit v1.2.3 From 9d80f01bced913dae0859be525b39eaa9df1fdbf Mon Sep 17 00:00:00 2001 From: Stephan Groß Date: Mon, 15 Apr 2013 15:15:55 +0200 Subject: Fix init call --- rest_framework/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index a1b9f546..6be633db 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -738,7 +738,7 @@ class DecimalField(WritableField): def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): self.max_value, self.min_value = max_value, min_value self.max_digits, self.decimal_places = max_digits, decimal_places - super(DecimalField, self).__init__(self, *args, **kwargs) + super(DecimalField, self).__init__(*args, **kwargs) if max_value is not None: self.validators.append(validators.MaxValueValidator(max_value)) -- cgit v1.2.3 From cac669702596cdf768971267e6355fb9223a69e8 Mon Sep 17 00:00:00 2001 From: Stephan Groß Date: Mon, 15 Apr 2013 15:24:14 +0200 Subject: Return Decimal instance instead of string --- rest_framework/fields.py | 5 ----- rest_framework/tests/fields.py | 4 ++-- 2 files changed, 2 insertions(+), 7 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 6be633db..926195be 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -761,11 +761,6 @@ class DecimalField(WritableField): raise ValidationError(self.error_messages['invalid']) return value - def to_native(self, value): - if value is not None: - return str(value) - return value - def validate(self, value): super(DecimalField, self).validate(value) if value in validators.EMPTY_VALUES: diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index 597180b4..3cdfa0f6 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -559,8 +559,8 @@ class DecimalFieldTest(TestCase): result_1 = f.to_native(Decimal('9000')) result_2 = f.to_native(Decimal('1.00000001')) - self.assertEqual('9000', result_1) - self.assertEqual('1.00000001', result_2) + self.assertEqual(Decimal('9000'), result_1) + self.assertEqual(Decimal('1.00000001'), result_2) def test_to_native_none(self): """ -- cgit v1.2.3 From 37fe0bf0de25d28d792a291d5a84987ab71c4cb6 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Apr 2013 09:03:24 +0100 Subject: Remove unneccessary tests from #789, and bit of cleanup. --- rest_framework/tests/generics.py | 165 ++++----------------------------------- 1 file changed, 17 insertions(+), 148 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index b40b0102..4a13389a 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals from django.db import models +from django.shortcuts import get_object_or_404 from django.test import TestCase from rest_framework import generics, serializers, status from rest_framework.tests.utils import RequestFactory @@ -24,28 +25,6 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView): model = BasicModel -class InstanceDetailView(generics.RetrieveUpdateDestroyAPIView): - """ - Example detail view for override of get_object(). - """ - - # we have to implement this too, otherwise we can't be sure that get_object - # will be called - def get_serializer(self, instance=None, data=None, files=None, partial=None): - class InstanceDetailSerializer(serializers.ModelSerializer): - class Meta: - model = BasicModel - return InstanceDetailSerializer(instance=instance, data=data, files=files, partial=partial) - - def get_object(self): - try: - pk = int(self.kwargs['pk']) - self.object = BasicModel.objects.get(id=pk) - return self.object - except BasicModel.DoesNotExist: - return self.permission_denied(self.request) - - class SlugSerializer(serializers.ModelSerializer): slug = serializers.Field() # read only @@ -323,7 +302,8 @@ class TestInstanceView(TestCase): new_obj = SlugBasedModel.objects.get(slug='test_slug') self.assertEqual(new_obj.text, 'foobar') -class TestInstanceDetailView(TestCase): + +class TestOverriddenGetObject(TestCase): """ Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the queryset/model mechanism but instead overrides get_object() @@ -340,139 +320,28 @@ class TestInstanceDetailView(TestCase): {'id': obj.id, 'text': obj.text} for obj in self.objects.all() ] - self.view_class = InstanceDetailView - self.view = InstanceDetailView.as_view() - def test_get_instance_view(self): - """ - GET requests to RetrieveUpdateDestroyAPIView should return a single object. - """ - request = factory.get('/1') - with self.assertNumQueries(1): - response = self.view(request, pk=1).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, self.data[0]) + class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): + """ + Example detail view for override of get_object(). + """ + model = BasicModel - def test_post_instance_view(self): - """ - POST requests to RetrieveUpdateDestroyAPIView should not be allowed - """ - content = {'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') - with self.assertNumQueries(0): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) - - def test_put_instance_view(self): - """ - PUT requests to RetrieveUpdateDestroyAPIView should update an object. - """ - content = {'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') - with self.assertNumQueries(2): - response = self.view(request, pk='1').render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) - updated = self.objects.get(id=1) - self.assertEqual(updated.text, 'foobar') - - def test_patch_instance_view(self): - """ - PATCH requests to RetrieveUpdateDestroyAPIView should update an object. - """ - content = {'text': 'foobar'} - request = factory.patch('/1', json.dumps(content), - content_type='application/json') - - with self.assertNumQueries(2): - response = self.view(request, pk=1).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) - updated = self.objects.get(id=1) - self.assertEqual(updated.text, 'foobar') - - def test_delete_instance_view(self): - """ - DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. - """ - request = factory.delete('/1') - with self.assertNumQueries(2): - response = self.view(request, pk=1).render() - self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - self.assertEqual(response.content, six.b('')) - ids = [obj.id for obj in self.objects.all()] - self.assertEqual(ids, [2, 3]) + def get_object(self): + pk = int(self.kwargs['pk']) + return get_object_or_404(BasicModel.objects.all(), id=pk) - def test_options_instance_view(self): - """ - OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata - """ - request = factory.options('/') - with self.assertNumQueries(0): - response = self.view(request).render() - expected = { - 'parses': [ - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data' - ], - 'renders': [ - 'application/json', - 'text/html' - ], - 'name': 'Instance Detail', - 'description': 'Example detail view for override of get_object().' - } - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, expected) + self.view = OverriddenGetObjectView.as_view() - def test_put_cannot_set_id(self): + def test_overridden_get_object_view(self): """ - PUT requests to create a new object should not be able to set the id. + GET requests to RetrieveUpdateDestroyAPIView should return a single object. """ - content = {'id': 999, 'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') - with self.assertNumQueries(2): + request = factory.get('/1') + with self.assertNumQueries(1): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) - updated = self.objects.get(id=1) - self.assertEqual(updated.text, 'foobar') - - def test_put_to_deleted_instance(self): - """ - PUT requests to RetrieveUpdateDestroyAPIView should create an object - if it does not currently exist. In our DetailView, however, - we cannot access any other id's than those that already exist. - See the InstanceView for the normal behaviour. - """ - self.objects.get(id=1).delete() - content = {'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') - with self.assertNumQueries(1): - response = self.view(request, pk=5).render() - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - def test_put_as_create_on_id_based_url(self): - """ - PUT requests to RetrieveUpdateDestroyAPIView should create an object - at the requested url if it doesn't exist. In our DetailView, however, - we cannot access any other id's than those that already exist. - See the InstanceView for the normal behaviour. - """ - content = {'text': 'foobar'} - # pk fields can not be created on demand, only the database can set the pk for a new object - request = factory.put('/5', json.dumps(content), - content_type='application/json') - with self.assertNumQueries(1): - response = self.view(request, pk=5).render() - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - + self.assertEqual(response.data, self.data[0]) # Regression test for #285 -- cgit v1.2.3 From ea55143a2308b396c8df6f59a0f6d663c1067163 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Apr 2013 09:07:20 +0100 Subject: Version 2.2.7 --- rest_framework/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 7ac12058..856badc6 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.2.6' +__version__ = '2.2.7' VERSION = __version__ # synonym -- cgit v1.2.3 From 33f494fcc89711ab7e97f47fe8d9b287aac4730f Mon Sep 17 00:00:00 2001 From: forgingdestiny Date: Wed, 17 Apr 2013 10:14:36 -0400 Subject: add branding and style blocks --- .../templates/rest_framework/login_base.html | 55 ++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 rest_framework/templates/rest_framework/login_base.html (limited to 'rest_framework') diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html new file mode 100644 index 00000000..380d5820 --- /dev/null +++ b/rest_framework/templates/rest_framework/login_base.html @@ -0,0 +1,55 @@ +{% load url from future %} +{% load rest_framework %} + + + + {% block style %} + {% block bootstrap_theme %}{% endblock %} + + + {% endblock %} + + + + +
+
+ +
+
+
+ {% block branding %}

Django REST framework

{% endblock %} +
+
+ +
+
+
+ {% csrf_token %} +
+
+ + +
+
+
+
+ + +
+
+ +
+ +
+
+
+
+
+ +
+
+ + + + -- cgit v1.2.3 From 03c736338fa04092da99d7d9ea202c8778998b38 Mon Sep 17 00:00:00 2001 From: forgingdestiny Date: Wed, 17 Apr 2013 10:15:02 -0400 Subject: extend base login template --- rest_framework/templates/rest_framework/login.html | 54 +--------------------- 1 file changed, 2 insertions(+), 52 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html index e10ce20f..b7629327 100644 --- a/rest_framework/templates/rest_framework/login.html +++ b/rest_framework/templates/rest_framework/login.html @@ -1,53 +1,3 @@ -{% load url from future %} -{% load rest_framework %} - +{% extends "rest_framework/login_base.html" %} - - - - - - - - -
-
- -
-
-
-

Django REST framework

-
-
- -
-
-
- {% csrf_token %} -
-
- - -
-
-
-
- - -
-
- -
- -
-
-
-
-
- -
-
- - - - +{# Override this template in your own templates directory to customize #} -- cgit v1.2.3 From 4bf1a09baeb885863e6028b97c2d51b26fb18534 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 23 Apr 2013 11:31:38 +0100 Subject: Ensure implementation of reverse relations in 'fields' is backwards compatible --- rest_framework/permissions.py | 2 +- rest_framework/serializers.py | 122 +++++++++++++++------------- rest_framework/tests/relations_hyperlink.py | 16 ++-- rest_framework/tests/relations_nested.py | 24 ++---- rest_framework/tests/relations_pk.py | 17 ++-- 5 files changed, 95 insertions(+), 86 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index ae895f39..2aa45c71 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -25,7 +25,7 @@ class BasePermission(object): """ Return `True` if permission is granted, `False` otherwise. """ - if len(inspect.getargspec(self.has_permission)[0]) == 4: + if len(inspect.getargspec(self.has_permission).args) == 4: warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. ' 'Use `has_object_permission()` instead for object permissions.', PendingDeprecationWarning, stacklevel=2) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index eac909c7..b4327af1 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -568,54 +568,73 @@ class ModelSerializer(Serializer): assert cls is not None, \ "Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__ opts = get_concrete_model(cls)._meta - pk_field = opts.pk + ret = SortedDict() + nested = bool(self.opts.depth) - # If model is a child via multitable inheritance, use parent's pk + # Deal with adding the primary key field + pk_field = opts.pk while pk_field.rel and pk_field.rel.parent_link: + # If model is a child via multitable inheritance, use parent's pk pk_field = pk_field.rel.to._meta.pk - fields = [pk_field] - fields += [field for field in opts.fields if field.serialize] - fields += [field for field in opts.many_to_many if field.serialize] + field = self.get_pk_field(pk_field) + if field: + ret[pk_field.name] = field - ret = SortedDict() - nested = bool(self.opts.depth) - is_pk = True # First field in the list is the pk - - for model_field in fields: - if is_pk: - field = self.get_pk_field(model_field) - is_pk = False - elif model_field.rel and nested: - field = self.get_nested_field(model_field) - elif model_field.rel: + # Deal with forward relationships + forward_rels = [field for field in opts.fields if field.serialize] + forward_rels += [field for field in opts.many_to_many if field.serialize] + + for model_field in forward_rels: + if model_field.rel: to_many = isinstance(model_field, models.fields.related.ManyToManyField) - field = self.get_related_field(model_field, to_many=to_many) + related_model = model_field.rel.to + + if model_field.rel and nested: + if len(inspect.getargspec(self.get_nested_field).args) == 2: + # TODO: deprecation warning + field = self.get_nested_field(model_field) + else: + field = self.get_nested_field(model_field, related_model, to_many) + elif model_field.rel: + if len(inspect.getargspec(self.get_nested_field).args) == 3: + # TODO: deprecation warning + field = self.get_related_field(model_field, to_many=to_many) + else: + field = self.get_related_field(model_field, related_model, to_many) else: field = self.get_field(model_field) if field: ret[model_field.name] = field - # Reverse relationships are only included if they are explicitly - # present in `Meta.fields`. - if self.opts.fields: - reverse = opts.get_all_related_objects() - reverse += opts.get_all_related_many_to_many_objects() - for rel in reverse: - name = rel.get_accessor_name() - if name not in self.opts.fields: - continue - - if nested: - field = self.get_nested_field(None, rel) - else: - field = self.get_related_field(None, rel, to_many=True) + # Deal with reverse relationships + if not self.opts.fields: + reverse_rels = [] + else: + # Reverse relationships are only included if they are explicitly + # present in the `fields` option on the serializer + reverse_rels = opts.get_all_related_objects() + reverse_rels += opts.get_all_related_many_to_many_objects() + + for relation in reverse_rels: + accessor_name = relation.get_accessor_name() + if accessor_name not in self.opts.fields: + continue + related_model = relation.model + to_many = relation.field.rel.multiple - if field: - ret[name] = field + if nested: + field = self.get_nested_field(None, related_model, to_many) + else: + field = self.get_related_field(None, related_model, to_many) + + if field: + ret[accessor_name] = field + # Add the `read_only` flag to any fields that have bee specified + # in the `read_only_fields` option for field_name in self.opts.read_only_fields: assert field_name in ret, \ "read_only_fields on '%s' included invalid item '%s'" % \ @@ -630,39 +649,30 @@ class ModelSerializer(Serializer): """ return self.get_field(model_field) - def get_nested_field(self, model_field, rel=None): + def get_nested_field(self, model_field, related_model, to_many): """ Creates a default instance of a nested relational field. """ - if rel: - model_class = rel.model - else: - model_class = model_field.rel.to - class NestedModelSerializer(ModelSerializer): class Meta: - model = model_class - return NestedModelSerializer() + model = related_model + return NestedModelSerializer(many=to_many) - def get_related_field(self, model_field, rel=None, to_many=False): + def get_related_field(self, model_field, related_model, to_many): """ Creates a default instance of a flat relational field. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - if rel: - model_class = rel.model - required = True - else: - model_class = model_field.rel.to - required = not(model_field.null or model_field.blank) kwargs = { - 'required': required, - 'queryset': model_class._default_manager, + 'queryset': related_model._default_manager, 'many': to_many } + if model_field: + kwargs['required'] = not(model_field.null or model_field.blank) + return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): @@ -830,19 +840,21 @@ class HyperlinkedModelSerializer(ModelSerializer): if self.opts.fields and model_field.name in self.opts.fields: return self.get_field(model_field) - def get_related_field(self, model_field, to_many): + def get_related_field(self, model_field, related_model, to_many): """ Creates a default instance of a flat relational field. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - rel = model_field.rel.to kwargs = { - 'required': not(model_field.null or model_field.blank), - 'queryset': rel._default_manager, - 'view_name': self._get_default_view_name(rel), + 'queryset': related_model._default_manager, + 'view_name': self._get_default_view_name(related_model), 'many': to_many } + + if model_field: + kwargs['required'] = not(model_field.null or model_field.blank) + return HyperlinkedRelatedField(**kwargs) def get_identity(self, data): diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index b5702a48..b1eed9a7 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -26,42 +26,44 @@ urlpatterns = patterns('', ) +# ManyToMany class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): - sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail') - class Meta: model = ManyToManyTarget + fields = ('url', 'name', 'sources') class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = ManyToManySource + fields = ('url', 'name', 'targets') +# ForeignKey class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): - sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail') - class Meta: model = ForeignKeyTarget + fields = ('url', 'name', 'sources') class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = ForeignKeySource + fields = ('url', 'name', 'target') # Nullable ForeignKey class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = NullableForeignKeySource + fields = ('url', 'name', 'target') -# OneToOne +# Nullable OneToOne class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): - nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail') - class Meta: model = OneToOneTarget + fields = ('url', 'name', 'nullable_source') # TODO: Add test that .data cannot be accessed prior to .is_valid diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index a125ba65..f6d006b3 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -5,39 +5,31 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null class ForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - depth = 1 - model = ForeignKeySource - - -class FlatForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = ForeignKeySource + fields = ('id', 'name', 'target') + depth = 1 class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = FlatForeignKeySourceSerializer(many=True) - class Meta: model = ForeignKeyTarget + fields = ('id', 'name', 'sources') + depth = 1 class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: - depth = 1 model = NullableForeignKeySource - - -class NullableOneToOneSourceSerializer(serializers.ModelSerializer): - class Meta: - model = NullableOneToOneSource + fields = ('id', 'name', 'target') + depth = 1 class NullableOneToOneTargetSerializer(serializers.ModelSerializer): - nullable_source = NullableOneToOneSourceSerializer() - class Meta: model = OneToOneTarget + fields = ('id', 'name', 'nullable_source') + depth = 1 class ReverseForeignKeyTests(TestCase): diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index f08e1808..5ce8b567 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -5,41 +5,44 @@ from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, Fore from rest_framework.compat import six +# ManyToMany class ManyToManyTargetSerializer(serializers.ModelSerializer): - sources = serializers.PrimaryKeyRelatedField(many=True) - class Meta: model = ManyToManyTarget + fields = ('id', 'name', 'sources') class ManyToManySourceSerializer(serializers.ModelSerializer): class Meta: model = ManyToManySource + fields = ('id', 'name', 'targets') +# ForeignKey class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = serializers.PrimaryKeyRelatedField(many=True) - class Meta: model = ForeignKeyTarget + fields = ('id', 'name', 'sources') class ForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = ForeignKeySource + fields = ('id', 'name', 'target') +# Nullable ForeignKey class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = NullableForeignKeySource + fields = ('id', 'name', 'target') -# OneToOne +# Nullable OneToOne class NullableOneToOneTargetSerializer(serializers.ModelSerializer): - nullable_source = serializers.PrimaryKeyRelatedField() - class Meta: model = OneToOneTarget + fields = ('id', 'name', 'nullable_source') # TODO: Add test that .data cannot be accessed prior to .is_valid -- cgit v1.2.3 From b94da2468cdda6b0ad491574d35097d0e336ea7f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 24 Apr 2013 22:40:24 +0100 Subject: Various clean up and lots of docs --- rest_framework/generics.py | 74 ++++++++++----- rest_framework/routers.py | 231 ++++++++++++++++++++++++++++++++++----------- 2 files changed, 222 insertions(+), 83 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index ae03060b..3440c01d 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -18,21 +18,35 @@ class GenericAPIView(views.APIView): Base class for all other generic views. """ + # You'll need to either set these attributes, + # or override `get_queryset`/`get_serializer_class`. queryset = None serializer_class = None - # Shortcut which may be used in place of `queryset`/`serializer_class` - model = None + # If you want to use object lookups other than pk, set this attribute. + lookup_field = 'pk' - filter_backend = api_settings.FILTER_BACKEND + # Pagination settings paginate_by = api_settings.PAGINATE_BY paginate_by_param = api_settings.PAGINATE_BY_PARAM pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS - model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS page_kwarg = 'page' - lookup_field = 'pk' + + # The filter backend class to use for queryset filtering + filter_backend = api_settings.FILTER_BACKEND + + # Determines if the view will return 200 or 404 responses for empty lists. allow_empty = True + # This shortcut may be used instead of setting either (or both) + # of the `queryset`/`serializer_class` attributes, although using + # the explicit style is generally preferred. + model = None + + # If the `model` shortcut is used instead of `serializer_class`, then the + # serializer class will be constructed using this class as the base. + model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS + ###################################### # These are pending deprecation... @@ -61,7 +75,7 @@ class GenericAPIView(views.APIView): return serializer_class(instance, data=data, files=files, many=many, partial=partial, context=context) - def get_pagination_serializer(self, page=None): + def get_pagination_serializer(self, page): """ Return a serializer instance to use with paginated data. """ @@ -73,32 +87,15 @@ class GenericAPIView(views.APIView): context = self.get_serializer_context() return pagination_serializer_class(instance=page, context=context) - def get_paginate_by(self, queryset=None): - """ - Return the size of pages to use with pagination. - - If `PAGINATE_BY_PARAM` is set it will attempt to get the page size - from a named query parameter in the url, eg. ?page_size=100 - - Otherwise defaults to using `self.paginate_by`. - """ - if self.paginate_by_param: - query_params = self.request.QUERY_PARAMS - try: - return int(query_params[self.paginate_by_param]) - except (KeyError, ValueError): - pass - - return self.paginate_by - def paginate_queryset(self, queryset, page_size, paginator_class=Paginator): """ Paginate a queryset. """ paginator = paginator_class(queryset, page_size, allow_empty_first_page=self.allow_empty) - page_kwarg = self.page_kwarg - page = self.kwargs.get(page_kwarg) or self.request.GET.get(page_kwarg) or 1 + page_kwarg = self.kwargs.get(self.page_kwarg) + page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) + page = page_kwarg or page_query_param or 1 try: page_number = int(page) except ValueError: @@ -133,6 +130,27 @@ class GenericAPIView(views.APIView): ### The following methods provide default implementations ### that you may want to override for more complex cases. + def get_paginate_by(self, queryset=None): + """ + Return the size of pages to use with pagination. + + If `PAGINATE_BY_PARAM` is set it will attempt to get the page size + from a named query parameter in the url, eg. ?page_size=100 + + Otherwise defaults to using `self.paginate_by`. + """ + if queryset is not None: + pass # TODO: Deprecation warning + + if self.paginate_by_param: + query_params = self.request.QUERY_PARAMS + try: + return int(query_params[self.paginate_by_param]) + except (KeyError, ValueError): + pass + + return self.paginate_by + def get_serializer_class(self): """ Return the class to use for the serializer. @@ -202,6 +220,7 @@ class GenericAPIView(views.APIView): # TODO: Deprecation warning filter_kwargs = {self.slug_field: slug} else: + # TODO: Fix error message raise AttributeError("Generic detail view %s must be called with " "either an object pk or a slug." % self.__class__.__name__) @@ -216,6 +235,9 @@ class GenericAPIView(views.APIView): ######################## ### The following are placeholder methods, ### and are intended to be overridden. + ### + ### The are not called by GenericAPIView directly, + ### but are used by the mixin methods. def pre_save(self, obj): """ diff --git a/rest_framework/routers.py b/rest_framework/routers.py index afc51f3b..febb02b3 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -1,81 +1,198 @@ +""" +Routers provide a convenient and consistent way of automatically +determining the URL conf for your API. + +They are used by simply instantiating a Router class, and then registering +all the required ViewSets with that router. + +For example, you might have a `urls.py` that looks something like this: + + router = routers.DefaultRouter() + router.register('users', UserViewSet, 'user') + router.register('accounts', AccountViewSet, 'account') + + urlpatterns = router.urls +""" from django.conf.urls import url, patterns +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.reverse import reverse +from rest_framework.urlpatterns import format_suffix_patterns class BaseRouter(object): def __init__(self): self.registry = [] - def register(self, prefix, viewset, base_name): - self.registry.append((prefix, viewset, base_name)) + def register(self, prefix, viewset, basename): + self.registry.append((prefix, viewset, basename)) - def get_urlpatterns(self): - raise NotImplemented('get_urlpatterns must be overridden') + def get_urls(self): + raise NotImplemented('get_urls must be overridden') @property - def urlpatterns(self): - if not hasattr(self, '_urlpatterns'): - self._urlpatterns = patterns('', *self.get_urlpatterns()) - return self._urlpatterns - - -class DefaultRouter(BaseRouter): - route_list = [ - (r'$', { - 'get': 'list', - 'post': 'create' - }, 'list'), - (r'(?P[^/]+)/$', { - 'get': 'retrieve', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy' - }, 'detail'), + def urls(self): + if not hasattr(self, '_urls'): + self._urls = patterns('', *self.get_urls()) + return self._urls + + +class SimpleRouter(BaseRouter): + routes = [ + # List route. + ( + r'^{prefix}/$', + { + 'get': 'list', + 'post': 'create' + }, + '{basename}-list' + ), + # Detail route. + ( + r'^{prefix}/{lookup}/$', + { + 'get': 'retrieve', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy' + }, + '{basename}-detail' + ), + # Dynamically generated routes. + # Generated using @action or @link decorators on methods of the viewset. + ( + r'^{prefix}/{lookup}/{methodname}/$', + { + '{httpmethod}': '{methodname}', + }, + '{basename}-{methodname}' + ), ] - extra_routes = r'(?P[^/]+)/%s/$' - name_format = '%s-%s' - def get_urlpatterns(self): + def get_routes(self, viewset): + """ + Augment `self.routes` with any dynamically generated routes. + + Returns a list of 4-tuples, of the form: + `(url_format, method_map, name_format, extra_kwargs)` + """ + + # Determine any `@action` or `@link` decorated methods on the viewset + dynamic_routes = {} + for methodname in dir(viewset): + attr = getattr(viewset, methodname) + httpmethod = getattr(attr, 'bind_to_method', None) + if httpmethod: + dynamic_routes[httpmethod] = methodname + ret = [] - for prefix, viewset, base_name in self.registry: - # Bind regular views - if not getattr(viewset, '_is_viewset', False): - regex = prefix - view = viewset - name = base_name - ret.append(url(regex, view, name=name)) - continue + for url_format, method_map, name_format in self.routes: + if method_map == {'{httpmethod}': '{methodname}'}: + # Dynamic routes + for httpmethod, methodname in dynamic_routes.items(): + extra_kwargs = getattr(viewset, methodname).kwargs + ret.append(( + url_format.replace('{methodname}', methodname), + {httpmethod: methodname}, + name_format.replace('{methodname}', methodname), + extra_kwargs + )) + else: + # Standard route + extra_kwargs = {} + ret.append((url_format, method_map, name_format, extra_kwargs)) - # Bind standard CRUD routes - for suffix, action_mapping, action_name in self.route_list: + return ret - # Only actions which actually exist on the viewset will be bound - bound_actions = {} - for method, action in action_mapping.items(): - if hasattr(viewset, action): - bound_actions[method] = action + def get_method_map(self, viewset, method_map): + """ + Given a viewset, and a mapping of http methods to actions, + return a new mapping which only includes any mappings that + are actually implemented by the viewset. + """ + bound_methods = {} + for method, action in method_map.items(): + if hasattr(viewset, action): + bound_methods[method] = action + return bound_methods + + def get_lookup_regex(self, viewset): + """ + Given a viewset, return the portion of URL regex that is used + to match against a single instance. + """ + base_regex = '(?P<{lookup_field}>[^/]+)' + lookup_field = getattr(viewset, 'lookup_field', 'pk') + return base_regex.format(lookup_field=lookup_field) + + def get_urls(self): + """ + Use the registered viewsets to generate a list of URL patterns. + """ + ret = [] - # Build the url pattern - regex = prefix + suffix - view = viewset.as_view(bound_actions, name_suffix=action_name) - name = self.name_format % (base_name, action_name) - ret.append(url(regex, view, name=name)) + for prefix, viewset, basename in self.registry: + lookup = self.get_lookup_regex(viewset) + routes = self.get_routes(viewset) - # Bind any extra `@action` or `@link` routes - for action_name in dir(viewset): - func = getattr(viewset, action_name) - http_method = getattr(func, 'bind_to_method', None) + for url_format, method_map, name_format, extra_kwargs in routes: - # Skip if this is not an @action or @link method - if not http_method: + # Only actions which actually exist on the viewset will be bound + method_map = self.get_method_map(viewset, method_map) + if not method_map: continue - suffix = self.extra_routes % action_name - # Build the url pattern - regex = prefix + suffix - view = viewset.as_view({http_method: action_name}, **func.kwargs) - name = self.name_format % (base_name, action_name) + regex = url_format.format(prefix=prefix, lookup=lookup) + view = viewset.as_view(method_map, **extra_kwargs) + name = name_format.format(basename=basename) ret.append(url(regex, view, name=name)) - # Return a list of url patterns return ret + + +class DefaultRouter(SimpleRouter): + """ + The default router extends the SimpleRouter, but also adds in a default + API root view, and adds format suffix patterns to the URLs. + """ + include_root_view = True + include_format_suffixes = True + + def get_api_root_view(self): + """ + Return a view to use as the API root. + """ + api_root_dict = {} + list_name = self.routes[0][-1] + for prefix, viewset, basename in self.registry: + api_root_dict[prefix] = list_name.format(basename=basename) + + @api_view(('GET',)) + def api_root(request, format=None): + ret = {} + for key, url_name in api_root_dict.items(): + ret[key] = reverse(url_name, request=request, format=format) + return Response(ret) + + return api_root + + def get_urls(self): + """ + Generate the list of URL patterns, including a default root view + for the API, and appending `.json` style format suffixes. + """ + urls = [] + + if self.include_root_view: + root_url = url(r'^$', self.get_api_root_view(), name='api-root') + urls.append(root_url) + + default_urls = super(DefaultRouter, self).get_urls() + urls.extend(default_urls) + + if self.include_format_suffixes: + urls = format_suffix_patterns(urls) + + return urls -- cgit v1.2.3 From 95abe6e8445f59f9e52609b0c54d9276830dbfd3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Apr 2013 12:47:34 +0100 Subject: Cleanup docstrings --- rest_framework/authentication.py | 2 +- rest_framework/decorators.py | 8 +++++++ rest_framework/fields.py | 5 +++++ rest_framework/filters.py | 4 ++++ rest_framework/generics.py | 2 -- rest_framework/negotiation.py | 4 ++++ rest_framework/pagination.py | 6 +++-- rest_framework/relations.py | 6 +++++ rest_framework/request.py | 5 ++--- rest_framework/response.py | 6 +++++ rest_framework/serializers.py | 12 ++++++++++ rest_framework/throttling.py | 8 ++++--- rest_framework/views.py | 2 +- rest_framework/viewsets.py | 48 +++++++++++++++++++++++++++++----------- 14 files changed, 93 insertions(+), 25 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 1eebb5b9..9caca788 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -1,5 +1,5 @@ """ -Provides a set of pluggable authentication policies. +Provides various authentication policies. """ from __future__ import unicode_literals import base64 diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 00b37f8b..81e585e1 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,3 +1,11 @@ +""" +The most imporant decorator in this module is `@api_view`, which is used +for writing function-based views with REST framework. + +There are also various decorators for setting the API policies on function +based views, as well as the `@action` and `@link` decorators, which are +used to annotate methods on viewsets that should be included by routers. +""" from __future__ import unicode_literals from rest_framework.compat import six from rest_framework.views import APIView diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f3496b53..949f68d6 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,3 +1,8 @@ +""" +Serializer fields perform validation on incoming data. + +They are very similar to Django's form fields. +""" from __future__ import unicode_literals import copy diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 413fa0d2..5e1cdbac 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -1,3 +1,7 @@ +""" +Provides generic filtering backends that can be used to filter the results +returned by list views. +""" from __future__ import unicode_literals from rest_framework.compat import django_filters diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 3440c01d..56471cfa 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -10,8 +10,6 @@ from django.http import Http404 from django.shortcuts import get_object_or_404 from django.utils.translation import ugettext as _ -### Base classes for the generic views ### - class GenericAPIView(views.APIView): """ diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index 0694d35f..4d205c0e 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -1,3 +1,7 @@ +""" +Content negotiation deals with selecting an appropriate renderer given the +incoming request. Typically this will be based on the request's Accept header. +""" from __future__ import unicode_literals from django.http import Http404 from rest_framework import exceptions diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 03a7a30f..d51ea929 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -1,9 +1,11 @@ +""" +Pagination serializers determine the structure of the output that should +be used for paginated responses. +""" from __future__ import unicode_literals from rest_framework import serializers from rest_framework.templatetags.rest_framework import replace_query_param -# TODO: Support URLconf kwarg-style paging - class NextPageField(serializers.Field): """ diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 2a10e9af..6bda7418 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,3 +1,9 @@ +""" +Serializer fields that deal with relationships. + +These fields allow you to specify the style that should be used to represent +model relationships, including hyperlinks, primary keys, or slugs. +""" from __future__ import unicode_literals from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch diff --git a/rest_framework/request.py b/rest_framework/request.py index ffbbab33..a434659c 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -1,11 +1,10 @@ """ -The :mod:`request` module provides a :class:`Request` class used to wrap the standard `request` -object received in all the views. +The Request class is used as a wrapper around the standard request object. The wrapped request then offers a richer API, in particular : - content automatically parsed according to `Content-Type` header, - and available as :meth:`.DATA` + and available as `request.DATA` - full support of PUT method, including support for file uploads - form overloading of HTTP method, content type and content """ diff --git a/rest_framework/response.py b/rest_framework/response.py index 5e1bf46e..26e4ab37 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -1,3 +1,9 @@ +""" +The Response class in REST framework is similiar to HTTPResponse, except that +it is initialized with unrendered data, instead of a pre-rendered string. + +The appropriate renderer is called during Django's template response rendering. +""" from __future__ import unicode_literals from django.core.handlers.wsgi import STATUS_CODE_TEXT from django.template.response import SimpleTemplateResponse diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b4327af1..fb438b12 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1,3 +1,15 @@ +""" +Serializers and ModelSerializers are similar to Forms and ModelForms. +Unlike forms, they are not constrained to dealing with HTML output, and +form encoded input. + +Serialization in REST framework is a two-phase process: + +1. Serializers marshal between complex types like model instances, and +python primatives. +2. The process of marshalling between python primatives and request and +response content is handled by parsers and renderers. +""" from __future__ import unicode_literals import copy import datetime diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 810cad63..93ea9816 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -1,3 +1,6 @@ +""" +Provides various throttling policies. +""" from __future__ import unicode_literals from django.core.cache import cache from rest_framework import exceptions @@ -28,9 +31,8 @@ class SimpleRateThrottle(BaseThrottle): A simple cache implementation, that only requires `.get_cache_key()` to be overridden. - The rate (requests / seconds) is set by a :attr:`throttle` attribute - on the :class:`.View` class. The attribute is a string of the form 'number of - requests/period'. + The rate (requests / seconds) is set by a `throttle` attribute on the View + class. The attribute is a string of the form 'number_of_requests/period'. Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') diff --git a/rest_framework/views.py b/rest_framework/views.py index b8e948e0..555fa2f4 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -1,5 +1,5 @@ """ -Provides an APIView class that is used as the base of all class-based views. +Provides an APIView class that is the base of all views in REST framework. """ from __future__ import unicode_literals from django.core.exceptions import PermissionDenied diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 28ab30e2..9133fd44 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -1,3 +1,21 @@ +""" +ViewSets are essentially just a type of class based view, that doesn't provide +any method handlers, such as `get()`, `post()`, etc... but instead has actions, +such as `list()`, `retrieve()`, `create()`, etc... + +Actions are only bound to methods at the point of instantiating the views. + + user_list = UserViewSet.as_view({'get': 'list'}) + user_detail = UserViewSet.as_view({'get': 'retrieve'}) + +Typically, rather than instantiate views from viewsets directly, you'll +regsiter the viewset with a router and let the URL conf be determined +automatically. + + router = DefaultRouter() + router.register(r'users', UserViewSet, 'user') + urlpatterns = router.urls +""" from functools import update_wrapper from django.utils.decorators import classonlymethod from rest_framework import views, generics, mixins @@ -15,13 +33,10 @@ class ViewSetMixin(object): view = MyViewSet.as_view({'get': 'list', 'post': 'create'}) """ - _is_viewset = True @classonlymethod def as_view(cls, actions=None, name_suffix=None, **initkwargs): """ - Main entry point for a request-response process. - Because of the way class based views create a closure around the instantiated view, we need to totally reimplement `.as_view`, and slightly modify the view function that is created and returned. @@ -64,12 +79,22 @@ class ViewSetMixin(object): class ViewSet(ViewSetMixin, views.APIView): + """ + The base ViewSet class does not provide any actions by default. + """ + pass + + +class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, + mixins.ListModelMixin, + ViewSetMixin, + generics.GenericAPIView): + """ + A viewset that provides default `list()` and `retrieve()` actions. + """ pass -# Note the inheritence of both MultipleObjectAPIView *and* SingleObjectAPIView -# is a bit weird given the diamond inheritence, but it will work for now. -# There's some implementation clean up that can happen later. class ModelViewSet(mixins.CreateModelMixin, mixins.RetrieveModelMixin, mixins.UpdateModelMixin, @@ -77,11 +102,8 @@ class ModelViewSet(mixins.CreateModelMixin, mixins.ListModelMixin, ViewSetMixin, generics.GenericAPIView): - pass - - -class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, - mixins.ListModelMixin, - ViewSetMixin, - generics.GenericAPIView): + """ + A viewset that provides default `create()`, `retrieve()`, `update()`, + `partial_update()`, `destroy()` and `list()` actions. + """ pass -- cgit v1.2.3 From 5d01ae661fcf85016718041e021b4bca524dfcdc Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Apr 2013 17:40:17 +0100 Subject: Simplify paginate_queryset method --- rest_framework/generics.py | 29 ++++++++++++++++++++++------- rest_framework/mixins.py | 9 +++------ 2 files changed, 25 insertions(+), 13 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 56471cfa..a18584d4 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -45,6 +45,8 @@ class GenericAPIView(views.APIView): # serializer class will be constructed using this class as the base. model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS + _paginator_class = Paginator + ###################################### # These are pending deprecation... @@ -85,12 +87,24 @@ class GenericAPIView(views.APIView): context = self.get_serializer_context() return pagination_serializer_class(instance=page, context=context) - def paginate_queryset(self, queryset, page_size, paginator_class=Paginator): + def paginate_queryset(self, queryset, page_size=None): """ - Paginate a queryset. + Paginate a queryset if required, either returning a page object, + or `None` if pagination is not configured for this view. """ - paginator = paginator_class(queryset, page_size, - allow_empty_first_page=self.allow_empty) + deprecated_style = False + if page_size is not None: + # TODO: Deperecation warning + deprecated_style = True + else: + # Determine the required page size. + # If pagination is not configured, simply return None. + page_size = self.get_paginate_by() + if not page_size: + return None + + paginator = self._paginator_class(queryset, page_size, + allow_empty_first_page=self.allow_empty) page_kwarg = self.kwargs.get(self.page_kwarg) page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) page = page_kwarg or page_query_param or 1 @@ -103,13 +117,16 @@ class GenericAPIView(views.APIView): raise Http404(_("Page is not 'last', nor can it be converted to an int.")) try: page = paginator.page(page_number) - return (paginator, page, page.object_list, page.has_other_pages()) except InvalidPage as e: raise Http404(_('Invalid page (%(page_number)s): %(message)s') % { 'page_number': page_number, 'message': str(e) }) + if deprecated_style: + return (paginator, page, page.object_list, page.has_other_pages()) + return page + def filter_queryset(self, queryset): """ Given a queryset, filter it with whichever filter backend is in use. @@ -163,7 +180,6 @@ class GenericAPIView(views.APIView): if serializer_class is not None: return serializer_class - # TODO: Deprecation warning class DefaultSerializer(self.model_serializer_class): class Meta: model = self.model @@ -184,7 +200,6 @@ class GenericAPIView(views.APIView): return self.queryset._clone() if self.model is not None: - # TODO: Deprecation warning return self.model._default_manager.all() raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 6e40b5c4..ec751e24 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -76,12 +76,9 @@ class ListModelMixin(object): error_msg = self.empty_error % {'class_name': class_name} raise Http404(error_msg) - # Pagination size is set by the `.paginate_by` attribute, - # which may be `None` to disable pagination. - page_size = self.get_paginate_by() - if page_size: - packed = self.paginate_queryset(self.object_list, page_size) - paginator, page, queryset, is_paginated = packed + # Switch between paginated or standard style responses + page = self.paginate_queryset(self.object_list) + if page is not None: serializer = self.get_pagination_serializer(page) else: serializer = self.get_serializer(self.object_list, many=True) -- cgit v1.2.3 From 7268a5c571bce323ccc75eb039b7c3f1b2b32391 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Apr 2013 17:41:47 +0100 Subject: Added AutoRouter. Don't know if this is a good idea. --- rest_framework/routers.py | 46 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index febb02b3..b7052218 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -14,12 +14,26 @@ For example, you might have a `urls.py` that looks something like this: urlpatterns = router.urls """ from django.conf.urls import url, patterns +from django.db import models from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.reverse import reverse +from rest_framework.viewsets import ModelViewSet from rest_framework.urlpatterns import format_suffix_patterns +def replace_methodname(format_string, methodname): + """ + Partially format a format_string, swapping out any + '{methodname}'' or '{methodnamehyphen}'' components. + """ + methodnamehyphen = methodname.replace('_', '-') + ret = format_string + ret = ret.replace('{methodname}', methodname) + ret = ret.replace('{methodnamehyphen}', methodnamehyphen) + return ret + + class BaseRouter(object): def __init__(self): self.registry = [] @@ -66,7 +80,7 @@ class SimpleRouter(BaseRouter): { '{httpmethod}': '{methodname}', }, - '{basename}-{methodname}' + '{basename}-{methodnamehyphen}' ), ] @@ -89,13 +103,13 @@ class SimpleRouter(BaseRouter): ret = [] for url_format, method_map, name_format in self.routes: if method_map == {'{httpmethod}': '{methodname}'}: - # Dynamic routes + # Dynamic routes (@link or @action decorator) for httpmethod, methodname in dynamic_routes.items(): extra_kwargs = getattr(viewset, methodname).kwargs ret.append(( - url_format.replace('{methodname}', methodname), + replace_methodname(url_format, methodname), {httpmethod: methodname}, - name_format.replace('{methodname}', methodname), + replace_methodname(name_format, methodname), extra_kwargs )) else: @@ -196,3 +210,27 @@ class DefaultRouter(SimpleRouter): urls = format_suffix_patterns(urls) return urls + + +class AutoRouter(DefaultRouter): + """ + A router class that doesn't require you to register any viewsets, + but instead automatically creates routes for all installed models. + + Useful for quick and dirty prototyping. + """ + def __init__(self): + super(AutoRouter, self).__init__() + for model in models.get_models(): + prefix = model._meta.verbose_name_plural.replace(' ', '_') + basename = model._meta.object_name.lower() + classname = model.__name__ + + DynamicViewSet = type( + classname, + (ModelViewSet,), + {} + ) + DynamicViewSet.model = model + + self.register(prefix, DynamicViewSet, basename) -- cgit v1.2.3 From 8fa79a7fd38dda015afa658084361c6da2856e46 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 26 Apr 2013 14:59:21 +0100 Subject: Deal with List/Instance suffixes for viewsets --- rest_framework/renderers.py | 2 +- rest_framework/routers.py | 72 ++++++++++++++++++++----------------- rest_framework/utils/breadcrumbs.py | 3 +- rest_framework/utils/formatting.py | 7 ++-- rest_framework/viewsets.py | 10 +++++- 5 files changed, 56 insertions(+), 38 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 752306ad..a0829c8f 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -439,7 +439,7 @@ class BrowsableAPIRenderer(BaseRenderer): return GenericContentForm() def get_name(self, view): - return get_view_name(view.__class__) + return get_view_name(view.__class__, getattr(view, 'suffix', None)) def get_description(self, view): return get_view_description(view.__class__, html=True) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index b7052218..3a8c4508 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -13,6 +13,7 @@ For example, you might have a `urls.py` that looks something like this: urlpatterns = router.urls """ +from collections import namedtuple from django.conf.urls import url, patterns from django.db import models from rest_framework.decorators import api_view @@ -22,6 +23,9 @@ from rest_framework.viewsets import ModelViewSet from rest_framework.urlpatterns import format_suffix_patterns +Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs']) + + def replace_methodname(format_string, methodname): """ Partially format a format_string, swapping out any @@ -38,8 +42,8 @@ class BaseRouter(object): def __init__(self): self.registry = [] - def register(self, prefix, viewset, basename): - self.registry.append((prefix, viewset, basename)) + def register(self, prefix, viewset, name): + self.registry.append((prefix, viewset, name)) def get_urls(self): raise NotImplemented('get_urls must be overridden') @@ -54,33 +58,36 @@ class BaseRouter(object): class SimpleRouter(BaseRouter): routes = [ # List route. - ( - r'^{prefix}/$', - { + Route( + url=r'^{prefix}/$', + mapping={ 'get': 'list', 'post': 'create' }, - '{basename}-list' + name='{basename}-list', + initkwargs={'suffix': 'List'} ), # Detail route. - ( - r'^{prefix}/{lookup}/$', - { + Route( + url=r'^{prefix}/{lookup}/$', + mapping={ 'get': 'retrieve', 'put': 'update', 'patch': 'partial_update', 'delete': 'destroy' }, - '{basename}-detail' + name='{basename}-detail', + initkwargs={'suffix': 'Instance'} ), # Dynamically generated routes. # Generated using @action or @link decorators on methods of the viewset. - ( - r'^{prefix}/{lookup}/{methodname}/$', - { + Route( + url=r'^{prefix}/{lookup}/{methodname}/$', + mapping={ '{httpmethod}': '{methodname}', }, - '{basename}-{methodnamehyphen}' + name='{basename}-{methodnamehyphen}', + initkwargs={} ), ] @@ -88,8 +95,7 @@ class SimpleRouter(BaseRouter): """ Augment `self.routes` with any dynamically generated routes. - Returns a list of 4-tuples, of the form: - `(url_format, method_map, name_format, extra_kwargs)` + Returns a list of the Route namedtuple. """ # Determine any `@action` or `@link` decorated methods on the viewset @@ -101,21 +107,21 @@ class SimpleRouter(BaseRouter): dynamic_routes[httpmethod] = methodname ret = [] - for url_format, method_map, name_format in self.routes: - if method_map == {'{httpmethod}': '{methodname}'}: + for route in self.routes: + if route.mapping == {'{httpmethod}': '{methodname}'}: # Dynamic routes (@link or @action decorator) for httpmethod, methodname in dynamic_routes.items(): - extra_kwargs = getattr(viewset, methodname).kwargs - ret.append(( - replace_methodname(url_format, methodname), - {httpmethod: methodname}, - replace_methodname(name_format, methodname), - extra_kwargs + initkwargs = route.initkwargs.copy() + initkwargs.update(getattr(viewset, methodname).kwargs) + ret.append(Route( + url=replace_methodname(route.url, methodname), + mapping={httpmethod: methodname}, + name=replace_methodname(route.name, methodname), + initkwargs=initkwargs, )) else: # Standard route - extra_kwargs = {} - ret.append((url_format, method_map, name_format, extra_kwargs)) + ret.append(route) return ret @@ -150,17 +156,17 @@ class SimpleRouter(BaseRouter): lookup = self.get_lookup_regex(viewset) routes = self.get_routes(viewset) - for url_format, method_map, name_format, extra_kwargs in routes: + for route in routes: # Only actions which actually exist on the viewset will be bound - method_map = self.get_method_map(viewset, method_map) - if not method_map: + mapping = self.get_method_map(viewset, route.mapping) + if not mapping: continue # Build the url pattern - regex = url_format.format(prefix=prefix, lookup=lookup) - view = viewset.as_view(method_map, **extra_kwargs) - name = name_format.format(basename=basename) + regex = route.url.format(prefix=prefix, lookup=lookup) + view = viewset.as_view(mapping, **route.initkwargs) + name = route.name.format(basename=basename) ret.append(url(regex, view, name=name)) return ret @@ -179,7 +185,7 @@ class DefaultRouter(SimpleRouter): Return a view to use as the API root. """ api_root_dict = {} - list_name = self.routes[0][-1] + list_name = self.routes[0].name for prefix, viewset, basename in self.registry: api_root_dict[prefix] = list_name.format(basename=basename) diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 18b3b207..8f8e5710 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -21,7 +21,8 @@ def get_breadcrumbs(url): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: - breadcrumbs_list.insert(0, (get_view_name(view.cls), prefix + url)) + suffix = getattr(view, 'suffix', None) + breadcrumbs_list.insert(0, (get_view_name(view.cls, suffix), prefix + url)) seen.append(view) if url == '': diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 79566db1..ebadb3a6 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -45,14 +45,17 @@ def _camelcase_to_spaces(content): return ' '.join(content.split('_')).title() -def get_view_name(cls): +def get_view_name(cls, suffix=None): """ Return a formatted name for an `APIView` class or `@api_view` function. """ name = cls.__name__ name = _remove_trailing_string(name, 'View') name = _remove_trailing_string(name, 'ViewSet') - return _camelcase_to_spaces(name) + name = _camelcase_to_spaces(name) + if suffix: + name += ' ' + suffix + return name def get_view_description(cls, html=False): diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 9133fd44..bd25df77 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -35,12 +35,16 @@ class ViewSetMixin(object): """ @classonlymethod - def as_view(cls, actions=None, name_suffix=None, **initkwargs): + def as_view(cls, actions=None, **initkwargs): """ Because of the way class based views create a closure around the instantiated view, we need to totally reimplement `.as_view`, and slightly modify the view function that is created and returned. """ + # The suffix initkwarg is reserved for identifing the viewset type + # eg. 'List' or 'Instance'. + cls.suffix = None + # sanitize keyword arguments for key in initkwargs: if key in cls.http_method_names: @@ -74,7 +78,11 @@ class ViewSetMixin(object): # like csrf_exempt from dispatch update_wrapper(view, cls.dispatch, assigned=()) + # We need to set these on the view function, so that breadcrumb + # generation can pick out these bits of information from a + # resolved URL. view.cls = cls + view.suffix = initkwargs.get('suffix', None) return view -- cgit v1.2.3 From 018d8b8dced31309196496e625cf8a746b98d65e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 26 Apr 2013 15:09:55 +0100 Subject: Bits of cleanup --- rest_framework/routers.py | 2 +- rest_framework/utils/breadcrumbs.py | 30 +++++++++++++++++++++--------- 2 files changed, 22 insertions(+), 10 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 3a8c4508..33e88a81 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -29,7 +29,7 @@ Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs']) def replace_methodname(format_string, methodname): """ Partially format a format_string, swapping out any - '{methodname}'' or '{methodnamehyphen}'' components. + '{methodname}' or '{methodnamehyphen}' components. """ methodnamehyphen = methodname.replace('_', '-') ret = format_string diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 8f8e5710..28801d09 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -4,25 +4,33 @@ from rest_framework.utils.formatting import get_view_name def get_breadcrumbs(url): - """Given a url returns a list of breadcrumbs, which are each a tuple of (name, url).""" + """ + Given a url returns a list of breadcrumbs, which are each a + tuple of (name, url). + """ from rest_framework.views import APIView def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen): - """Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url.""" + """ + Add tuples of (name, url) to the breadcrumbs list, + progressively chomping off parts of the url. + """ try: (view, unused_args, unused_kwargs) = resolve(url) except Exception: pass else: - # Check if this is a REST framework view, and if so add it to the breadcrumbs + # Check if this is a REST framework view, + # and if so add it to the breadcrumbs if issubclass(getattr(view, 'cls', None), APIView): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: suffix = getattr(view, 'suffix', None) - breadcrumbs_list.insert(0, (get_view_name(view.cls, suffix), prefix + url)) + name = get_view_name(view.cls, suffix) + breadcrumbs_list.insert(0, (name, prefix + url)) seen.append(view) if url == '': @@ -30,11 +38,15 @@ def get_breadcrumbs(url): return breadcrumbs_list elif url.endswith('/'): - # Drop trailing slash off the end and continue to try to resolve more breadcrumbs - return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen) - - # Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs - return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen) + # Drop trailing slash off the end and continue to try to + # resolve more breadcrumbs + url = url.rstrip('/') + return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) + + # Drop trailing non-slash off the end and continue to try to + # resolve more breadcrumbs + url = url[:url.rfind('/') + 1] + return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) prefix = get_script_prefix().rstrip('/') url = url[len(prefix):] -- cgit v1.2.3 From 3b0fa3ebaa9d42723d970bb88be0dfe2586d1a5e Mon Sep 17 00:00:00 2001 From: JC Date: Sat, 27 Apr 2013 13:10:39 -0700 Subject: Changed DepthTest to have depth=2 --- rest_framework/tests/serializer.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 05217f35..bd874253 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -3,7 +3,7 @@ from django.utils.datastructures import MultiValueDict from django.test import TestCase from rest_framework import serializers from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, - BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel, + BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) import datetime import pickle @@ -767,8 +767,6 @@ class RelatedTraversalTest(TestCase): post = BlogPost.objects.create(title="Test blog post", writer=user) post.blogpostcomment_set.create(text="I love this blog post") - from rest_framework.tests.models import BlogPostComment - class PersonSerializer(serializers.ModelSerializer): class Meta: model = Person @@ -968,23 +966,26 @@ class SerializerPickleTests(TestCase): class DepthTest(TestCase): def test_implicit_nesting(self): + writer = Person.objects.create(name="django", age=1) post = BlogPost.objects.create(title="Test blog post", writer=writer) + comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) - class BlogPostSerializer(serializers.ModelSerializer): + class BlogPostCommentSerializer(serializers.ModelSerializer): class Meta: - model = BlogPost - depth = 1 + model = BlogPostComment + depth = 2 - serializer = BlogPostSerializer(instance=post) - expected = {'id': 1, 'title': 'Test blog post', - 'writer': {'id': 1, 'name': 'django', 'age': 1}} + serializer = BlogPostCommentSerializer(instance=comment) + expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', + 'writer': {'id': 1, 'name': 'django', 'age': 1}}} self.assertEqual(serializer.data, expected) def test_explicit_nesting(self): writer = Person.objects.create(name="django", age=1) post = BlogPost.objects.create(title="Test blog post", writer=writer) + comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) class PersonSerializer(serializers.ModelSerializer): class Meta: @@ -996,9 +997,15 @@ class DepthTest(TestCase): class Meta: model = BlogPost - serializer = BlogPostSerializer(instance=post) - expected = {'id': 1, 'title': 'Test blog post', - 'writer': {'id': 1, 'name': 'django', 'age': 1}} + class BlogPostCommentSerializer(serializers.ModelSerializer): + blog_post = BlogPostSerializer() + + class Meta: + model = BlogPostComment + + serializer = BlogPostCommentSerializer(instance=comment) + expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', + 'writer': {'id': 1, 'name': 'django', 'age': 1}}} self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From 8cbb715f4c5550d76e397828608a31a4f254a37d Mon Sep 17 00:00:00 2001 From: JC Date: Sat, 27 Apr 2013 13:23:55 -0700 Subject: Changed definition of NestedModelSerializer to correct depth handling --- rest_framework/serializers.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e28bbe81..add46566 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -205,18 +205,6 @@ class BaseSerializer(WritableField): return ret - ##### - # Field methods - used when the serializer class is itself used as a field. - - def initialize(self, parent, field_name): - """ - Same behaviour as usual Field, except that we need to keep track - of state so that we can deal with handling maximum depth. - """ - super(BaseSerializer, self).initialize(parent, field_name) - if parent.opts.depth: - self.opts.depth = parent.opts.depth - 1 - ##### # Methods to convert or revert from objects <--> primitive representations. @@ -619,6 +607,8 @@ class ModelSerializer(Serializer): class NestedModelSerializer(ModelSerializer): class Meta: model = model_field.rel.to + depth = self.opts.depth - 1 + return NestedModelSerializer() def get_related_field(self, model_field, to_many=False): -- cgit v1.2.3 From dc7b1d643020cac5d585aac42f98962cc7aa6bf7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Apr 2013 12:45:00 +0100 Subject: 2.2's PendingDeprecationWarnings now become DeprecationWarnings. 2.3's PendingDeprecationWarnings added. --- rest_framework/fields.py | 4 +-- rest_framework/generics.py | 60 +++++++++++++++++++++++-------- rest_framework/permissions.py | 8 +++-- rest_framework/relations.py | 28 +++++++-------- rest_framework/routers.py | 2 ++ rest_framework/serializers.py | 24 +++++++++---- rest_framework/tests/serializer.py | 5 ++- rest_framework/tests/serializer_nested.py | 4 +-- rest_framework/viewsets.py | 2 ++ 9 files changed, 93 insertions(+), 44 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 38fe025d..f934fc39 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -200,9 +200,9 @@ class WritableField(Field): # 'blank' is to be deprecated in favor of 'required' if blank is not None: - warnings.warn('The `blank` keyword argument is due to deprecated. ' + warnings.warn('The `blank` keyword argument is deprecated. ' 'Use the `required` keyword argument instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) required = not(blank) super(WritableField, self).__init__(source=source) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index a18584d4..972424e6 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -2,13 +2,16 @@ Generic views that provide commonly needed behaviour. """ from __future__ import unicode_literals -from rest_framework import views, mixins -from rest_framework.settings import api_settings + from django.core.exceptions import ImproperlyConfigured from django.core.paginator import Paginator, InvalidPage from django.http import Http404 from django.shortcuts import get_object_or_404 from django.utils.translation import ugettext as _ +from rest_framework import views, mixins +from rest_framework.exceptions import ConfigurationError +from rest_framework.settings import api_settings +import warnings class GenericAPIView(views.APIView): @@ -94,7 +97,12 @@ class GenericAPIView(views.APIView): """ deprecated_style = False if page_size is not None: - # TODO: Deperecation warning + warnings.warn('The `page_size` parameter to `paginate_queryset()` ' + 'is due to be deprecated. ' + 'Note that the return style of this method is also ' + 'changed, and will simply return a page object ' + 'when called without a `page_size` argument.', + PendingDeprecationWarning, stacklevel=2) deprecated_style = True else: # Determine the required page size. @@ -155,7 +163,9 @@ class GenericAPIView(views.APIView): Otherwise defaults to using `self.paginate_by`. """ if queryset is not None: - pass # TODO: Deprecation warning + warnings.warn('The `queryset` parameter to `get_paginate_by()` ' + 'is due to be deprecated.', + PendingDeprecationWarning, stacklevel=2) if self.paginate_by_param: query_params = self.request.QUERY_PARAMS @@ -226,17 +236,27 @@ class GenericAPIView(views.APIView): if lookup is not None: filter_kwargs = {self.lookup_field: lookup} - elif pk is not None: - # TODO: Deprecation warning + elif pk is not None and self.lookup_field == 'pk': + warnings.warn( + 'The `pk_url_kwarg` attribute is due to be deprecated. ' + 'Use the `lookup_field` attribute instead', + PendingDeprecationWarning + ) filter_kwargs = {'pk': pk} - elif slug is not None: - # TODO: Deprecation warning + elif slug is not None and self.lookup_field == 'pk': + warnings.warn( + 'The `slug_url_kwarg` attribute is due to be deprecated. ' + 'Use the `lookup_field` attribute instead', + PendingDeprecationWarning + ) filter_kwargs = {self.slug_field: slug} else: - # TODO: Fix error message - raise AttributeError("Generic detail view %s must be called with " - "either an object pk or a slug." - % self.__class__.__name__) + raise ConfigurationError( + 'Expected view %s to be called with a URL keyword argument ' + 'named "%s". Fix your URL conf, or set the `.lookup_field` ' + 'attribute on the view correctly.' % + (self.__class__.__name__, self.lookup_field) + ) obj = get_object_or_404(queryset, **filter_kwargs) @@ -391,8 +411,20 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, ########################## class MultipleObjectAPIView(GenericAPIView): - pass + def __init__(self, *args, **kwargs): + warnings.warn( + 'Subclassing `MultipleObjectAPIView` is due to be deprecated. ' + 'You should simply subclass `GenericAPIView` instead.', + PendingDeprecationWarning, stacklevel=2 + ) + super(MultipleObjectAPIView, self).__init__(*args, **kwargs) class SingleObjectAPIView(GenericAPIView): - pass + def __init__(self, *args, **kwargs): + warnings.warn( + 'Subclassing `SingleObjectAPIView` is due to be deprecated. ' + 'You should simply subclass `GenericAPIView` instead.', + PendingDeprecationWarning, stacklevel=2 + ) + super(SingleObjectAPIView, self).__init__(*args, **kwargs) diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 2aa45c71..91bf5ad6 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -26,9 +26,11 @@ class BasePermission(object): Return `True` if permission is granted, `False` otherwise. """ if len(inspect.getargspec(self.has_permission).args) == 4: - warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. ' - 'Use `has_object_permission()` instead for object permissions.', - PendingDeprecationWarning, stacklevel=2) + warnings.warn( + 'The `obj` argument in `has_permission` is deprecated. ' + 'Use `has_object_permission()` instead for object permissions.', + DeprecationWarning, stacklevel=2 + ) return self.has_permission(request, view, obj) return True diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 6bda7418..abe5203b 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -42,9 +42,9 @@ class RelatedField(WritableField): # 'null' is to be deprecated in favor of 'required' if 'null' in kwargs: - warnings.warn('The `null` keyword argument is due to be deprecated. ' + warnings.warn('The `null` keyword argument is deprecated. ' 'Use the `required` keyword argument instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['required'] = not kwargs.pop('null') self.queryset = kwargs.pop('queryset', None) @@ -328,9 +328,9 @@ class HyperlinkedRelatedField(RelatedField): if request is None: warnings.warn("Using `HyperlinkedRelatedField` without including the " - "request in the serializer context is due to be deprecated. " + "request in the serializer context is deprecated. " "Add `context={'request': request}` when instantiating the serializer.", - PendingDeprecationWarning, stacklevel=4) + DeprecationWarning, stacklevel=4) pk = getattr(obj, 'pk', None) if pk is None: @@ -443,9 +443,9 @@ class HyperlinkedIdentityField(Field): if request is None: warnings.warn("Using `HyperlinkedIdentityField` without including the " - "request in the serializer context is due to be deprecated. " + "request in the serializer context is deprecated. " "Add `context={'request': request}` when instantiating the serializer.", - PendingDeprecationWarning, stacklevel=4) + DeprecationWarning, stacklevel=4) # By default use whatever format is given for the current context # unless the target is a different type to the source. @@ -488,35 +488,35 @@ class HyperlinkedIdentityField(Field): class ManyRelatedField(RelatedField): def __init__(self, *args, **kwargs): - warnings.warn('`ManyRelatedField()` is due to be deprecated. ' + warnings.warn('`ManyRelatedField()` is deprecated. ' 'Use `RelatedField(many=True)` instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['many'] = True super(ManyRelatedField, self).__init__(*args, **kwargs) class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField): def __init__(self, *args, **kwargs): - warnings.warn('`ManyPrimaryKeyRelatedField()` is due to be deprecated. ' + warnings.warn('`ManyPrimaryKeyRelatedField()` is deprecated. ' 'Use `PrimaryKeyRelatedField(many=True)` instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['many'] = True super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs) class ManySlugRelatedField(SlugRelatedField): def __init__(self, *args, **kwargs): - warnings.warn('`ManySlugRelatedField()` is due to be deprecated. ' + warnings.warn('`ManySlugRelatedField()` is deprecated. ' 'Use `SlugRelatedField(many=True)` instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['many'] = True super(ManySlugRelatedField, self).__init__(*args, **kwargs) class ManyHyperlinkedRelatedField(HyperlinkedRelatedField): def __init__(self, *args, **kwargs): - warnings.warn('`ManyHyperlinkedRelatedField()` is due to be deprecated. ' + warnings.warn('`ManyHyperlinkedRelatedField()` is deprecated. ' 'Use `HyperlinkedRelatedField(many=True)` instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['many'] = True super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 33e88a81..2bbf519c 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -13,6 +13,8 @@ For example, you might have a `urls.py` that looks something like this: urlpatterns = router.urls """ +from __future__ import unicode_literals + from collections import namedtuple from django.conf.urls import url, patterns from django.db import models diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 3d956e4d..3afb7475 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -436,9 +436,9 @@ class BaseSerializer(WritableField): else: many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) if many: - warnings.warn('Implict list/queryset serialization is due to be deprecated. ' + warnings.warn('Implict list/queryset serialization is deprecated. ' 'Use the `many=True` flag when instantiating the serializer.', - PendingDeprecationWarning, stacklevel=3) + DeprecationWarning, stacklevel=3) if many: ret = [] @@ -498,9 +498,9 @@ class BaseSerializer(WritableField): else: many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) if many: - warnings.warn('Implict list/queryset serialization is due to be deprecated. ' + warnings.warn('Implict list/queryset serialization is deprecated. ' 'Use the `many=True` flag when instantiating the serializer.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) if many: self._data = [self.to_native(item) for item in obj] @@ -606,13 +606,25 @@ class ModelSerializer(Serializer): if model_field.rel and nested: if len(inspect.getargspec(self.get_nested_field).args) == 2: - # TODO: deprecation warning + warnings.warn( + 'The `get_nested_field(model_field)` call signature ' + 'is due to be deprecated. ' + 'Use `get_nested_field(model_field, related_model, ' + 'to_many) instead', + PendingDeprecationWarning + ) field = self.get_nested_field(model_field) else: field = self.get_nested_field(model_field, related_model, to_many) elif model_field.rel: if len(inspect.getargspec(self.get_nested_field).args) == 3: - # TODO: deprecation warning + warnings.warn( + 'The `get_related_field(model_field, to_many)` call ' + 'signature is due to be deprecated. ' + 'Use `get_related_field(model_field, related_model, ' + 'to_many) instead', + PendingDeprecationWarning + ) field = self.get_related_field(model_field, to_many=to_many) else: field = self.get_related_field(model_field, related_model, to_many) diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 3a94fad5..ae8d09dc 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -357,7 +357,6 @@ class CustomValidationTests(TestCase): def validate_email(self, attrs, source): value = attrs[source] - return attrs def validate_content(self, attrs, source): @@ -1103,7 +1102,7 @@ class DeserializeListTestCase(TestCase): def test_no_errors(self): data = [self.data.copy() for x in range(0, 3)] - serializer = CommentSerializer(data=data) + serializer = CommentSerializer(data=data, many=True) self.assertTrue(serializer.is_valid()) self.assertTrue(isinstance(serializer.object, list)) self.assertTrue( @@ -1115,7 +1114,7 @@ class DeserializeListTestCase(TestCase): invalid_item['email'] = '' data = [self.data.copy(), invalid_item, self.data.copy()] - serializer = CommentSerializer(data=data) + serializer = CommentSerializer(data=data, many=True) self.assertFalse(serializer.is_valid()) expected = [{}, {'email': ['This field is required.']}, {}] self.assertEqual(serializer.errors, expected) diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py index 6a29c652..71d0e24b 100644 --- a/rest_framework/tests/serializer_nested.py +++ b/rest_framework/tests/serializer_nested.py @@ -109,7 +109,7 @@ class WritableNestedSerializerBasicTests(TestCase): } ] - serializer = self.AlbumSerializer(data=data) + serializer = self.AlbumSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) self.assertEqual(serializer.errors, expected_errors) @@ -241,6 +241,6 @@ class WritableNestedSerializerObjectTests(TestCase): ) ] - serializer = self.AlbumSerializer(data=data) + serializer = self.AlbumSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.object, expected_object) diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index bd25df77..a54467d7 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -16,6 +16,8 @@ automatically. router.register(r'users', UserViewSet, 'user') urlpatterns = router.urls """ +from __future__ import unicode_literals + from functools import update_wrapper from django.utils.decorators import classonlymethod from rest_framework import views, generics, mixins -- cgit v1.2.3 From d17e2d852fc6ebc738e324b8797d390dc0287d37 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Apr 2013 12:46:57 +0100 Subject: Remove AutoRouter. (Adding shortcut to generic views/viewsets means it's unneccessary) --- rest_framework/routers.py | 26 -------------------------- 1 file changed, 26 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 2bbf519c..923405e8 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -17,11 +17,9 @@ from __future__ import unicode_literals from collections import namedtuple from django.conf.urls import url, patterns -from django.db import models from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.reverse import reverse -from rest_framework.viewsets import ModelViewSet from rest_framework.urlpatterns import format_suffix_patterns @@ -218,27 +216,3 @@ class DefaultRouter(SimpleRouter): urls = format_suffix_patterns(urls) return urls - - -class AutoRouter(DefaultRouter): - """ - A router class that doesn't require you to register any viewsets, - but instead automatically creates routes for all installed models. - - Useful for quick and dirty prototyping. - """ - def __init__(self): - super(AutoRouter, self).__init__() - for model in models.get_models(): - prefix = model._meta.verbose_name_plural.replace(' ', '_') - basename = model._meta.object_name.lower() - classname = model.__name__ - - DynamicViewSet = type( - classname, - (ModelViewSet,), - {} - ) - DynamicViewSet.model = model - - self.register(prefix, DynamicViewSet, basename) -- cgit v1.2.3 From 53f9d4a380ee0066cbee8382ae265ea6005d8c88 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Apr 2013 13:20:15 +0100 Subject: fields shortcut on views --- rest_framework/generics.py | 5 +++++ rest_framework/serializers.py | 2 +- rest_framework/tests/generics.py | 24 ++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 972424e6..0b8e4a15 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -44,6 +44,10 @@ class GenericAPIView(views.APIView): # the explicit style is generally preferred. model = None + # This shortcut may be used instead of setting the `serializer_class` + # attribute, although using the explicit style is generally preferred. + fields = None + # If the `model` shortcut is used instead of `serializer_class`, then the # serializer class will be constructed using this class as the base. model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS @@ -193,6 +197,7 @@ class GenericAPIView(views.APIView): class DefaultSerializer(self.model_serializer_class): class Meta: model = self.model + fields = self.fields return DefaultSerializer def get_queryset(self): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 3afb7475..f4a20097 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -645,7 +645,7 @@ class ModelSerializer(Serializer): for relation in reverse_rels: accessor_name = relation.get_accessor_name() - if accessor_name not in self.opts.fields: + if not self.opts.fields or accessor_name not in self.opts.fields: continue related_model = relation.model to_many = relation.field.rel.multiple diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 4a13389a..12c9b677 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -344,6 +344,30 @@ class TestOverriddenGetObject(TestCase): self.assertEqual(response.data, self.data[0]) +class TestFieldsShortcut(TestCase): + """ + Test cases for setting the `fields` attribute on a view. + """ + def setUp(self): + class OverriddenFieldsView(generics.RetrieveUpdateDestroyAPIView): + model = BasicModel + fields = ('text',) + + class RegularView(generics.RetrieveUpdateDestroyAPIView): + model = BasicModel + + self.overridden_fields_view = OverriddenFieldsView() + self.regular_view = RegularView() + + def test_overridden_fields_view(self): + Serializer = self.overridden_fields_view.get_serializer_class() + self.assertEqual(Serializer().fields.keys(), ['text']) + + def test_not_overridden_fields_view(self): + Serializer = self.regular_view.get_serializer_class() + self.assertEqual(Serializer().fields.keys(), ['id', 'text']) + + # Regression test for #285 class CommentSerializer(serializers.ModelSerializer): -- cgit v1.2.3 From 0c1ab584d3d0898d47e0bce6beb5d7c39a55dd52 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Apr 2013 14:08:38 +0100 Subject: Tweaks for preferring .queryset over .model --- rest_framework/generics.py | 19 ++++++++++++------- rest_framework/tests/generics.py | 6 ++++-- 2 files changed, 16 insertions(+), 9 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 0b8e4a15..3ea78b5d 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -48,11 +48,10 @@ class GenericAPIView(views.APIView): # attribute, although using the explicit style is generally preferred. fields = None - # If the `model` shortcut is used instead of `serializer_class`, then the - # serializer class will be constructed using this class as the base. + # The following attributes may be subject to change, + # and should be considered private API. model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS - - _paginator_class = Paginator + paginator_class = Paginator ###################################### # These are pending deprecation... @@ -115,8 +114,8 @@ class GenericAPIView(views.APIView): if not page_size: return None - paginator = self._paginator_class(queryset, page_size, - allow_empty_first_page=self.allow_empty) + paginator = self.paginator_class(queryset, page_size, + allow_empty_first_page=self.allow_empty) page_kwarg = self.kwargs.get(self.page_kwarg) page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) page = page_kwarg or page_query_param or 1 @@ -194,9 +193,15 @@ class GenericAPIView(views.APIView): if serializer_class is not None: return serializer_class + assert self.model is not None or self.queryset is not None, \ + "'%s' should either include a 'serializer_class' attribute, " \ + "or use the 'queryset' or 'model' attribute as a shortcut for " \ + "automatically generating a serializer class." \ + % self.__class__.__name__ + class DefaultSerializer(self.model_serializer_class): class Meta: - model = self.model + model = self.model or self.queryset.model fields = self.fields return DefaultSerializer diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 12c9b677..63ff1fc3 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -350,11 +350,11 @@ class TestFieldsShortcut(TestCase): """ def setUp(self): class OverriddenFieldsView(generics.RetrieveUpdateDestroyAPIView): - model = BasicModel + queryset = BasicModel.objects.all() fields = ('text',) class RegularView(generics.RetrieveUpdateDestroyAPIView): - model = BasicModel + queryset = BasicModel.objects.all() self.overridden_fields_view = OverriddenFieldsView() self.regular_view = RegularView() @@ -362,10 +362,12 @@ class TestFieldsShortcut(TestCase): def test_overridden_fields_view(self): Serializer = self.overridden_fields_view.get_serializer_class() self.assertEqual(Serializer().fields.keys(), ['text']) + self.assertEqual(Serializer().opts.model, BasicModel) def test_not_overridden_fields_view(self): Serializer = self.regular_view.get_serializer_class() self.assertEqual(Serializer().fields.keys(), ['id', 'text']) + self.assertEqual(Serializer().opts.model, BasicModel) # Regression test for #285 -- cgit v1.2.3 From 21ae3a66917acf4ea57e8f7940ce1a6823a2ce92 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 30 Apr 2013 08:24:33 +0100 Subject: Drop out attribute --- rest_framework/generics.py | 24 ++++++++++-------------- rest_framework/serializers.py | 4 ++++ rest_framework/tests/generics.py | 26 -------------------------- 3 files changed, 14 insertions(+), 40 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 3ea78b5d..62129dcc 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -20,11 +20,17 @@ class GenericAPIView(views.APIView): """ # You'll need to either set these attributes, - # or override `get_queryset`/`get_serializer_class`. + # or override `get_queryset()`/`get_serializer_class()`. queryset = None serializer_class = None + # This shortcut may be used instead of setting either or both + # of the `queryset`/`serializer_class` attributes, although using + # the explicit style is generally preferred. + model = None + # If you want to use object lookups other than pk, set this attribute. + # For more complex lookup requirements override `get_object()`. lookup_field = 'pk' # Pagination settings @@ -39,15 +45,6 @@ class GenericAPIView(views.APIView): # Determines if the view will return 200 or 404 responses for empty lists. allow_empty = True - # This shortcut may be used instead of setting either (or both) - # of the `queryset`/`serializer_class` attributes, although using - # the explicit style is generally preferred. - model = None - - # This shortcut may be used instead of setting the `serializer_class` - # attribute, although using the explicit style is generally preferred. - fields = None - # The following attributes may be subject to change, # and should be considered private API. model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS @@ -193,16 +190,15 @@ class GenericAPIView(views.APIView): if serializer_class is not None: return serializer_class - assert self.model is not None or self.queryset is not None, \ + assert self.model is not None, \ "'%s' should either include a 'serializer_class' attribute, " \ - "or use the 'queryset' or 'model' attribute as a shortcut for " \ + "or use the 'model' attribute as a shortcut for " \ "automatically generating a serializer class." \ % self.__class__.__name__ class DefaultSerializer(self.model_serializer_class): class Meta: - model = self.model or self.queryset.model - fields = self.fields + model = self.model return DefaultSerializer def get_queryset(self): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index f4a20097..0f943d79 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -677,6 +677,8 @@ class ModelSerializer(Serializer): def get_nested_field(self, model_field, related_model, to_many): """ Creates a default instance of a nested relational field. + + Note that model_field will be `None` for reverse relationships. """ class NestedModelSerializer(ModelSerializer): class Meta: @@ -686,6 +688,8 @@ class ModelSerializer(Serializer): def get_related_field(self, model_field, related_model, to_many): """ Creates a default instance of a flat relational field. + + Note that model_field will be `None` for reverse relationships. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 63ff1fc3..4a13389a 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -344,32 +344,6 @@ class TestOverriddenGetObject(TestCase): self.assertEqual(response.data, self.data[0]) -class TestFieldsShortcut(TestCase): - """ - Test cases for setting the `fields` attribute on a view. - """ - def setUp(self): - class OverriddenFieldsView(generics.RetrieveUpdateDestroyAPIView): - queryset = BasicModel.objects.all() - fields = ('text',) - - class RegularView(generics.RetrieveUpdateDestroyAPIView): - queryset = BasicModel.objects.all() - - self.overridden_fields_view = OverriddenFieldsView() - self.regular_view = RegularView() - - def test_overridden_fields_view(self): - Serializer = self.overridden_fields_view.get_serializer_class() - self.assertEqual(Serializer().fields.keys(), ['text']) - self.assertEqual(Serializer().opts.model, BasicModel) - - def test_not_overridden_fields_view(self): - Serializer = self.regular_view.get_serializer_class() - self.assertEqual(Serializer().fields.keys(), ['id', 'text']) - self.assertEqual(Serializer().opts.model, BasicModel) - - # Regression test for #285 class CommentSerializer(serializers.ModelSerializer): -- cgit v1.2.3 From 8dff8d2fdcfcee356c134f4be8235d2a4f122d1a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 30 Apr 2013 14:34:03 +0100 Subject: Add get_breadcrumbs hook to BrowseableAPIRenderer. Closes #733. --- rest_framework/renderers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index a0829c8f..c457ec73 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -444,6 +444,9 @@ class BrowsableAPIRenderer(BaseRenderer): def get_description(self, view): return get_view_description(view.__class__, html=True) + def get_breadcrumbs(self, request): + return get_breadcrumbs(request.path) + def render(self, data, accepted_media_type=None, renderer_context=None): """ Renders *obj* using the :attr:`template` set on the class. @@ -475,7 +478,7 @@ class BrowsableAPIRenderer(BaseRenderer): name = self.get_name(view) description = self.get_description(view) - breadcrumb_list = get_breadcrumbs(request.path) + breadcrumb_list = self.get_breadcrumbs(request) template = loader.get_template(self.template) context = RequestContext(request, { -- cgit v1.2.3 From b65b065375796919a57f4bd6f1dd8187ef0eb165 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 30 Apr 2013 14:34:28 +0100 Subject: Add DjangoModelPermissionsOrAnonReadOnly --- rest_framework/permissions.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 91bf5ad6..751f31a7 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -89,8 +89,8 @@ class DjangoModelPermissions(BasePermission): It ensures that the user is authenticated, and has the appropriate `add`/`change`/`delete` permissions on the model. - This permission will only be applied against view classes that - provide a `.model` attribute, such as the generic class-based views. + This permission can only be applied against view classes that + provide a `.model` or `.queryset` attribute. """ # Map methods into required permission codes. @@ -138,6 +138,14 @@ class DjangoModelPermissions(BasePermission): return False +class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions): + """ + Similar to DjangoModelPermissions, except that anonymous users are + allowed read-only access. + """ + authenticated_users_only = False + + class TokenHasReadWriteScope(BasePermission): """ The request is authenticated as a user and the token used has the right scope -- cgit v1.2.3 From e5040fbf942e021444f629a371bc71c9d47d052f Mon Sep 17 00:00:00 2001 From: Danilo Bargen Date: Tue, 30 Apr 2013 23:24:20 +0200 Subject: Catch ImproperlyConfigured exception in compat.py (fixes #803) --- rest_framework/compat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 067e9018..f8e4e7ca 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -6,6 +6,7 @@ versions of django/python, and compatibility wrappers around optional packages. from __future__ import unicode_literals import django +from django.core.exceptions import ImproperlyConfigured # Try to import six from Django, fallback to included `six`. try: @@ -473,7 +474,7 @@ except ImportError: try: import oauth_provider from oauth_provider.store import store as oauth_provider_store -except ImportError: +except (ImportError, ImproperlyConfigured): oauth_provider = None oauth_provider_store = None -- cgit v1.2.3 From 35f99cddc4a098547389fab7d9f397ad442dfff1 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 1 May 2013 09:03:09 +0100 Subject: lookup_field on hyperlinked fields, and overriddable hyperlinked fields. Closes #688 --- rest_framework/relations.py | 147 +++++++++++++++++++++++++----------------- rest_framework/serializers.py | 3 +- 2 files changed, 91 insertions(+), 59 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index abe5203b..6d8deec1 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -288,10 +288,8 @@ class HyperlinkedRelatedField(RelatedField): """ Represents a relationship using hyperlinking. """ - pk_url_kwarg = 'pk' - slug_field = 'slug' - slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden read_only = False + lookup_field = 'pk' default_error_messages = { 'no_match': _('Invalid hyperlink - No URL match'), @@ -301,69 +299,120 @@ class HyperlinkedRelatedField(RelatedField): 'incorrect_type': _('Incorrect type. Expected url string, received %s.'), } + # These are all pending deprecation + pk_url_kwarg = 'pk' + slug_field = 'slug' + slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden + def __init__(self, *args, **kwargs): try: self.view_name = kwargs.pop('view_name') except KeyError: raise ValueError("Hyperlinked field requires 'view_name' kwarg") + self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + self.format = kwargs.pop('format', None) + + # These are pending deprecation + self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) self.slug_field = kwargs.pop('slug_field', self.slug_field) default_slug_kwarg = self.slug_url_kwarg or self.slug_field - self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) - self.format = kwargs.pop('format', None) super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) - def get_slug_field(self): + def get_url(self, obj, view_name, request, format): """ - Get the name of a slug field to be used to look up by slug. - """ - return self.slug_field - - def to_native(self, obj): - view_name = self.view_name - request = self.context.get('request', None) - format = self.format or self.context.get('format', None) - - if request is None: - warnings.warn("Using `HyperlinkedRelatedField` without including the " - "request in the serializer context is deprecated. " - "Add `context={'request': request}` when instantiating the serializer.", - DeprecationWarning, stacklevel=4) + Given an object, return the URL that hyperlinks to the object. - pk = getattr(obj, 'pk', None) - if pk is None: - return - kwargs = {self.pk_url_kwarg: pk} + May raise a `NoReverseMatch` if the `view_name` and `lookup_field` + attributes are not configured to correctly match the URL conf. + """ + lookup_field = getattr(obj, self.lookup_field) + kwargs = {self.lookup_field: lookup_field} try: return reverse(view_name, kwargs=kwargs, request=request, format=format) except NoReverseMatch: pass + if self.pk_url_kwarg != 'pk': + # Only try pk if it has been explicitly set. + # Otherwise, the default `lookup_field = 'pk'` has us covered. + pk = obj.pk + kwargs = {self.pk_url_kwarg: pk} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass + slug = getattr(obj, self.slug_field, None) + if slug is not None: + # Only try slug if it corresponds to an attribute on the object. + kwargs = {self.slug_url_kwarg: slug} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass - if not slug: - raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + raise NoReverseMatch() - kwargs = {self.slug_url_kwarg: slug} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass + def get_object(self, queryset, view_name, view_args, view_kwargs): + """ + Return the object corresponding to a matched URL. - kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} + Takes the matched URL conf arguments, and the queryset, and should + return an object instance, or raise an `ObjectDoesNotExist` exception. + """ + lookup = view_kwargs.get(self.lookup_field, None) + pk = view_kwargs.get(self.pk_url_kwarg, None) + slug = view_kwargs.get(self.slug_url_kwarg, None) + + if lookup is not None: + filter_kwargs = {self.lookup_field: lookup} + elif pk is not None: + filter_kwargs = {'pk': pk} + elif slug is not None: + filter_kwargs = {self.slug_field: slug} + else: + raise ObjectDoesNotExist() + + return queryset.get(**filter_kwargs) + + def to_native(self, obj): + view_name = self.view_name + request = self.context.get('request', None) + format = self.format or self.context.get('format', None) + + if request is None: + msg = ( + "Using `HyperlinkedRelatedField` without including the request " + "in the serializer context is deprecated. " + "Add `context={'request': request}` when instantiating " + "the serializer." + ) + warnings.warn(msg, DeprecationWarning, stacklevel=4) + + # If the object has not yet been saved then we cannot hyperlink to it. + if getattr(obj, 'pk', None) is None: + return + + # Return the hyperlink, or error if incorrectly configured. try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) + return self.get_url(obj, view_name, request, format) except NoReverseMatch: - pass - - raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + msg = ( + 'Could not resolve URL for hyperlinked relationship using ' + 'view name "%s". You may have failed to include the related ' + 'model in your API, or incorrectly configured the ' + '`lookup_field` attribute on this field.' + ) + raise Exception(msg % view_name) def from_native(self, value): # Convert URL -> model instance pk # TODO: Use values_list - if self.queryset is None: + queryset = self.queryset + if queryset is None: raise Exception('Writable related fields must include a `queryset` argument') try: @@ -387,29 +436,11 @@ class HyperlinkedRelatedField(RelatedField): if match.view_name != self.view_name: raise ValidationError(self.error_messages['incorrect_match']) - pk = match.kwargs.get(self.pk_url_kwarg, None) - slug = match.kwargs.get(self.slug_url_kwarg, None) - - # Try explicit primary key. - if pk is not None: - queryset = self.queryset.filter(pk=pk) - # Next, try looking up by slug. - elif slug is not None: - slug_field = self.get_slug_field() - queryset = self.queryset.filter(**{slug_field: slug}) - # If none of those are defined, it's probably a configuation error. - else: - raise ValidationError(self.error_messages['configuration_error']) - try: - obj = queryset.get() - except ObjectDoesNotExist: + return self.get_object(queryset, match.view_name, + match.args, match.kwargs) + except (ObjectDoesNotExist, TypeError, ValueError): raise ValidationError(self.error_messages['does_not_exist']) - except (TypeError, ValueError): - msg = self.error_messages['incorrect_type'] - raise ValidationError(msg % type(value).__name__) - - return obj class HyperlinkedIdentityField(Field): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b589eca8..d4b34c01 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -836,6 +836,7 @@ class HyperlinkedModelSerializer(ModelSerializer): """ _options_class = HyperlinkedModelSerializerOptions _default_view_name = '%(model_name)s-detail' + _hyperlink_field_class = HyperlinkedRelatedField url = HyperlinkedIdentityField() @@ -874,7 +875,7 @@ class HyperlinkedModelSerializer(ModelSerializer): if model_field: kwargs['required'] = not(model_field.null or model_field.blank) - return HyperlinkedRelatedField(**kwargs) + return self._hyperlink_field_class(**kwargs) def get_identity(self, data): """ -- cgit v1.2.3 From 8cabae22c5330da2e0a15a6d61ef038a6447756a Mon Sep 17 00:00:00 2001 From: Victor Shih Date: Wed, 1 May 2013 21:26:40 -0700 Subject: Example and spelling fixes. Change "browseable" to "browsable" for consistency. --- rest_framework/renderers.py | 2 +- rest_framework/tests/generics.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 4c15e0db..83bbc5b8 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -57,7 +57,7 @@ class JSONRenderer(BaseRenderer): return '' # If 'indent' is provided in the context, then pretty print the result. - # E.g. If we're being called by the BrowseableAPIRenderer. + # E.g. If we're being called by the BrowsableAPIRenderer. renderer_context = renderer_context or {} indent = renderer_context.get('indent', None) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 4a13389a..eca50d82 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -377,7 +377,7 @@ class TestCreateModelWithAutoNowAddField(TestCase): self.assertEqual(created.content, 'foobar') -# Test for particularly ugly regression with m2m in browseable API +# Test for particularly ugly regression with m2m in browsable API class ClassB(models.Model): name = models.CharField(max_length=255) @@ -402,7 +402,7 @@ class ExampleView(generics.ListCreateAPIView): class TestM2MBrowseableAPI(TestCase): def test_m2m_in_browseable_api(self): """ - Test for particularly ugly regression with m2m in browseable API + Test for particularly ugly regression with m2m in browsable API """ request = factory.get('/', HTTP_ACCEPT='text/html') view = ExampleView().as_view() -- cgit v1.2.3 From e4067bfb75a38851ea865719ebfbb65708187b4e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 2 May 2013 12:07:18 +0100 Subject: introduce lookup_field and add pendingdeprecationwarnings --- rest_framework/mixins.py | 13 +++++++++++-- rest_framework/relations.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index ec751e24..ae703771 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -12,7 +12,7 @@ from rest_framework.response import Response from rest_framework.request import clone_request -def _get_validation_exclusions(obj, pk=None, slug_field=None): +def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None): """ Given a model instance, and an optional pk and slug field, return the full list of all other field names on that model. @@ -23,14 +23,19 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None): include = [] if pk: + # Pending deprecation pk_field = obj._meta.pk while pk_field.rel: pk_field = pk_field.rel.to._meta.pk include.append(pk_field.name) if slug_field: + # Pending deprecation include.append(slug_field) + if lookup_field and lookup_field != 'pk': + include.append(lookup_field) + return [field.name for field in obj._meta.fields if field.name not in include] @@ -139,10 +144,14 @@ class UpdateModelMixin(object): Set any attributes on the object that are implicit in the request. """ # pk and/or slug attributes are implicit in the URL. + lookup = self.kwargs.get(self.lookup_field, None) pk = self.kwargs.get(self.pk_url_kwarg, None) slug = self.kwargs.get(self.slug_url_kwarg, None) slug_field = slug and self.slug_field or None + if lookup: + setattr(obj, self.lookup_field, lookup) + if pk: setattr(obj, 'pk', pk) @@ -152,7 +161,7 @@ class UpdateModelMixin(object): # Ensure we clean the attributes so that we don't eg return integer # pk using a string representation, as provided by the url conf kwarg. if hasattr(obj, 'full_clean'): - exclude = _get_validation_exclusions(obj, pk, slug_field) + exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field) obj.full_clean(exclude) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 6d8deec1..bc7f112c 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -314,6 +314,16 @@ class HyperlinkedRelatedField(RelatedField): self.format = kwargs.pop('format', None) # These are pending deprecation + if 'pk_url_kwarg' in kwargs: + msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_url_kwarg' in kwargs: + msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_field' in kwargs: + msg = 'slug_field is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) self.slug_field = kwargs.pop('slug_field', self.slug_field) default_slug_kwarg = self.slug_url_kwarg or self.slug_field -- cgit v1.2.3 From 387250bee438a3826191b2d0d196d0c11373f7f3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 2 May 2013 12:07:37 +0100 Subject: Automagically determine base_name in router class --- rest_framework/routers.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 923405e8..0707635a 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -42,10 +42,22 @@ class BaseRouter(object): def __init__(self): self.registry = [] - def register(self, prefix, viewset, name): - self.registry.append((prefix, viewset, name)) + def register(self, prefix, viewset, base_name=None): + if base_name is None: + base_name = self.get_default_base_name(viewset) + self.registry.append((prefix, viewset, base_name)) + + def get_default_base_name(self, viewset): + """ + If `base_name` is not specified, attempt to automatically determine + it from the viewset. + """ + raise NotImplemented('get_default_base_name must be overridden') def get_urls(self): + """ + Return a list of URL patterns, given the registered viewsets. + """ raise NotImplemented('get_urls must be overridden') @property @@ -91,6 +103,22 @@ class SimpleRouter(BaseRouter): ), ] + def get_default_base_name(self, viewset): + """ + If `base_name` is not specified, attempt to automatically determine + it from the viewset. + """ + model_cls = getattr(viewset, 'model', None) + queryset = getattr(viewset, 'queryset', None) + if model_cls is None and queryset is not None: + model_cls = queryset.model + + assert model_cls, '`name` not argument not specified, and could ' \ + 'not automatically determine the name from the viewset, as ' \ + 'it does not have a `.model` or `.queryset` attribute.' + + return model_cls._meta.object_name.lower() + def get_routes(self, viewset): """ Augment `self.routes` with any dynamically generated routes. -- cgit v1.2.3 From 0c85768435e67133ff219aaddb4ea3bf122bd360 Mon Sep 17 00:00:00 2001 From: Michael Elovskikh Date: Fri, 3 May 2013 01:37:25 +0600 Subject: Added FileUploadParser refs #7 --- rest_framework/parsers.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 491acd68..6ba05aef 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -6,9 +6,10 @@ on the request, such as form content or json encoded data. """ from __future__ import unicode_literals from django.conf import settings +from django.core.files.uploadhandler import StopFutureHandlers from django.http import QueryDict from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser -from django.http.multipartparser import MultiPartParserError +from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter from rest_framework.compat import yaml, etree from rest_framework.exceptions import ParseError from rest_framework.compat import six @@ -205,3 +206,63 @@ class XMLParser(BaseParser): pass return value + + +class FileUploadParser(BaseParser): + """ + Parser for file upload data. + """ + media_type = '*/*' + + def parse(self, stream, media_type=None, parser_context=None): + parser_context = parser_context or {} + request = parser_context['request'] + encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + meta = request.META + + try: + disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) + filename = disposition[1]['filename'] + except KeyError: + filename = None + + content_type = meta.get('HTTP_CONTENT_TYPE', meta.get('CONTENT_TYPE', '')) + try: + content_length = int(meta.get('HTTP_CONTENT_LENGTH', meta.get('CONTENT_LENGTH', 0))) + except (ValueError, TypeError): + content_length = None + + # See if the handler will want to take care of the parsing. + for handler in request.upload_handlers: + result = handler.handle_raw_input(None, + meta, + content_length, + None, + encoding) + if result is not None: + return DataAndFiles(result[0], {'file': result[1]}) + + possible_sizes = [x.chunk_size for x in request.upload_handlers if x.chunk_size] + chunk_size = min([2**31-4] + possible_sizes) + chunks = ChunkIter(stream, chunk_size) + counters = [0] * len(request.upload_handlers) + + for handler in request.upload_handlers: + try: + handler.new_file(None, filename, content_type, content_length, encoding) + except StopFutureHandlers: + break + + for chunk in chunks: + for i, handler in enumerate(request.upload_handlers): + chunk_length = len(chunk) + chunk = handler.receive_data_chunk(chunk, counters[i]) + counters[i] += chunk_length + if chunk is None: + # If the chunk received by the handler is None, then don't continue. + break + + for i, handler in enumerate(request.upload_handlers): + file_obj = handler.file_complete(counters[i]) + if file_obj: + return DataAndFiles(None, {'file': file_obj}) -- cgit v1.2.3 From 318fdaabe560c99de4983e0a3cdcb79756baaf01 Mon Sep 17 00:00:00 2001 From: Michael Elovskikh Date: Fri, 3 May 2013 01:39:08 +0600 Subject: Tests for FileUploadParser --- rest_framework/tests/parsers.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py index 539c5b44..b18ecbf2 100644 --- a/rest_framework/tests/parsers.py +++ b/rest_framework/tests/parsers.py @@ -1,10 +1,11 @@ from __future__ import unicode_literals from rest_framework.compat import StringIO from django import forms +from django.core.files.uploadhandler import MemoryFileUploadHandler from django.test import TestCase from django.utils import unittest from rest_framework.compat import etree -from rest_framework.parsers import FormParser +from rest_framework.parsers import FormParser, FileUploadParser from rest_framework.parsers import XMLParser import datetime @@ -82,3 +83,27 @@ class TestXMLParser(TestCase): parser = XMLParser() data = parser.parse(self._complex_data_input) self.assertEqual(data, self._complex_data) + + +class TestFileUploadParser(TestCase): + def setUp(self): + class MockRequest(object): + pass + from io import BytesIO + self.stream = BytesIO( + "Test text file".encode('utf-8') + ) + request = MockRequest() + request.upload_handlers = (MemoryFileUploadHandler(),) + request.META = { + 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'), + 'HTTP_CONTENT_LENGTH': 14, + } + self.parser_context = {'request': request} + + def test_parse(self): + """ Make sure the `QueryDict` works OK """ + parser = FileUploadParser() + data_and_files = parser.parse(self.stream, parser_context=self.parser_context) + file_obj = data_and_files.files['file'] + self.assertEqual(file_obj._size, 14) -- cgit v1.2.3 From e36e4f48ad481b4303e68ed524677add07b224f7 Mon Sep 17 00:00:00 2001 From: Michael Elovskikh Date: Sat, 4 May 2013 14:58:21 +0600 Subject: Codebase improvements on FileUploadParser * Added docstrings. * Added `FileUploadParser.get_filename` to make it easier to override. * Added url kwargs filename detection step. * Updated tests corresponding to these changes. --- rest_framework/parsers.py | 45 +++++++++++++++++++++++++++++------------ rest_framework/tests/parsers.py | 10 +++++++-- 2 files changed, 40 insertions(+), 15 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 6ba05aef..7eb92184 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -215,16 +215,19 @@ class FileUploadParser(BaseParser): media_type = '*/*' def parse(self, stream, media_type=None, parser_context=None): + """ + Returns a DataAndFiles object. + + `.data` will be None (we expect request body to be a file content). + `.files` will be a `QueryDict` containing one 'file' elemnt - a parsed file. + """ + parser_context = parser_context or {} request = parser_context['request'] encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) meta = request.META - - try: - disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) - filename = disposition[1]['filename'] - except KeyError: - filename = None + upload_handlers = request.upload_handlers + filename = self.get_filename(stream, media_type, parser_context) content_type = meta.get('HTTP_CONTENT_TYPE', meta.get('CONTENT_TYPE', '')) try: @@ -233,28 +236,28 @@ class FileUploadParser(BaseParser): content_length = None # See if the handler will want to take care of the parsing. - for handler in request.upload_handlers: + for handler in upload_handlers: result = handler.handle_raw_input(None, meta, content_length, None, encoding) if result is not None: - return DataAndFiles(result[0], {'file': result[1]}) + return DataAndFiles(None, {'file': result[1]}) - possible_sizes = [x.chunk_size for x in request.upload_handlers if x.chunk_size] + possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size] chunk_size = min([2**31-4] + possible_sizes) chunks = ChunkIter(stream, chunk_size) - counters = [0] * len(request.upload_handlers) + counters = [0] * len(upload_handlers) - for handler in request.upload_handlers: + for handler in upload_handlers: try: handler.new_file(None, filename, content_type, content_length, encoding) except StopFutureHandlers: break for chunk in chunks: - for i, handler in enumerate(request.upload_handlers): + for i, handler in enumerate(upload_handlers): chunk_length = len(chunk) chunk = handler.receive_data_chunk(chunk, counters[i]) counters[i] += chunk_length @@ -262,7 +265,23 @@ class FileUploadParser(BaseParser): # If the chunk received by the handler is None, then don't continue. break - for i, handler in enumerate(request.upload_handlers): + for i, handler in enumerate(upload_handlers): file_obj = handler.file_complete(counters[i]) if file_obj: return DataAndFiles(None, {'file': file_obj}) + + def get_filename(self, stream, media_type, parser_context): + """ + Detects the uploaded file name. First searches a 'filename' url kwarg. + Then tries to parse Content-Disposition header. + """ + try: + return parser_context['kwargs']['filename'] + except KeyError: + pass + try: + meta = parser_context['request'].META + disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) + return disposition[1]['filename'] + except (AttributeError, KeyError): + pass diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py index b18ecbf2..7699e10c 100644 --- a/rest_framework/tests/parsers.py +++ b/rest_framework/tests/parsers.py @@ -99,11 +99,17 @@ class TestFileUploadParser(TestCase): 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'), 'HTTP_CONTENT_LENGTH': 14, } - self.parser_context = {'request': request} + self.parser_context = {'request': request, 'kwargs': {}} def test_parse(self): """ Make sure the `QueryDict` works OK """ parser = FileUploadParser() - data_and_files = parser.parse(self.stream, parser_context=self.parser_context) + self.stream.seek(0) + data_and_files = parser.parse(self.stream, None, self.parser_context) file_obj = data_and_files.files['file'] self.assertEqual(file_obj._size, 14) + + def test_get_filename(self): + parser = FileUploadParser() + filename = parser.get_filename(self.stream, None, self.parser_context) + self.assertEqual(filename, 'file.txt'.encode('utf-8')) -- cgit v1.2.3 From a514232815a82ad8a4dc1819afa0d62f9bab1323 Mon Sep 17 00:00:00 2001 From: Michael Elovskikh Date: Sat, 4 May 2013 17:18:10 +0600 Subject: Raise ParseError if can't handle the uploaded file --- rest_framework/parsers.py | 1 + 1 file changed, 1 insertion(+) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 7eb92184..27a0db65 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -269,6 +269,7 @@ class FileUploadParser(BaseParser): file_obj = handler.file_complete(counters[i]) if file_obj: return DataAndFiles(None, {'file': file_obj}) + raise ParseError("FileUpload parse error - none of upload handlers can handle the stream") def get_filename(self, stream, media_type, parser_context): """ -- cgit v1.2.3 From 538d2e35e7f1e4623a215d1b8c684b284f951c09 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sun, 5 May 2013 16:47:45 +0100 Subject: lookup_field on hyperlink serializers --- rest_framework/relations.py | 10 +++++++++- rest_framework/serializers.py | 4 ++++ 2 files changed, 13 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index bc7f112c..fc5054b2 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -360,7 +360,15 @@ class HyperlinkedRelatedField(RelatedField): # Only try slug if it corresponds to an attribute on the object. kwargs = {self.slug_url_kwarg: slug} try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) + ret = reverse(view_name, kwargs=kwargs, request=request, format=format) + if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug': + # If the lookup succeeds using the default slug params, + # then `slug_field` is being used implicitly, and we + # we need to warn about the pending deprecation. + msg = 'Implicit slug field hyperlinked fields are pending deprecation.' \ + 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + return ret except NoReverseMatch: pass diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d4b34c01..ea5175e2 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -827,6 +827,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): def __init__(self, meta): super(HyperlinkedModelSerializerOptions, self).__init__(meta) self.view_name = getattr(meta, 'view_name', None) + self.lookup_field = getattr(meta, 'slug_field', None) class HyperlinkedModelSerializer(ModelSerializer): @@ -875,6 +876,9 @@ class HyperlinkedModelSerializer(ModelSerializer): if model_field: kwargs['required'] = not(model_field.null or model_field.blank) + if self.opts.lookup_field: + kwargs['lookup_field'] = self.opts.lookup_field + return self._hyperlink_field_class(**kwargs) def get_identity(self, data): -- cgit v1.2.3 From 660d2405174519628c72ed84a69ae37531df12f3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sun, 5 May 2013 16:48:00 +0100 Subject: .action attribute on viewsets --- rest_framework/viewsets.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index a54467d7..0eb3e86d 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -59,6 +59,10 @@ class ViewSetMixin(object): def view(request, *args, **kwargs): self = cls(**initkwargs) + # We also store the mapping of request methods to actions, + # so that we can later set the action attribute. + # eg. `self.action = 'list'` on an incoming GET request. + self.action_map = actions # Bind methods to actions # This is the bit that's different to a standard view @@ -87,6 +91,15 @@ class ViewSetMixin(object): view.suffix = initkwargs.get('suffix', None) return view + def initialize_request(self, request, *args, **kargs): + """ + Set the `.action` attribute on the view, + depending on the request method. + """ + request = super(ViewSetMixin, self).initialize_request(request, *args, **kargs) + self.action = self.action_map.get(request.method.lower()) + return request + class ViewSet(ViewSetMixin, views.APIView): """ -- cgit v1.2.3 From d71a5533f9a8787652244dfb16af37fb7d9059fb Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 7 May 2013 12:25:41 +0100 Subject: allow_empty -> pending deprecation in preference of overridden get_queryset. --- rest_framework/filters.py | 29 +++++++++++++++++++++++++++++ rest_framework/generics.py | 12 +++++++++--- 2 files changed, 38 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 5e1cdbac..571704dc 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -3,7 +3,10 @@ Provides generic filtering backends that can be used to filter the results returned by list views. """ from __future__ import unicode_literals + +from django.db import models from rest_framework.compat import django_filters +import operator FilterSet = django_filters and django_filters.FilterSet or None @@ -62,3 +65,29 @@ class DjangoFilterBackend(BaseFilterBackend): return filter_class(request.QUERY_PARAMS, queryset=queryset).qs return queryset + + +class SearchFilter(BaseFilterBackend): + def construct_search(self, field_name): + if field_name.startswith('^'): + return "%s__istartswith" % field_name[1:] + elif field_name.startswith('='): + return "%s__iexact" % field_name[1:] + elif field_name.startswith('@'): + return "%s__search" % field_name[1:] + else: + return "%s__icontains" % field_name + + def filter_queryset(self, request, queryset, view): + search_fields = getattr(view, 'search_fields', None) + + if not search_fields: + return None + + orm_lookups = [self.construct_search(str(search_field)) + for search_field in self.search_fields] + for bit in self.query.split(): + or_queries = [models.Q(**{orm_lookup: bit}) + for orm_lookup in orm_lookups] + queryset = queryset.filter(reduce(operator.or_, or_queries)) + return queryset diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 62129dcc..2bb23a89 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -42,9 +42,6 @@ class GenericAPIView(views.APIView): # The filter backend class to use for queryset filtering filter_backend = api_settings.FILTER_BACKEND - # Determines if the view will return 200 or 404 responses for empty lists. - allow_empty = True - # The following attributes may be subject to change, # and should be considered private API. model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS @@ -56,6 +53,7 @@ class GenericAPIView(views.APIView): pk_url_kwarg = 'pk' slug_url_kwarg = 'slug' slug_field = 'slug' + allow_empty = True def get_serializer_context(self): """ @@ -111,6 +109,14 @@ class GenericAPIView(views.APIView): if not page_size: return None + if not self.allow_empty: + warnings.warn( + 'The `allow_empty` parameter is due to be deprecated. ' + 'To use `allow_empty=False` style behavior, You should override ' + '`get_queryset()` and explicitly raise a 404 on empty querysets.', + PendingDeprecationWarning, stacklevel=2 + ) + paginator = self.paginator_class(queryset, page_size, allow_empty_first_page=self.allow_empty) page_kwarg = self.kwargs.get(self.page_kwarg) -- cgit v1.2.3 From 3c2bb0666063917707bfbfedf056e5692bfcc471 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 7 May 2013 13:00:44 +0100 Subject: Support for multiple filter classes --- rest_framework/generics.py | 23 +++++++++++++++++------ rest_framework/settings.py | 12 +++++++++--- 2 files changed, 26 insertions(+), 9 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 2bb23a89..05ec93d3 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -39,8 +39,8 @@ class GenericAPIView(views.APIView): pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS page_kwarg = 'page' - # The filter backend class to use for queryset filtering - filter_backend = api_settings.FILTER_BACKEND + # The filter backend classes to use for queryset filtering + filter_backends = api_settings.DEFAULT_FILTER_BACKENDS # The following attributes may be subject to change, # and should be considered private API. @@ -54,6 +54,7 @@ class GenericAPIView(views.APIView): slug_url_kwarg = 'slug' slug_field = 'slug' allow_empty = True + filter_backend = api_settings.FILTER_BACKEND def get_serializer_context(self): """ @@ -150,10 +151,20 @@ class GenericAPIView(views.APIView): method if you want to apply the configured filtering backend to the default queryset. """ - if not self.filter_backend: - return queryset - backend = self.filter_backend() - return backend.filter_queryset(self.request, queryset, self) + filter_backends = self.filter_backends or [] + if not filter_backends and self.filter_backend: + warnings.warn( + 'The `filter_backend` attribute and `FILTER_BACKEND` setting ' + 'are due to be deprecated in favor of a `filter_backends` ' + 'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take ' + 'a *list* of filter backend classes.', + PendingDeprecationWarning, stacklevel=2 + ) + filter_backends = [self.filter_backend] + + for backend in filter_backends: + queryset = backend().filter_queryset(self.request, queryset, self) + return queryset ######################## ### The following methods provide default implementations diff --git a/rest_framework/settings.py b/rest_framework/settings.py index eede0c5a..734d8478 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -29,6 +29,7 @@ from rest_framework.compat import six USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None) DEFAULTS = { + # Base API policies 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer', 'rest_framework.renderers.BrowsableAPIRenderer', @@ -50,11 +51,15 @@ DEFAULTS = { 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', + + # Genric view behavior 'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.serializers.ModelSerializer', 'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer', + 'DEFAULT_FILTER_BACKENDS': (), + # Throttling 'DEFAULT_THROTTLE_RATES': { 'user': None, 'anon': None, @@ -64,9 +69,6 @@ DEFAULTS = { 'PAGINATE_BY': None, 'PAGINATE_BY_PARAM': None, - # Filtering - 'FILTER_BACKEND': None, - # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -95,6 +97,9 @@ DEFAULTS = { ISO_8601, ), 'TIME_FORMAT': ISO_8601, + + # Pending deprecation + 'FILTER_BACKEND': None, } @@ -108,6 +113,7 @@ IMPORT_STRINGS = ( 'DEFAULT_CONTENT_NEGOTIATION_CLASS', 'DEFAULT_MODEL_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS', + 'DEFAULT_FILTER_BACKENDS', 'FILTER_BACKEND', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', -- cgit v1.2.3 From 3353889ae85cc21890469cf00f7073d1ea5c2070 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 7 May 2013 13:27:27 +0100 Subject: Docs for FileUploadParser --- rest_framework/parsers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 27a0db65..614531a1 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -246,7 +246,7 @@ class FileUploadParser(BaseParser): return DataAndFiles(None, {'file': result[1]}) possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size] - chunk_size = min([2**31-4] + possible_sizes) + chunk_size = min([2 ** 31 - 4] + possible_sizes) chunks = ChunkIter(stream, chunk_size) counters = [0] * len(upload_handlers) @@ -280,9 +280,10 @@ class FileUploadParser(BaseParser): return parser_context['kwargs']['filename'] except KeyError: pass + try: meta = parser_context['request'].META disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) return disposition[1]['filename'] except (AttributeError, KeyError): - pass + raise ParseError("Filename must be set in Content-Disposition header.") -- cgit v1.2.3 From ed2cf180c961bb337c5d3ab7e5f74a1539c33ae4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 7 May 2013 13:29:38 +0100 Subject: Version 2.3.0 --- rest_framework/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 856badc6..35196c74 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.2.7' +__version__ = '2.3.0' VERSION = __version__ # synonym -- cgit v1.2.3 From d7c08222f14389b4d61e5ca9032c49b8b917d251 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 7 May 2013 14:11:48 +0100 Subject: Fix breadcrumb rendering issue --- rest_framework/__init__.py | 2 +- rest_framework/utils/breadcrumbs.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 35196c74..819558b5 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.3.0' +__version__ = '2.3.1' VERSION = __version__ # synonym diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 28801d09..d51374b0 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -24,7 +24,8 @@ def get_breadcrumbs(url): else: # Check if this is a REST framework view, # and if so add it to the breadcrumbs - if issubclass(getattr(view, 'cls', None), APIView): + cls = getattr(view, 'cls', None) + if cls is not None and issubclass(cls, APIView): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: -- cgit v1.2.3 From 429e078eee63a120c408946cf7c1460d4ca9e9b4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 May 2013 20:07:51 +0100 Subject: Allow None filename on uploaded files --- rest_framework/parsers.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 614531a1..25be2e6a 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -219,7 +219,7 @@ class FileUploadParser(BaseParser): Returns a DataAndFiles object. `.data` will be None (we expect request body to be a file content). - `.files` will be a `QueryDict` containing one 'file' elemnt - a parsed file. + `.files` will be a `QueryDict` containing one 'file' element. """ parser_context = parser_context or {} @@ -229,9 +229,13 @@ class FileUploadParser(BaseParser): upload_handlers = request.upload_handlers filename = self.get_filename(stream, media_type, parser_context) - content_type = meta.get('HTTP_CONTENT_TYPE', meta.get('CONTENT_TYPE', '')) + # Note that this code is extracted from Django's handling of + # file uploads in MultiPartParser. + content_type = meta.get('HTTP_CONTENT_TYPE', + meta.get('CONTENT_TYPE', '')) try: - content_length = int(meta.get('HTTP_CONTENT_LENGTH', meta.get('CONTENT_LENGTH', 0))) + content_length = int(meta.get('HTTP_CONTENT_LENGTH', + meta.get('CONTENT_LENGTH', 0))) except (ValueError, TypeError): content_length = None @@ -245,6 +249,7 @@ class FileUploadParser(BaseParser): if result is not None: return DataAndFiles(None, {'file': result[1]}) + # This is the standard case. possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size] chunk_size = min([2 ** 31 - 4] + possible_sizes) chunks = ChunkIter(stream, chunk_size) @@ -252,7 +257,8 @@ class FileUploadParser(BaseParser): for handler in upload_handlers: try: - handler.new_file(None, filename, content_type, content_length, encoding) + handler.new_file(None, filename, content_type, + content_length, encoding) except StopFutureHandlers: break @@ -262,14 +268,14 @@ class FileUploadParser(BaseParser): chunk = handler.receive_data_chunk(chunk, counters[i]) counters[i] += chunk_length if chunk is None: - # If the chunk received by the handler is None, then don't continue. break for i, handler in enumerate(upload_handlers): file_obj = handler.file_complete(counters[i]) if file_obj: return DataAndFiles(None, {'file': file_obj}) - raise ParseError("FileUpload parse error - none of upload handlers can handle the stream") + raise ParseError("FileUpload parse error - " + "none of upload handlers can handle the stream") def get_filename(self, stream, media_type, parser_context): """ @@ -286,4 +292,4 @@ class FileUploadParser(BaseParser): disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) return disposition[1]['filename'] except (AttributeError, KeyError): - raise ParseError("Filename must be set in Content-Disposition header.") + pass -- cgit v1.2.3 From de69a28b9e786b8c759cda4acedb0a1b8542298b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 May 2013 20:18:01 +0100 Subject: Test and fix for #814. --- rest_framework/filters.py | 14 ++++++++++---- rest_framework/tests/filterset.py | 28 +++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 571704dc..f2163f6f 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -38,21 +38,27 @@ class DjangoFilterBackend(BaseFilterBackend): """ filter_class = getattr(view, 'filter_class', None) filter_fields = getattr(view, 'filter_fields', None) - view_model = getattr(view, 'model', None) + model_cls = getattr(view, 'model', None) + queryset = getattr(view, 'queryset', None) + if model_cls is None and queryset is not None: + model_cls = queryset.model if filter_class: filter_model = filter_class.Meta.model - assert issubclass(filter_model, view_model), \ + assert issubclass(filter_model, model_cls), \ 'FilterSet model %s does not match view model %s' % \ - (filter_model, view_model) + (filter_model, model_cls) return filter_class if filter_fields: + assert model_cls is not None, 'Cannot use DjangoFilterBackend ' \ + 'on a view which does not have a .model or .queryset attribute.' + class AutoFilterSet(self.default_filter_set): class Meta: - model = view_model + model = model_cls fields = filter_fields return AutoFilterSet diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 1e53a5cd..023bd016 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -5,7 +5,7 @@ from django.core.urlresolvers import reverse from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest -from rest_framework import generics, status, filters +from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url from rest_framework.tests.models import FilterableItem, BasicModel @@ -52,6 +52,17 @@ if django_filters: filter_class = SeveralFieldsFilter filter_backend = filters.DjangoFilterBackend + # Regression test for #814 + class FilterableItemSerializer(serializers.ModelSerializer): + class Meta: + model = FilterableItem + + class FilterFieldsQuerysetView(generics.ListCreateAPIView): + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer + filter_fields = ['decimal', 'date'] + filter_backend = filters.DjangoFilterBackend + urlpatterns = patterns('', url(r'^(?P\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), url(r'^$', FilterClassRootView.as_view(), name='root-view'), @@ -114,6 +125,21 @@ class IntegrationTestFiltering(CommonFilteringTestCase): expected_data = [f for f in self.data if f['date'] == search_date] self.assertEqual(response.data, expected_data) + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_filter_with_queryset(self): + """ + Regression test for #814. + """ + view = FilterFieldsQuerysetView.as_view() + + # Tests that the decimal filter works. + search_decimal = Decimal('2.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] == search_decimal] + self.assertEqual(response.data, expected_data) + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_class_root_view(self): """ -- cgit v1.2.3 From b443560080a20d52a3dd49f625a103810935affd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 May 2013 20:38:50 +0100 Subject: Fix DATETIME_FORMAT, DATE_FORMAT, TIME_FORMAT settings. Closes #798 --- rest_framework/fields.py | 6 +++--- rest_framework/settings.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f934fc39..c83ee5ec 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -500,7 +500,7 @@ class DateField(WritableField): } empty = None input_formats = api_settings.DATE_INPUT_FORMATS - format = None + format = api_settings.DATE_FORMAT def __init__(self, input_formats=None, format=None, *args, **kwargs): self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -563,7 +563,7 @@ class DateTimeField(WritableField): } empty = None input_formats = api_settings.DATETIME_INPUT_FORMATS - format = None + format = api_settings.DATETIME_FORMAT def __init__(self, input_formats=None, format=None, *args, **kwargs): self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -632,7 +632,7 @@ class TimeField(WritableField): } empty = None input_formats = api_settings.TIME_INPUT_FORMATS - format = None + format = api_settings.TIME_FORMAT def __init__(self, input_formats=None, format=None, *args, **kwargs): self.input_formats = input_formats if input_formats is not None else self.input_formats diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 734d8478..beb511ac 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -86,17 +86,17 @@ DEFAULTS = { 'DATE_INPUT_FORMATS': ( ISO_8601, ), - 'DATE_FORMAT': ISO_8601, + 'DATE_FORMAT': None, 'DATETIME_INPUT_FORMATS': ( ISO_8601, ), - 'DATETIME_FORMAT': ISO_8601, + 'DATETIME_FORMAT': None, 'TIME_INPUT_FORMATS': ( ISO_8601, ), - 'TIME_FORMAT': ISO_8601, + 'TIME_FORMAT': None, # Pending deprecation 'FILTER_BACKEND': None, -- cgit v1.2.3 From 4ab7b8f257f9d3a1b35d34d0f90f0103b0cc6369 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 May 2013 20:49:49 +0100 Subject: Version 2.3.2 --- rest_framework/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 819558b5..b4961e2f 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.3.1' +__version__ = '2.3.2' VERSION = __version__ # synonym -- cgit v1.2.3 From 14482a966168a98d43099d00c163d1c8c3b6471b Mon Sep 17 00:00:00 2001 From: Mark Aaron Shirley Date: Wed, 8 May 2013 22:44:23 -0700 Subject: Fix deprecation warnings in relations_nested tests --- rest_framework/tests/relations_nested.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index 22c98e7f..8325580f 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -46,7 +46,7 @@ class ReverseNestedOneToOneTests(TestCase): def test_one_to_one_retrieve(self): queryset = OneToOneTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, @@ -65,7 +65,7 @@ class ReverseNestedOneToOneTests(TestCase): # Ensure (target 4, target_source 4, source 4) are added, and # everything else is as expected. queryset = OneToOneTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, @@ -92,7 +92,7 @@ class ReverseNestedOneToOneTests(TestCase): # Ensure (target 3, target_source 3, source 3) are updated, # and everything else is as expected. queryset = OneToOneTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, @@ -125,7 +125,7 @@ class ForwardNestedOneToOneTests(TestCase): def test_one_to_one_retrieve(self): queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, @@ -144,7 +144,7 @@ class ForwardNestedOneToOneTests(TestCase): # Ensure (target 4, target_source 4, source 4) are added, and # everything else is as expected. queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, @@ -171,7 +171,7 @@ class ForwardNestedOneToOneTests(TestCase): # Ensure (target 3, target_source 3, source 3) are updated, # and everything else is as expected. queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, @@ -224,7 +224,7 @@ class ReverseNestedOneToManyTests(TestCase): def test_one_to_many_retrieve(self): queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, {'id': 2, 'name': 'source-2'}, @@ -247,7 +247,7 @@ class ReverseNestedOneToManyTests(TestCase): # Ensure source 4 is added, and everything else is as # expected. queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, {'id': 2, 'name': 'source-2'}, @@ -279,7 +279,7 @@ class ReverseNestedOneToManyTests(TestCase): # Ensure (target 1, source 1) are updated, # and everything else is as expected. queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, {'id': 2, 'name': 'source-2'}, @@ -299,7 +299,7 @@ class ReverseNestedOneToManyTests(TestCase): # Ensure source 2 is deleted, and everything else is as # expected. queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, {'id': 3, 'name': 'source-3'}]} -- cgit v1.2.3