aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2012-09-30 17:31:28 +0100
committerTom Christie2012-09-30 17:31:28 +0100
commit6fa589fefd48d98e4f0a11548b6c3e5ced58e31e (patch)
treea5d1dd75ce6d5c5be7bf81b386a29c13235d33ff
parent43d3634e892e303ca377265d3176e8313f19563f (diff)
downloaddjango-rest-framework-6fa589fefd48d98e4f0a11548b6c3e5ced58e31e.tar.bz2
Pagination support
-rw-r--r--rest_framework/fields.py14
-rw-r--r--rest_framework/generics.py31
-rw-r--r--rest_framework/mixins.py21
-rw-r--r--rest_framework/pagination.py34
-rw-r--r--rest_framework/settings.py6
-rw-r--r--rest_framework/templatetags/rest_framework.py2
-rw-r--r--rest_framework/tests/pagination.py57
7 files changed, 151 insertions, 14 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index eab90617..74675ee9 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -139,7 +139,13 @@ class Field(object):
if hasattr(self, 'model_field'):
return self.to_native(self.model_field._get_val_from_obj(obj))
- return self.to_native(getattr(obj, self.source or field_name))
+ if self.source:
+ value = obj
+ for component in self.source.split('.'):
+ value = getattr(value, component)
+ else:
+ value = getattr(obj, field_name)
+ return self.to_native(value)
def to_native(self, value):
"""
@@ -175,7 +181,7 @@ class RelatedField(Field):
"""
def field_to_native(self, obj, field_name):
- obj = getattr(obj, field_name)
+ obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
return [self.to_native(item) for item in obj.all()]
return self.to_native(obj)
@@ -215,10 +221,10 @@ class PrimaryKeyRelatedField(RelatedField):
def field_to_native(self, obj, field_name):
try:
- obj = obj.serializable_value(field_name)
+ obj = obj.serializable_value(self.source or field_name)
except AttributeError:
field = obj._meta.get_field_by_name(field_name)[0]
- obj = getattr(obj, field_name)
+ obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ == 'RelatedManager':
return [self.to_native(item.pk) for item in obj.all()]
elif isinstance(field, RelatedObject):
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 4240e33e..1e547b32 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -2,7 +2,8 @@
Generic views that provide commmonly needed behaviour.
"""
-from rest_framework import views, mixins, serializers
+from rest_framework import views, mixins
+from rest_framework.settings import api_settings
from django.views.generic.detail import SingleObjectMixin
from django.views.generic.list import MultipleObjectMixin
@@ -14,23 +15,37 @@ class BaseView(views.APIView):
Base class for all other generic views.
"""
serializer_class = None
+ model_serializer_class = api_settings.MODEL_SERIALIZER
+ pagination_serializer_class = api_settings.PAGINATION_SERIALIZER
+ paginate_by = api_settings.PAGINATE_BY
- def get_serializer(self, data=None, files=None, instance=None):
+ def get_serializer_context(self):
+ return {
+ 'request': self.request,
+ 'format': self.kwargs.get('format', None)
+ }
+
+ def get_serializer(self, data=None, files=None, instance=None, kwargs=None):
# TODO: add support for files
# TODO: add support for seperate serializer/deserializer
serializer_class = self.serializer_class
+ kwargs = kwargs or {}
if serializer_class is None:
- class DefaultSerializer(serializers.ModelSerializer):
+ class DefaultSerializer(self.model_serializer_class):
class Meta:
model = self.model
serializer_class = DefaultSerializer
- context = {
- 'request': self.request,
- 'format': self.kwargs.get('format', None)
- }
- return serializer_class(data, instance=instance, context=context)
+ context = self.get_serializer_context()
+ return serializer_class(data, instance=instance, context=context, **kwargs)
+
+ def get_pagination_serializer(self, page=None):
+ serializer_class = self.pagination_serializer_class
+ context = self.get_serializer_context()
+ ret = serializer_class(instance=page, context=context)
+ ret.fields['results'] = self.get_serializer(kwargs={'source': 'object_list'})
+ return ret
class MultipleObjectBaseView(MultipleObjectMixin, BaseView):
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index fe12dc8f..167cd89a 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -7,6 +7,7 @@ which allows mixin classes to be composed in interesting ways.
Eg. Use mixins to build a Resource class, and have a Router class
perform the binding of http methods to actions for us.
"""
+from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
@@ -30,9 +31,27 @@ class ListModelMixin(object):
List a queryset.
Should be mixed in with `MultipleObjectBaseView`.
"""
+ empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
+
def list(self, request, *args, **kwargs):
self.object_list = self.get_queryset()
- serializer = self.get_serializer(instance=self.object_list)
+
+ # Default is to allow empty querysets. This can be altered by setting
+ # `.allow_empty = False`, to raise 404 errors on empty querysets.
+ allow_empty = self.get_allow_empty()
+ if not allow_empty and len(self.object_list) == 0:
+ error_args = {'class_name': self.__class__.__name__}
+ raise Http404(self.empty_error % error_args)
+
+ # Pagination size is set by the `.paginate_by` attribute,
+ # which may be `None` to disable pagination.
+ page_size = self.get_paginate_by(self.object_list)
+ if page_size:
+ paginator, page, queryset, is_paginated = self.paginate_queryset(self.object_list, page_size)
+ serializer = self.get_pagination_serializer(page)
+ else:
+ serializer = self.get_serializer(instance=self.object_list)
+
return Response(serializer.data)
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
new file mode 100644
index 00000000..398e6f3d
--- /dev/null
+++ b/rest_framework/pagination.py
@@ -0,0 +1,34 @@
+from rest_framework import serializers
+
+# TODO: Support URLconf kwarg-style paging
+
+
+class NextPageField(serializers.Field):
+ def to_native(self, value):
+ if not value.has_next():
+ return None
+ page = value.next_page_number()
+ request = self.context['request']
+ return request.build_absolute_uri('?page=%d' % page)
+
+
+class PreviousPageField(serializers.Field):
+ def to_native(self, value):
+ if not value.has_previous():
+ return None
+ page = value.previous_page_number()
+ request = self.context['request']
+ return request.build_absolute_uri('?page=%d' % page)
+
+
+class PaginationSerializer(serializers.Serializer):
+ count = serializers.Field(source='paginator.count')
+ next = NextPageField(source='*')
+ previous = PreviousPageField(source='*')
+
+ def to_native(self, obj):
+ """
+ Prevent default behaviour of iterating over elements, and serializing
+ each in turn.
+ """
+ return self.convert_object(obj)
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index cfc89fe1..8387fd29 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -44,6 +44,10 @@ DEFAULTS = {
'anon': None,
},
+ 'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer',
+ 'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer',
+ 'PAGINATE_BY': 20,
+
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
@@ -65,6 +69,8 @@ IMPORT_STRINGS = (
'DEFAULT_PERMISSIONS',
'DEFAULT_THROTTLES',
'DEFAULT_CONTENT_NEGOTIATION',
+ 'MODEL_SERIALIZER',
+ 'PAGINATION_SERIALIZER',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
)
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index 377fd489..c9b6eb10 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -1,5 +1,5 @@
from django import template
-from django.core.urlresolvers import reverse, NoReverseMatch
+from django.core.urlresolvers import reverse
from django.http import QueryDict
from django.utils.encoding import force_unicode
from django.utils.html import escape
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
new file mode 100644
index 00000000..4ddfc915
--- /dev/null
+++ b/rest_framework/tests/pagination.py
@@ -0,0 +1,57 @@
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework import generics, status
+from rest_framework.tests.models import BasicModel
+
+factory = RequestFactory()
+
+
+class RootView(generics.RootAPIView):
+ """
+ Example description for OPTIONS.
+ """
+ model = BasicModel
+ paginate_by = 10
+
+
+class TestPaginatedView(TestCase):
+ def setUp(self):
+ """
+ Create 26 BasicModel intances.
+ """
+ for char in 'abcdefghijklmnopqrstuvwxyz':
+ BasicModel(text=char * 3).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = RootView.as_view()
+
+ def test_get_paginated_root_view(self):
+ """
+ GET requests to paginated RootAPIView should return paginated results.
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 26)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 26)
+ self.assertEquals(response.data['results'], self.data[10:20])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertNotEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 26)
+ self.assertEquals(response.data['results'], self.data[20:])
+ self.assertEquals(response.data['next'], None)
+ self.assertNotEquals(response.data['previous'], None)