diff options
| author | Ben Konrath | 2012-10-08 22:00:55 +0200 |
|---|---|---|
| committer | Ben Konrath | 2012-10-11 12:01:07 +0200 |
| commit | 1e9ece0f9353515265da9b6266dc4b39775a0257 (patch) | |
| tree | ab71c69bc46d7e48e256a2cb31128b241e0b2e0f /rest_framework | |
| parent | 83f39b3dce4028ff6b2ebe0be55c2a00d67ede00 (diff) | |
| download | django-rest-framework-1e9ece0f9353515265da9b6266dc4b39775a0257.tar.bz2 | |
First attempt at adding filter support.
The filter support uses django-filter to work its magic.
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/generics.py | 35 | ||||
| -rw-r--r-- | rest_framework/mixins.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/filterset.py | 160 | ||||
| -rw-r--r-- | rest_framework/tests/models.py | 7 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 69 |
5 files changed, 268 insertions, 5 deletions
diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 59739d01..3b2bea3b 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -1,12 +1,12 @@ """ -Generic views that provide commmonly needed behaviour. +Generic views that provide commonly needed behaviour. """ from rest_framework import views, mixins from rest_framework.settings import api_settings from django.views.generic.detail import SingleObjectMixin from django.views.generic.list import MultipleObjectMixin - +import django_filters ### Base classes for the generic views ### @@ -58,6 +58,37 @@ class MultipleObjectBaseView(MultipleObjectMixin, BaseView): pagination_serializer_class = api_settings.PAGINATION_SERIALIZER paginate_by = api_settings.PAGINATE_BY + filter_class = None + filter_fields = None + + def get_filter_class(self): + """ + Return the django-filters `FilterSet` used to filter the queryset. + """ + if self.filter_class: + return self.filter_class + + if self.filter_fields: + class AutoFilterSet(django_filters.FilterSet): + class Meta: + model = self.model + fields = self.filter_fields + return AutoFilterSet + + return None + + def filter_queryset(self, queryset): + filter_class = self.get_filter_class() + + if filter_class: + assert issubclass(filter_class.Meta.model, self.model), \ + "%s is not a subclass of %s" % (filter_class.Meta.model, self.model) + return filter_class(self.request.GET, queryset=queryset) + + return queryset + + def get_filtered_queryset(self): + return self.filter_queryset(self.get_queryset()) def get_pagination_serializer_class(self): """ diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 29153e18..04626fb0 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -33,7 +33,7 @@ class ListModelMixin(object): empty_error = u"Empty list and '%(class_name)s.allow_empty' is False." def list(self, request, *args, **kwargs): - self.object_list = self.get_queryset() + self.object_list = self.get_filtered_queryset() # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py new file mode 100644 index 00000000..8c857f3f --- /dev/null +++ b/rest_framework/tests/filterset.py @@ -0,0 +1,160 @@ +import datetime +from django.test import TestCase +from django.test.client import RequestFactory +from rest_framework import generics, status +from rest_framework.tests.models import FilterableItem, BasicModel +import django_filters + +factory = RequestFactory() + +# Basic filter on a list view. +class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_fields = ['decimal', 'date'] + + +# These class are used to test a filter class. +class SeveralFieldsFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + decimal = django_filters.NumberFilter(lookup_type='lt') + date = django_filters.DateFilter(lookup_type='gt') + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + +class FilterClassRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + + +# These classes are used to test a misconfigured filter class. +class MisconfiguredFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + class Meta: + model = BasicModel + fields = ['text'] + + +class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = MisconfiguredFilter + + +class IntegrationTestFiltering(TestCase): + """ + Integration tests for filtered list views. + """ + + def setUp(self): + """ + Create 10 FilterableItem instances. + """ + base_data = ('a', 0.25, datetime.date(2012, 10, 8)) + for i in range(10): + text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. + decimal = base_data[1] + i + date = base_data[2] - datetime.timedelta(days=i * 2) + FilterableItem(text=text, decimal=decimal, date=date).save() + + self.objects = FilterableItem.objects + self.data = [ + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + for obj in self.objects.all() + ] + + def test_get_filtered_fields_root_view(self): + """ + GET requests to paginated ListCreateAPIView should return paginated results. + """ + view = FilterFieldsRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data) + + # Tests that the decimal filter works. + search_decimal = 2.25 + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [ f for f in self.data if f['decimal'] == search_decimal ] + self.assertEquals(response.data, expected_data) + + # Tests that the date filter works. + search_date = datetime.date(2012, 9, 22) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [ f for f in self.data if f['date'] == search_date ] + self.assertEquals(response.data, expected_data) + + def test_get_filtered_class_root_view(self): + """ + GET requests to filtered ListCreateAPIView that have a filter_class set + should return filtered results. + """ + view = FilterClassRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data) + + # Tests that the decimal filter set with 'lt' in the filter class works. + search_decimal = 4.25 + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [ f for f in self.data if f['decimal'] < search_decimal ] + self.assertEquals(response.data, expected_data) + + # Tests that the date filter set with 'gt' in the filter class works. + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [ f for f in self.data if f['date'] > search_date ] + self.assertEquals(response.data, expected_data) + + # Tests that the text filter set with 'icontains' in the filter class works. + search_text = 'ff' + request = factory.get('/?text=%s' % search_text) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [ f for f in self.data if search_text in f['text'].lower() ] + self.assertEquals(response.data, expected_data) + + # Tests that multiple filters works. + search_decimal = 5.25 + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [ f for f in self.data if f['date'] > search_date and + f['decimal'] < search_decimal ] + self.assertEquals(response.data, expected_data) + + def test_incorrectly_configured_filter(self): + """ + An error should be displayed when the filter class is misconfigured. + """ + view = IncorrectlyConfiguredRootView.as_view() + + request = factory.get('/') + self.assertRaises(AssertionError, view, request) + + # TODO Return 400 filter paramater requested that hasn't been configured. + def test_bad_request(self): + """ + GET requests with filters that aren't configured should return 400. + """ + view = FilterFieldsRootView.as_view() + + search_integer = 10 + request = factory.get('/?integer=%s' % search_integer) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
\ No newline at end of file diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 6a758f0c..780c9dba 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -85,6 +85,13 @@ class Bookmark(RESTFrameworkModel): tags = GenericRelation(TaggedItem) +# Model to test filtering. +class FilterableItem(RESTFrameworkModel): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + # Model for regression test for #285 class Comment(RESTFrameworkModel): diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index a939c9ef..729bbfc2 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -1,8 +1,10 @@ +import datetime from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory from rest_framework import generics, status, pagination -from rest_framework.tests.models import BasicModel +from rest_framework.tests.models import BasicModel, FilterableItem +import django_filters factory = RequestFactory() @@ -15,6 +17,19 @@ class RootView(generics.ListCreateAPIView): paginate_by = 10 +class DecimalFilter(django_filters.FilterSet): + decimal = django_filters.NumberFilter(lookup_type='lt') + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + +class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_class = DecimalFilter + + class IntegrationTestPagination(TestCase): """ Integration tests for paginated list views. @@ -22,7 +37,7 @@ class IntegrationTestPagination(TestCase): def setUp(self): """ - Create 26 BasicModel intances. + Create 26 BasicModel instances. """ for char in 'abcdefghijklmnopqrstuvwxyz': BasicModel(text=char * 3).save() @@ -61,6 +76,56 @@ class IntegrationTestPagination(TestCase): self.assertEquals(response.data['next'], None) self.assertNotEquals(response.data['previous'], None) +class IntegrationTestPaginationAndFiltering(TestCase): + + def setUp(self): + """ + Create 50 FilterableItem instances. + """ + base_data = ('a', 0.25, datetime.date(2012, 10, 8)) + for i in range(26): + text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. + decimal = base_data[1] + i + date = base_data[2] - datetime.timedelta(days=i * 2) + FilterableItem(text=text, decimal=decimal, date=date).save() + + self.objects = FilterableItem.objects + self.data = [ + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + for obj in self.objects.all() + ] + self.view = FilterFieldsRootView.as_view() + + def test_get_paginated_filtered_root_view(self): + """ + GET requests to paginated filtered ListCreateAPIView should return + paginated results. The next and previous links should preserve the + filtered parameters. + """ + request = factory.get('/?decimal=15.20') + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 15) + self.assertEquals(response.data['results'], self.data[:10]) + self.assertNotEquals(response.data['next'], None) + self.assertEquals(response.data['previous'], None) + + request = factory.get(response.data['next']) + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 15) + self.assertEquals(response.data['results'], self.data[10:15]) + self.assertEquals(response.data['next'], None) + self.assertNotEquals(response.data['previous'], None) + + request = factory.get(response.data['previous']) + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 15) + self.assertEquals(response.data['results'], self.data[:10]) + self.assertNotEquals(response.data['next'], None) + self.assertEquals(response.data['previous'], None) + class UnitTestPagination(TestCase): """ |
