aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2011-12-09 13:37:53 +0000
committerTom Christie2011-12-09 13:37:53 +0000
commit5db422c9d38277789bb6d2cf214f46ed7642d395 (patch)
treed5470b1b24a7d09a08ba2f94a5d45c3b49aaf760
parent42cdd00591b01fd8c7c51276fcd09bb7a4d3c185 (diff)
downloaddjango-rest-framework-5db422c9d38277789bb6d2cf214f46ed7642d395.tar.bz2
Add pagination. Thanks @devioustree!
-rw-r--r--djangorestframework/mixins.py135
-rw-r--r--djangorestframework/tests/mixins.py162
2 files changed, 254 insertions, 43 deletions
diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py
index 394440d3..b1a634a0 100644
--- a/djangorestframework/mixins.py
+++ b/djangorestframework/mixins.py
@@ -1,23 +1,20 @@
"""
-The :mod:`mixins` module provides a set of reusable `mixin`
+The :mod:`mixins` module provides a set of reusable `mixin`
classes that can be added to a `View`.
"""
from django.contrib.auth.models import AnonymousUser
-from django.db.models.query import QuerySet
+from django.core.paginator import Paginator
from django.db.models.fields.related import ForeignKey
from django.http import HttpResponse
from djangorestframework import status
-from djangorestframework.parsers import FormParser, MultiPartParser
from djangorestframework.renderers import BaseRenderer
from djangorestframework.resources import Resource, FormResource, ModelResource
from djangorestframework.response import Response, ErrorResponse
from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX
from djangorestframework.utils.mediatypes import is_form_media_type, order_by_precedence
-from decimal import Decimal
-import re
from StringIO import StringIO
@@ -52,7 +49,7 @@ class RequestMixin(object):
"""
The set of request parsers that the view can handle.
-
+
Should be a tuple/list of classes as described in the :mod:`parsers` module.
"""
parsers = ()
@@ -158,7 +155,7 @@ class RequestMixin(object):
# We only need to use form overloading on form POST requests.
if not self._USE_FORM_OVERLOADING or self._method != 'POST' or not is_form_media_type(self._content_type):
return
-
+
# At this point we're committed to parsing the request as form data.
self._data = data = self.request.POST.copy()
self._files = self.request.FILES
@@ -203,12 +200,12 @@ class RequestMixin(object):
"""
return [parser.media_type for parser in self.parsers]
-
+
@property
def _default_parser(self):
"""
Return the view's default parser class.
- """
+ """
return self.parsers[0]
@@ -218,7 +215,7 @@ class RequestMixin(object):
class ResponseMixin(object):
"""
Adds behavior for pluggable `Renderers` to a :class:`views.View` class.
-
+
Default behavior is to use standard HTTP Accept header content negotiation.
Also supports overriding the content type by specifying an ``_accept=`` parameter in the URL.
Ignores Accept headers from Internet Explorer user agents and uses a sensible browser Accept header instead.
@@ -229,8 +226,8 @@ class ResponseMixin(object):
"""
The set of response renderers that the view can handle.
-
- Should be a tuple/list of classes as described in the :mod:`renderers` module.
+
+ Should be a tuple/list of classes as described in the :mod:`renderers` module.
"""
renderers = ()
@@ -253,7 +250,7 @@ class ResponseMixin(object):
# Set the media type of the response
# Note that the renderer *could* override it in .render() if required.
response.media_type = renderer.media_type
-
+
# Serialize the response content
if response.has_content_body:
content = renderer.render(response.cleaned_content, media_type)
@@ -317,7 +314,7 @@ class ResponseMixin(object):
Return an list of all the media types that this view can render.
"""
return [renderer.media_type for renderer in self.renderers]
-
+
@property
def _rendered_formats(self):
"""
@@ -339,18 +336,18 @@ class AuthMixin(object):
"""
Simple :class:`mixin` class to add authentication and permission checking to a :class:`View` class.
"""
-
+
"""
The set of authentication types that this view can handle.
-
- Should be a tuple/list of classes as described in the :mod:`authentication` module.
+
+ Should be a tuple/list of classes as described in the :mod:`authentication` module.
"""
authentication = ()
"""
The set of permissions that will be enforced on this view.
-
- Should be a tuple/list of classes as described in the :mod:`permissions` module.
+
+ Should be a tuple/list of classes as described in the :mod:`permissions` module.
"""
permissions = ()
@@ -359,7 +356,7 @@ class AuthMixin(object):
def user(self):
"""
Returns the :obj:`user` for the current request, as determined by the set of
- :class:`authentication` classes applied to the :class:`View`.
+ :class:`authentication` classes applied to the :class:`View`.
"""
if not hasattr(self, '_user'):
self._user = self._authenticate()
@@ -541,13 +538,13 @@ class CreateModelMixin(object):
for fieldname in m2m_data:
manager = getattr(instance, fieldname)
-
+
if hasattr(manager, 'add'):
manager.add(*m2m_data[fieldname][1])
else:
data = {}
data[manager.source_field_name] = instance
-
+
for related_item in m2m_data[fieldname][1]:
data[m2m_data[fieldname][0]] = related_item
manager.through(**data).save()
@@ -564,8 +561,8 @@ class UpdateModelMixin(object):
"""
def put(self, request, *args, **kwargs):
model = self.resource.model
-
- # TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url
+
+ # TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url
try:
if args:
# If we have any none kwargs then assume the last represents the primary key
@@ -640,3 +637,93 @@ class ListModelMixin(object):
return queryset.filter(**kwargs)
+########## Pagination Mixins ##########
+
+class PaginatorMixin(object):
+ """
+ Adds pagination support to GET requests
+ Obviously should only be used on lists :)
+
+ A default limit can be set by setting `limit` on the object. This will also
+ be used as the maximum if the client sets the `limit` GET param
+ """
+ limit = 20
+
+ def get_limit(self):
+ """ Helper method to determine what the `limit` should be """
+ try:
+ limit = int(self.request.GET.get('limit', self.limit))
+ return min(limit, self.limit)
+ except ValueError:
+ return self.limit
+
+ def url_with_page_number(self, page_number):
+ """ Constructs a url used for getting the next/previous urls """
+ url = "%s?page=%d" % (self.request.path, page_number)
+
+ limit = self.get_limit()
+ if limit != self.limit:
+ url = "%s&limit=%d" % (url, limit)
+
+ return url
+
+ def next(self, page):
+ """ Returns a url to the next page of results (if any) """
+ if not page.has_next():
+ return None
+
+ return self.url_with_page_number(page.next_page_number())
+
+ def previous(self, page):
+ """ Returns a url to the previous page of results (if any) """
+ if not page.has_previous():
+ return None
+
+ return self.url_with_page_number(page.previous_page_number())
+
+ def serialize_page_info(self, page):
+ """ This is some useful information that is added to the response """
+ return {
+ 'next': self.next(page),
+ 'page': page.number,
+ 'pages': page.paginator.num_pages,
+ 'per_page': self.get_limit(),
+ 'previous': self.previous(page),
+ 'total': page.paginator.count,
+ }
+
+ def filter_response(self, obj):
+ """
+ Given the response content, paginate and then serialize.
+
+ The response is modified to include to useful data relating to the number
+ of objects, number of pages, next/previous urls etc. etc.
+
+ The serialised objects are put into `results` on this new, modified
+ response
+ """
+
+ # We don't want to paginate responses for anything other than GET requests
+ if self.method.upper() != 'GET':
+ return self._resource.filter_response(obj)
+
+ paginator = Paginator(obj, self.get_limit())
+
+ try:
+ page_num = int(self.request.GET.get('page', '1'))
+ except ValueError:
+ raise ErrorResponse(status.HTTP_404_NOT_FOUND,
+ {'detail': 'That page contains no results'})
+
+ if page_num not in paginator.page_range:
+ raise ErrorResponse(status.HTTP_404_NOT_FOUND,
+ {'detail': 'That page contains no results'})
+
+ page = paginator.page(page_num)
+
+ serialized_object_list = self._resource.filter_response(page.object_list)
+ serialized_page_info = self.serialize_page_info(page)
+
+ serialized_page_info['results'] = serialized_object_list
+
+ return serialized_page_info
diff --git a/djangorestframework/tests/mixins.py b/djangorestframework/tests/mixins.py
index da7c4d86..65cf4a45 100644
--- a/djangorestframework/tests/mixins.py
+++ b/djangorestframework/tests/mixins.py
@@ -1,14 +1,17 @@
-"""Tests for the status module"""
+"""Tests for the mixin module"""
from django.test import TestCase
+from django.utils import simplejson as json
from djangorestframework import status
from djangorestframework.compat import RequestFactory
from django.contrib.auth.models import Group, User
-from djangorestframework.mixins import CreateModelMixin
+from djangorestframework.mixins import CreateModelMixin, PaginatorMixin
from djangorestframework.resources import ModelResource
+from djangorestframework.response import Response
from djangorestframework.tests.models import CustomUser
+from djangorestframework.views import View
-class TestModelCreation(TestCase):
+class TestModelCreation(TestCase):
"""Tests on CreateModelMixin"""
def setUp(self):
@@ -25,23 +28,26 @@ class TestModelCreation(TestCase):
mixin = CreateModelMixin()
mixin.resource = GroupResource
mixin.CONTENT = form_data
-
+
response = mixin.post(request)
self.assertEquals(1, Group.objects.count())
self.assertEquals('foo', response.cleaned_content.name)
-
def test_creation_with_m2m_relation(self):
class UserResource(ModelResource):
model = User
-
+
def url(self, instance):
return "/users/%i" % instance.id
group = Group(name='foo')
group.save()
- form_data = {'username': 'bar', 'password': 'baz', 'groups': [group.id]}
+ form_data = {
+ 'username': 'bar',
+ 'password': 'baz',
+ 'groups': [group.id]
+ }
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group]
@@ -53,18 +59,18 @@ class TestModelCreation(TestCase):
self.assertEquals(1, User.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo', response.cleaned_content.groups.all()[0].name)
-
+
def test_creation_with_m2m_relation_through(self):
"""
Tests creation where the m2m relation uses a through table
"""
class UserResource(ModelResource):
model = CustomUser
-
+
def url(self, instance):
return "/customusers/%i" % instance.id
-
- form_data = {'username': 'bar0', 'groups': []}
+
+ form_data = {'username': 'bar0', 'groups': []}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = []
@@ -74,12 +80,12 @@ class TestModelCreation(TestCase):
response = mixin.post(request)
self.assertEquals(1, CustomUser.objects.count())
- self.assertEquals(0, response.cleaned_content.groups.count())
+ self.assertEquals(0, response.cleaned_content.groups.count())
group = Group(name='foo1')
group.save()
- form_data = {'username': 'bar1', 'groups': [group.id]}
+ form_data = {'username': 'bar1', 'groups': [group.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group]
@@ -91,12 +97,11 @@ class TestModelCreation(TestCase):
self.assertEquals(2, CustomUser.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
-
-
+
group2 = Group(name='foo2')
- group2.save()
-
- form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}
+ group2.save()
+
+ form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group, group2]
@@ -109,5 +114,124 @@ class TestModelCreation(TestCase):
self.assertEquals(2, response.cleaned_content.groups.count())
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name)
-
+
+class MockPaginatorView(PaginatorMixin, View):
+ total = 60
+
+ def get(self, request):
+ return range(0, self.total)
+
+ def post(self, request):
+ return Response(status.CREATED, {'status': 'OK'})
+
+
+class TestPagination(TestCase):
+ def setUp(self):
+ self.req = RequestFactory()
+
+ def test_default_limit(self):
+ """ Tests if pagination works without overwriting the limit """
+ request = self.req.get('/paginator')
+ response = MockPaginatorView.as_view()(request)
+
+ content = json.loads(response.content)
+
+ self.assertEqual(response.status_code, status.OK)
+ self.assertEqual(MockPaginatorView.total, content['total'])
+ self.assertEqual(MockPaginatorView.limit, content['per_page'])
+
+ self.assertEqual(range(0, MockPaginatorView.limit), content['results'])
+
+ def test_overwriting_limit(self):
+ """ Tests if the limit can be overwritten """
+ limit = 10
+
+ request = self.req.get('/paginator')
+ response = MockPaginatorView.as_view(limit=limit)(request)
+
+ content = json.loads(response.content)
+
+ self.assertEqual(response.status_code, status.OK)
+ self.assertEqual(content['per_page'], limit)
+
+ self.assertEqual(range(0, limit), content['results'])
+
+ def test_limit_param(self):
+ """ Tests if the client can set the limit """
+ from math import ceil
+
+ limit = 5
+ num_pages = int(ceil(MockPaginatorView.total / float(limit)))
+
+ request = self.req.get('/paginator/?limit=%d' % limit)
+ response = MockPaginatorView.as_view()(request)
+
+ content = json.loads(response.content)
+
+ self.assertEqual(response.status_code, status.OK)
+ self.assertEqual(MockPaginatorView.total, content['total'])
+ self.assertEqual(limit, content['per_page'])
+ self.assertEqual(num_pages, content['pages'])
+
+ def test_exceeding_limit(self):
+ """ Makes sure the client cannot exceed the default limit """
+ from math import ceil
+
+ limit = MockPaginatorView.limit + 10
+ num_pages = int(ceil(MockPaginatorView.total / float(limit)))
+
+ request = self.req.get('/paginator/?limit=%d' % limit)
+ response = MockPaginatorView.as_view()(request)
+
+ content = json.loads(response.content)
+
+ self.assertEqual(response.status_code, status.OK)
+ self.assertEqual(MockPaginatorView.total, content['total'])
+ self.assertNotEqual(limit, content['per_page'])
+ self.assertNotEqual(num_pages, content['pages'])
+ self.assertEqual(MockPaginatorView.limit, content['per_page'])
+
+ def test_only_works_for_get(self):
+ """ Pagination should only work for GET requests """
+ request = self.req.post('/paginator', data={'content': 'spam'})
+ response = MockPaginatorView.as_view()(request)
+
+ content = json.loads(response.content)
+
+ self.assertEqual(response.status_code, status.CREATED)
+ self.assertEqual(None, content.get('per_page'))
+ self.assertEqual('OK', content['status'])
+
+ def test_non_int_page(self):
+ """ Tests that it can handle invalid values """
+ request = self.req.get('/paginator/?page=spam')
+ response = MockPaginatorView.as_view()(request)
+
+ self.assertEqual(response.status_code, status.NOT_FOUND)
+
+ def test_page_range(self):
+ """ Tests that the page range is handle correctly """
+ request = self.req.get('/paginator/?page=0')
+ response = MockPaginatorView.as_view()(request)
+ content = json.loads(response.content)
+ self.assertEqual(response.status_code, status.NOT_FOUND)
+
+ request = self.req.get('/paginator/')
+ response = MockPaginatorView.as_view()(request)
+ content = json.loads(response.content)
+ self.assertEqual(response.status_code, status.OK)
+ self.assertEqual(range(0, MockPaginatorView.limit), content['results'])
+
+ num_pages = content['pages']
+
+ request = self.req.get('/paginator/?page=%d' % num_pages)
+ response = MockPaginatorView.as_view()(request)
+ content = json.loads(response.content)
+ self.assertEqual(response.status_code, status.OK)
+ self.assertEqual(range(MockPaginatorView.limit*(num_pages-1), MockPaginatorView.total), content['results'])
+
+ request = self.req.get('/paginator/?page=%d' % (num_pages + 1,))
+ response = MockPaginatorView.as_view()(request)
+ content = json.loads(response.content)
+ self.assertEqual(response.status_code, status.NOT_FOUND)