aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2012-11-07 20:13:27 +0000
committerTom Christie2012-11-07 20:13:27 +0000
commit9fd061a0b68f0cef6683bf195911a2cc7ff2fa06 (patch)
tree60769e37be41acf4bf12ba4ad59737e57c55da6a
parent066d51faa16d1cfd3c8370c6bfe46f8494bbc26a (diff)
parent09f39bd23b3c688c89551845d665395e1aabbfab (diff)
downloaddjango-rest-framework-9fd061a0b68f0cef6683bf195911a2cc7ff2fa06.tar.bz2
Merge branch 'restframework2-filter' of git://github.com/onepercentclub/django-rest-framework into filtering
-rw-r--r--.travis.yml1
-rw-r--r--requirements.txt1
-rw-r--r--rest_framework/generics.py33
-rw-r--r--rest_framework/mixins.py2
-rw-r--r--rest_framework/pagination.py20
-rw-r--r--rest_framework/tests/filterset.py160
-rw-r--r--rest_framework/tests/models.py7
-rw-r--r--rest_framework/tests/pagination.py71
-rw-r--r--tox.ini6
9 files changed, 292 insertions, 9 deletions
diff --git a/.travis.yml b/.travis.yml
index 0e177a95..fa8693a0 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -11,6 +11,7 @@ env:
install:
- pip install $DJANGO
+ - pip install -r requirements.txt --use-mirrors
- export PYTHONPATH=.
script:
diff --git a/requirements.txt b/requirements.txt
index 730c1d07..48ff9d65 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1,2 @@
Django>=1.3
+-e git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 45cedd8b..ac02d3da 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -6,7 +6,7 @@ from rest_framework import views, mixins
from rest_framework.settings import api_settings
from django.views.generic.detail import SingleObjectMixin
from django.views.generic.list import MultipleObjectMixin
-
+import django_filters
### Base classes for the generic views ###
@@ -58,6 +58,37 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
paginate_by = api_settings.PAGINATE_BY
+ filter_class = None
+ filter_fields = None
+
+ def get_filter_class(self):
+ """
+ Return the django-filters `FilterSet` used to filter the queryset.
+ """
+ if self.filter_class:
+ return self.filter_class
+
+ if self.filter_fields:
+ class AutoFilterSet(django_filters.FilterSet):
+ class Meta:
+ model = self.model
+ fields = self.filter_fields
+ return AutoFilterSet
+
+ return None
+
+ def filter_queryset(self, queryset):
+ filter_class = self.get_filter_class()
+
+ if filter_class:
+ assert issubclass(filter_class.Meta.model, self.model), \
+ "%s is not a subclass of %s" % (filter_class.Meta.model, self.model)
+ return filter_class(self.request.GET, queryset=queryset)
+
+ return queryset
+
+ def get_filtered_queryset(self):
+ return self.filter_queryset(self.get_queryset())
def get_pagination_serializer_class(self):
"""
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 6824a4d2..c3625a88 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -34,7 +34,7 @@ class ListModelMixin(object):
empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs):
- self.object_list = self.get_queryset()
+ self.object_list = self.get_filtered_queryset()
# Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets.
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 131718fd..c77a1005 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -3,7 +3,11 @@ from rest_framework import serializers
# TODO: Support URLconf kwarg-style paging
-class NextPageField(serializers.Field):
+class PageField(serializers.Field):
+ page_field = 'page'
+
+
+class NextPageField(PageField):
"""
Field that returns a link to the next page in paginated results.
"""
@@ -12,13 +16,16 @@ class NextPageField(serializers.Field):
return None
page = value.next_page_number()
request = self.context.get('request')
- relative_url = '?page=%d' % page
+ relative_url = '?%s=%d' % (self.page_field, page)
if request:
+ for field, value in request.QUERY_PARAMS.iteritems():
+ if field != self.page_field:
+ relative_url += '&%s=%s' % (field, value)
return request.build_absolute_uri(relative_url)
return relative_url
-class PreviousPageField(serializers.Field):
+class PreviousPageField(PageField):
"""
Field that returns a link to the previous page in paginated results.
"""
@@ -27,9 +34,12 @@ class PreviousPageField(serializers.Field):
return None
page = value.previous_page_number()
request = self.context.get('request')
- relative_url = '?page=%d' % page
+ relative_url = '?%s=%d' % (self.page_field, page)
if request:
- return request.build_absolute_uri('?page=%d' % page)
+ for field, value in request.QUERY_PARAMS.iteritems():
+ if field != self.page_field:
+ relative_url += '&%s=%s' % (field, value)
+ return request.build_absolute_uri(relative_url)
return relative_url
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
new file mode 100644
index 00000000..5374eefc
--- /dev/null
+++ b/rest_framework/tests/filterset.py
@@ -0,0 +1,160 @@
+import datetime
+from decimal import Decimal
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework import generics, status
+from rest_framework.tests.models import FilterableItem, BasicModel
+import django_filters
+
+factory = RequestFactory()
+
+# Basic filter on a list view.
+class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_fields = ['decimal', 'date']
+
+
+# These class are used to test a filter class.
+class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+
+class FilterClassRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+
+
+# These classes are used to test a misconfigured filter class.
+class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ class Meta:
+ model = BasicModel
+ fields = ['text']
+
+
+class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = MisconfiguredFilter
+
+
+class IntegrationTestFiltering(TestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
+ def setUp(self):
+ """
+ Create 10 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(10):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
+ ]
+
+ def test_get_filtered_fields_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [ f for f in self.data if f['decimal'] == search_decimal ]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the date filter works.
+ search_date = datetime.date(2012, 9, 22)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [ f for f in self.data if f['date'] == search_date ]
+ self.assertEquals(response.data, expected_data)
+
+ def test_get_filtered_class_root_view(self):
+ """
+ GET requests to filtered ListCreateAPIView that have a filter_class set
+ should return filtered results.
+ """
+ view = FilterClassRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ # Tests that the decimal filter set with 'lt' in the filter class works.
+ search_decimal = Decimal('4.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [ f for f in self.data if f['decimal'] < search_decimal ]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the date filter set with 'gt' in the filter class works.
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [ f for f in self.data if f['date'] > search_date ]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the text filter set with 'icontains' in the filter class works.
+ search_text = 'ff'
+ request = factory.get('/?text=%s' % search_text)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [ f for f in self.data if search_text in f['text'].lower() ]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [ f for f in self.data if f['date'] > search_date and
+ f['decimal'] < search_decimal ]
+ self.assertEquals(response.data, expected_data)
+
+ def test_incorrectly_configured_filter(self):
+ """
+ An error should be displayed when the filter class is misconfigured.
+ """
+ view = IncorrectlyConfiguredRootView.as_view()
+
+ request = factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ def test_unknown_filter(self):
+ """
+ GET requests with filters that aren't configured should return 200.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ search_integer = 10
+ request = factory.get('/?integer=%s' % search_integer)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK) \ No newline at end of file
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 0e23734e..a2aba5be 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -95,6 +95,13 @@ class Bookmark(RESTFrameworkModel):
tags = GenericRelation(TaggedItem)
+# Model to test filtering.
+class FilterableItem(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
# Model for regression test for #285
class Comment(RESTFrameworkModel):
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 64e8d822..7a2134e0 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -1,8 +1,11 @@
+import datetime
+from decimal import Decimal
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status, pagination
-from rest_framework.tests.models import BasicModel
+from rest_framework.tests.models import BasicModel, FilterableItem
+import django_filters
factory = RequestFactory()
@@ -15,6 +18,19 @@ class RootView(generics.ListCreateAPIView):
paginate_by = 10
+class DecimalFilter(django_filters.FilterSet):
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+
+class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_class = DecimalFilter
+
+
class IntegrationTestPagination(TestCase):
"""
Integration tests for paginated list views.
@@ -22,7 +38,7 @@ class IntegrationTestPagination(TestCase):
def setUp(self):
"""
- Create 26 BasicModel intances.
+ Create 26 BasicModel instances.
"""
for char in 'abcdefghijklmnopqrstuvwxyz':
BasicModel(text=char * 3).save()
@@ -62,6 +78,57 @@ class IntegrationTestPagination(TestCase):
self.assertNotEquals(response.data['previous'], None)
+class IntegrationTestPaginationAndFiltering(TestCase):
+
+ def setUp(self):
+ """
+ Create 50 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(26):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
+ ]
+ self.view = FilterFieldsRootView.as_view()
+
+ def test_get_paginated_filtered_root_view(self):
+ """
+ GET requests to paginated filtered ListCreateAPIView should return
+ paginated results. The next and previous links should preserve the
+ filtered parameters.
+ """
+ request = factory.get('/?decimal=15.20')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[10:15])
+ self.assertEquals(response.data['next'], None)
+ self.assertNotEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+
class UnitTestPagination(TestCase):
"""
Unit tests for pagination of primative objects.
diff --git a/tox.ini b/tox.ini
index bcfff672..3596bbdc 100644
--- a/tox.ini
+++ b/tox.ini
@@ -8,23 +8,29 @@ commands = {envpython} rest_framework/runtests/runtests.py
[testenv:py2.7-django1.5]
basepython = python2.7
deps = https://github.com/django/django/zipball/master
+ git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter
[testenv:py2.7-django1.4]
basepython = python2.7
deps = django==1.4.1
+ git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter
[testenv:py2.7-django1.3]
basepython = python2.7
deps = django==1.3.3
+ git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter
[testenv:py2.6-django1.5]
basepython = python2.6
deps = https://github.com/django/django/zipball/master
+ git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter
[testenv:py2.6-django1.4]
basepython = python2.6
deps = django==1.4.1
+ git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter
[testenv:py2.6-django1.3]
basepython = python2.6
deps = django==1.3.3
+ git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter