diff options
| -rw-r--r-- | djangorestframework/mixins.py | 135 | ||||
| -rw-r--r-- | djangorestframework/tests/mixins.py | 162 |
2 files changed, 254 insertions, 43 deletions
diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py index 394440d3..b1a634a0 100644 --- a/djangorestframework/mixins.py +++ b/djangorestframework/mixins.py @@ -1,23 +1,20 @@ """ -The :mod:`mixins` module provides a set of reusable `mixin` +The :mod:`mixins` module provides a set of reusable `mixin` classes that can be added to a `View`. """ from django.contrib.auth.models import AnonymousUser -from django.db.models.query import QuerySet +from django.core.paginator import Paginator from django.db.models.fields.related import ForeignKey from django.http import HttpResponse from djangorestframework import status -from djangorestframework.parsers import FormParser, MultiPartParser from djangorestframework.renderers import BaseRenderer from djangorestframework.resources import Resource, FormResource, ModelResource from djangorestframework.response import Response, ErrorResponse from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX from djangorestframework.utils.mediatypes import is_form_media_type, order_by_precedence -from decimal import Decimal -import re from StringIO import StringIO @@ -52,7 +49,7 @@ class RequestMixin(object): """ The set of request parsers that the view can handle. - + Should be a tuple/list of classes as described in the :mod:`parsers` module. """ parsers = () @@ -158,7 +155,7 @@ class RequestMixin(object): # We only need to use form overloading on form POST requests. if not self._USE_FORM_OVERLOADING or self._method != 'POST' or not is_form_media_type(self._content_type): return - + # At this point we're committed to parsing the request as form data. self._data = data = self.request.POST.copy() self._files = self.request.FILES @@ -203,12 +200,12 @@ class RequestMixin(object): """ return [parser.media_type for parser in self.parsers] - + @property def _default_parser(self): """ Return the view's default parser class. - """ + """ return self.parsers[0] @@ -218,7 +215,7 @@ class RequestMixin(object): class ResponseMixin(object): """ Adds behavior for pluggable `Renderers` to a :class:`views.View` class. - + Default behavior is to use standard HTTP Accept header content negotiation. Also supports overriding the content type by specifying an ``_accept=`` parameter in the URL. Ignores Accept headers from Internet Explorer user agents and uses a sensible browser Accept header instead. @@ -229,8 +226,8 @@ class ResponseMixin(object): """ The set of response renderers that the view can handle. - - Should be a tuple/list of classes as described in the :mod:`renderers` module. + + Should be a tuple/list of classes as described in the :mod:`renderers` module. """ renderers = () @@ -253,7 +250,7 @@ class ResponseMixin(object): # Set the media type of the response # Note that the renderer *could* override it in .render() if required. response.media_type = renderer.media_type - + # Serialize the response content if response.has_content_body: content = renderer.render(response.cleaned_content, media_type) @@ -317,7 +314,7 @@ class ResponseMixin(object): Return an list of all the media types that this view can render. """ return [renderer.media_type for renderer in self.renderers] - + @property def _rendered_formats(self): """ @@ -339,18 +336,18 @@ class AuthMixin(object): """ Simple :class:`mixin` class to add authentication and permission checking to a :class:`View` class. """ - + """ The set of authentication types that this view can handle. - - Should be a tuple/list of classes as described in the :mod:`authentication` module. + + Should be a tuple/list of classes as described in the :mod:`authentication` module. """ authentication = () """ The set of permissions that will be enforced on this view. - - Should be a tuple/list of classes as described in the :mod:`permissions` module. + + Should be a tuple/list of classes as described in the :mod:`permissions` module. """ permissions = () @@ -359,7 +356,7 @@ class AuthMixin(object): def user(self): """ Returns the :obj:`user` for the current request, as determined by the set of - :class:`authentication` classes applied to the :class:`View`. + :class:`authentication` classes applied to the :class:`View`. """ if not hasattr(self, '_user'): self._user = self._authenticate() @@ -541,13 +538,13 @@ class CreateModelMixin(object): for fieldname in m2m_data: manager = getattr(instance, fieldname) - + if hasattr(manager, 'add'): manager.add(*m2m_data[fieldname][1]) else: data = {} data[manager.source_field_name] = instance - + for related_item in m2m_data[fieldname][1]: data[m2m_data[fieldname][0]] = related_item manager.through(**data).save() @@ -564,8 +561,8 @@ class UpdateModelMixin(object): """ def put(self, request, *args, **kwargs): model = self.resource.model - - # TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url + + # TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url try: if args: # If we have any none kwargs then assume the last represents the primary key @@ -640,3 +637,93 @@ class ListModelMixin(object): return queryset.filter(**kwargs) +########## Pagination Mixins ########## + +class PaginatorMixin(object): + """ + Adds pagination support to GET requests + Obviously should only be used on lists :) + + A default limit can be set by setting `limit` on the object. This will also + be used as the maximum if the client sets the `limit` GET param + """ + limit = 20 + + def get_limit(self): + """ Helper method to determine what the `limit` should be """ + try: + limit = int(self.request.GET.get('limit', self.limit)) + return min(limit, self.limit) + except ValueError: + return self.limit + + def url_with_page_number(self, page_number): + """ Constructs a url used for getting the next/previous urls """ + url = "%s?page=%d" % (self.request.path, page_number) + + limit = self.get_limit() + if limit != self.limit: + url = "%s&limit=%d" % (url, limit) + + return url + + def next(self, page): + """ Returns a url to the next page of results (if any) """ + if not page.has_next(): + return None + + return self.url_with_page_number(page.next_page_number()) + + def previous(self, page): + """ Returns a url to the previous page of results (if any) """ + if not page.has_previous(): + return None + + return self.url_with_page_number(page.previous_page_number()) + + def serialize_page_info(self, page): + """ This is some useful information that is added to the response """ + return { + 'next': self.next(page), + 'page': page.number, + 'pages': page.paginator.num_pages, + 'per_page': self.get_limit(), + 'previous': self.previous(page), + 'total': page.paginator.count, + } + + def filter_response(self, obj): + """ + Given the response content, paginate and then serialize. + + The response is modified to include to useful data relating to the number + of objects, number of pages, next/previous urls etc. etc. + + The serialised objects are put into `results` on this new, modified + response + """ + + # We don't want to paginate responses for anything other than GET requests + if self.method.upper() != 'GET': + return self._resource.filter_response(obj) + + paginator = Paginator(obj, self.get_limit()) + + try: + page_num = int(self.request.GET.get('page', '1')) + except ValueError: + raise ErrorResponse(status.HTTP_404_NOT_FOUND, + {'detail': 'That page contains no results'}) + + if page_num not in paginator.page_range: + raise ErrorResponse(status.HTTP_404_NOT_FOUND, + {'detail': 'That page contains no results'}) + + page = paginator.page(page_num) + + serialized_object_list = self._resource.filter_response(page.object_list) + serialized_page_info = self.serialize_page_info(page) + + serialized_page_info['results'] = serialized_object_list + + return serialized_page_info diff --git a/djangorestframework/tests/mixins.py b/djangorestframework/tests/mixins.py index da7c4d86..65cf4a45 100644 --- a/djangorestframework/tests/mixins.py +++ b/djangorestframework/tests/mixins.py @@ -1,14 +1,17 @@ -"""Tests for the status module""" +"""Tests for the mixin module""" from django.test import TestCase +from django.utils import simplejson as json from djangorestframework import status from djangorestframework.compat import RequestFactory from django.contrib.auth.models import Group, User -from djangorestframework.mixins import CreateModelMixin +from djangorestframework.mixins import CreateModelMixin, PaginatorMixin from djangorestframework.resources import ModelResource +from djangorestframework.response import Response from djangorestframework.tests.models import CustomUser +from djangorestframework.views import View -class TestModelCreation(TestCase): +class TestModelCreation(TestCase): """Tests on CreateModelMixin""" def setUp(self): @@ -25,23 +28,26 @@ class TestModelCreation(TestCase): mixin = CreateModelMixin() mixin.resource = GroupResource mixin.CONTENT = form_data - + response = mixin.post(request) self.assertEquals(1, Group.objects.count()) self.assertEquals('foo', response.cleaned_content.name) - def test_creation_with_m2m_relation(self): class UserResource(ModelResource): model = User - + def url(self, instance): return "/users/%i" % instance.id group = Group(name='foo') group.save() - form_data = {'username': 'bar', 'password': 'baz', 'groups': [group.id]} + form_data = { + 'username': 'bar', + 'password': 'baz', + 'groups': [group.id] + } request = self.req.post('/groups', data=form_data) cleaned_data = dict(form_data) cleaned_data['groups'] = [group] @@ -53,18 +59,18 @@ class TestModelCreation(TestCase): self.assertEquals(1, User.objects.count()) self.assertEquals(1, response.cleaned_content.groups.count()) self.assertEquals('foo', response.cleaned_content.groups.all()[0].name) - + def test_creation_with_m2m_relation_through(self): """ Tests creation where the m2m relation uses a through table """ class UserResource(ModelResource): model = CustomUser - + def url(self, instance): return "/customusers/%i" % instance.id - - form_data = {'username': 'bar0', 'groups': []} + + form_data = {'username': 'bar0', 'groups': []} request = self.req.post('/groups', data=form_data) cleaned_data = dict(form_data) cleaned_data['groups'] = [] @@ -74,12 +80,12 @@ class TestModelCreation(TestCase): response = mixin.post(request) self.assertEquals(1, CustomUser.objects.count()) - self.assertEquals(0, response.cleaned_content.groups.count()) + self.assertEquals(0, response.cleaned_content.groups.count()) group = Group(name='foo1') group.save() - form_data = {'username': 'bar1', 'groups': [group.id]} + form_data = {'username': 'bar1', 'groups': [group.id]} request = self.req.post('/groups', data=form_data) cleaned_data = dict(form_data) cleaned_data['groups'] = [group] @@ -91,12 +97,11 @@ class TestModelCreation(TestCase): self.assertEquals(2, CustomUser.objects.count()) self.assertEquals(1, response.cleaned_content.groups.count()) self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name) - - + group2 = Group(name='foo2') - group2.save() - - form_data = {'username': 'bar2', 'groups': [group.id, group2.id]} + group2.save() + + form_data = {'username': 'bar2', 'groups': [group.id, group2.id]} request = self.req.post('/groups', data=form_data) cleaned_data = dict(form_data) cleaned_data['groups'] = [group, group2] @@ -109,5 +114,124 @@ class TestModelCreation(TestCase): self.assertEquals(2, response.cleaned_content.groups.count()) self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name) self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name) - + +class MockPaginatorView(PaginatorMixin, View): + total = 60 + + def get(self, request): + return range(0, self.total) + + def post(self, request): + return Response(status.CREATED, {'status': 'OK'}) + + +class TestPagination(TestCase): + def setUp(self): + self.req = RequestFactory() + + def test_default_limit(self): + """ Tests if pagination works without overwriting the limit """ + request = self.req.get('/paginator') + response = MockPaginatorView.as_view()(request) + + content = json.loads(response.content) + + self.assertEqual(response.status_code, status.OK) + self.assertEqual(MockPaginatorView.total, content['total']) + self.assertEqual(MockPaginatorView.limit, content['per_page']) + + self.assertEqual(range(0, MockPaginatorView.limit), content['results']) + + def test_overwriting_limit(self): + """ Tests if the limit can be overwritten """ + limit = 10 + + request = self.req.get('/paginator') + response = MockPaginatorView.as_view(limit=limit)(request) + + content = json.loads(response.content) + + self.assertEqual(response.status_code, status.OK) + self.assertEqual(content['per_page'], limit) + + self.assertEqual(range(0, limit), content['results']) + + def test_limit_param(self): + """ Tests if the client can set the limit """ + from math import ceil + + limit = 5 + num_pages = int(ceil(MockPaginatorView.total / float(limit))) + + request = self.req.get('/paginator/?limit=%d' % limit) + response = MockPaginatorView.as_view()(request) + + content = json.loads(response.content) + + self.assertEqual(response.status_code, status.OK) + self.assertEqual(MockPaginatorView.total, content['total']) + self.assertEqual(limit, content['per_page']) + self.assertEqual(num_pages, content['pages']) + + def test_exceeding_limit(self): + """ Makes sure the client cannot exceed the default limit """ + from math import ceil + + limit = MockPaginatorView.limit + 10 + num_pages = int(ceil(MockPaginatorView.total / float(limit))) + + request = self.req.get('/paginator/?limit=%d' % limit) + response = MockPaginatorView.as_view()(request) + + content = json.loads(response.content) + + self.assertEqual(response.status_code, status.OK) + self.assertEqual(MockPaginatorView.total, content['total']) + self.assertNotEqual(limit, content['per_page']) + self.assertNotEqual(num_pages, content['pages']) + self.assertEqual(MockPaginatorView.limit, content['per_page']) + + def test_only_works_for_get(self): + """ Pagination should only work for GET requests """ + request = self.req.post('/paginator', data={'content': 'spam'}) + response = MockPaginatorView.as_view()(request) + + content = json.loads(response.content) + + self.assertEqual(response.status_code, status.CREATED) + self.assertEqual(None, content.get('per_page')) + self.assertEqual('OK', content['status']) + + def test_non_int_page(self): + """ Tests that it can handle invalid values """ + request = self.req.get('/paginator/?page=spam') + response = MockPaginatorView.as_view()(request) + + self.assertEqual(response.status_code, status.NOT_FOUND) + + def test_page_range(self): + """ Tests that the page range is handle correctly """ + request = self.req.get('/paginator/?page=0') + response = MockPaginatorView.as_view()(request) + content = json.loads(response.content) + self.assertEqual(response.status_code, status.NOT_FOUND) + + request = self.req.get('/paginator/') + response = MockPaginatorView.as_view()(request) + content = json.loads(response.content) + self.assertEqual(response.status_code, status.OK) + self.assertEqual(range(0, MockPaginatorView.limit), content['results']) + + num_pages = content['pages'] + + request = self.req.get('/paginator/?page=%d' % num_pages) + response = MockPaginatorView.as_view()(request) + content = json.loads(response.content) + self.assertEqual(response.status_code, status.OK) + self.assertEqual(range(MockPaginatorView.limit*(num_pages-1), MockPaginatorView.total), content['results']) + + request = self.req.get('/paginator/?page=%d' % (num_pages + 1,)) + response = MockPaginatorView.as_view()(request) + content = json.loads(response.content) + self.assertEqual(response.status_code, status.NOT_FOUND) |
