diff options
| author | Tom Christie | 2012-11-09 05:07:34 -0800 |
|---|---|---|
| committer | Tom Christie | 2012-11-09 05:07:34 -0800 |
| commit | c7df9694b5a7a7931161f74a7c5c16d5c98d87d9 (patch) | |
| tree | d2f832ad883a51ce2bde6b1d44b0156f300612c3 /rest_framework | |
| parent | 0089f0faa716bd37ca29f9f2db98b4ab273e01f1 (diff) | |
| parent | ff1234b711b8dfb7dc1cc539fa9d2b6fd2477825 (diff) | |
| download | django-rest-framework-c7df9694b5a7a7931161f74a7c5c16d5c98d87d9.tar.bz2 | |
Merge pull request #383 from tomchristie/filtering
Support for filtering backends
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/compat.py | 34 | ||||
| -rw-r--r-- | rest_framework/filters.py | 59 | ||||
| -rw-r--r-- | rest_framework/generics.py | 10 | ||||
| -rw-r--r-- | rest_framework/mixins.py | 2 | ||||
| -rw-r--r-- | rest_framework/pagination.py | 17 | ||||
| -rw-r--r-- | rest_framework/runtests/settings.py | 1 | ||||
| -rw-r--r-- | rest_framework/settings.py | 9 | ||||
| -rw-r--r-- | rest_framework/templatetags/rest_framework.py | 25 | ||||
| -rw-r--r-- | rest_framework/tests/filterset.py | 168 | ||||
| -rw-r--r-- | rest_framework/tests/models.py | 7 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 77 | ||||
| -rw-r--r-- | rest_framework/tests/response.py | 6 | ||||
| -rw-r--r-- | rest_framework/utils/__init__.py | 1 |
13 files changed, 357 insertions, 59 deletions
diff --git a/rest_framework/compat.py b/rest_framework/compat.py index b0367a32..02e50604 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -5,6 +5,13 @@ versions of django/python, and compatbility wrappers around optional packages. # flake8: noqa import django +# django-filter is optional +try: + import django_filters +except: + django_filters = None + + # cStringIO only if it's available, otherwise StringIO try: import cStringIO as StringIO @@ -348,33 +355,6 @@ except ImportError: yaml = None -import unittest -try: - import unittest.skip -except ImportError: # python < 2.7 - from unittest import TestCase - import functools - - def skip(reason): - # Pasted from py27/lib/unittest/case.py - """ - Unconditionally skip a test. - """ - def decorator(test_item): - if not (isinstance(test_item, type) and issubclass(test_item, TestCase)): - @functools.wraps(test_item) - def skip_wrapper(*args, **kwargs): - pass - test_item = skip_wrapper - - test_item.__unittest_skip__ = True - test_item.__unittest_skip_why__ = reason - return test_item - return decorator - - unittest.skip = skip - - # xml.etree.parse only throws ParseError for python >= 2.7 try: from xml.etree import ParseError as ETParseError diff --git a/rest_framework/filters.py b/rest_framework/filters.py new file mode 100644 index 00000000..ccae4825 --- /dev/null +++ b/rest_framework/filters.py @@ -0,0 +1,59 @@ +from rest_framework.compat import django_filters + +FilterSet = django_filters and django_filters.FilterSet or None + + +class BaseFilterBackend(object): + """ + A base class from which all filter backend classes should inherit. + """ + + def filter_queryset(self, request, queryset, view): + """ + Return a filtered queryset. + """ + raise NotImplementedError(".filter_queryset() must be overridden.") + + +class DjangoFilterBackend(BaseFilterBackend): + """ + A filter backend that uses django-filter. + """ + default_filter_set = FilterSet + + def __init__(self): + assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed' + + def get_filter_class(self, view): + """ + Return the django-filters `FilterSet` used to filter the queryset. + """ + filter_class = getattr(view, 'filter_class', None) + filter_fields = getattr(view, 'filter_fields', None) + view_model = getattr(view, 'model', None) + + if filter_class: + filter_model = filter_class.Meta.model + + assert issubclass(filter_model, view_model), \ + 'FilterSet model %s does not match view model %s' % \ + (filter_model, view_model) + + return filter_class + + if filter_fields: + class AutoFilterSet(self.default_filter_set): + class Meta: + model = view_model + fields = filter_fields + return AutoFilterSet + + return None + + def filter_queryset(self, request, queryset, view): + filter_class = self.get_filter_class(view) + + if filter_class: + return filter_class(request.GET, queryset=queryset) + + return queryset diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 45cedd8b..ebd06e45 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -58,6 +58,16 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS paginate_by = api_settings.PAGINATE_BY + filter_backend = api_settings.FILTER_BACKEND + + def filter_queryset(self, queryset): + if not self.filter_backend: + return queryset + backend = self.filter_backend() + return backend.filter_queryset(self.request, queryset, self) + + def get_filtered_queryset(self): + return self.filter_queryset(self.get_queryset()) def get_pagination_serializer_class(self): """ diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 6824a4d2..c3625a88 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -34,7 +34,7 @@ class ListModelMixin(object): empty_error = u"Empty list and '%(class_name)s.allow_empty' is False." def list(self, request, *args, **kwargs): - self.object_list = self.get_queryset() + self.object_list = self.get_filtered_queryset() # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 131718fd..d241ade7 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -1,4 +1,5 @@ from rest_framework import serializers +from rest_framework.templatetags.rest_framework import replace_query_param # TODO: Support URLconf kwarg-style paging @@ -7,30 +8,30 @@ class NextPageField(serializers.Field): """ Field that returns a link to the next page in paginated results. """ + page_field = 'page' + def to_native(self, value): if not value.has_next(): return None page = value.next_page_number() request = self.context.get('request') - relative_url = '?page=%d' % page - if request: - return request.build_absolute_uri(relative_url) - return relative_url + url = request and request.build_absolute_uri() or '' + return replace_query_param(url, self.page_field, page) class PreviousPageField(serializers.Field): """ Field that returns a link to the previous page in paginated results. """ + page_field = 'page' + def to_native(self, value): if not value.has_previous(): return None page = value.previous_page_number() request = self.context.get('request') - relative_url = '?page=%d' % page - if request: - return request.build_absolute_uri('?page=%d' % page) - return relative_url + url = request and request.build_absolute_uri() or '' + return replace_query_param(url, self.page_field, page) class PaginationSerializerOptions(serializers.SerializerOptions): diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index b48f85e4..dd5d9dc3 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -107,6 +107,7 @@ import django if django.VERSION < (1, 3): INSTALLED_APPS += ('staticfiles',) + # If we're running on the Jenkins server we want to archive the coverage reports as XML. import os if os.environ.get('HUDSON_URL', None): diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 9c40a214..906a7cf6 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -55,6 +55,7 @@ DEFAULTS = { 'anon': None, }, 'PAGINATE_BY': None, + 'FILTER_BACKEND': None, 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -79,6 +80,7 @@ IMPORT_STRINGS = ( 'DEFAULT_CONTENT_NEGOTIATION_CLASS', 'DEFAULT_MODEL_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS', + 'FILTER_BACKEND', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', ) @@ -142,8 +144,15 @@ class APISettings(object): if val and attr in self.import_strings: val = perform_import(val, attr) + self.validate_setting(attr, val) + # Cache the result setattr(self, attr, val) return val + def validate_setting(self, attr, val): + if attr == 'FILTER_BACKEND' and val is not None: + # Make sure we can initilize the class + val() + api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index c9b6eb10..4e0181ee 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -11,6 +11,18 @@ import string register = template.Library() +def replace_query_param(url, key, val): + """ + Given a URL and a key/val pair, set or replace an item in the query + parameters of the URL, and return the new URL. + """ + (scheme, netloc, path, query, fragment) = urlsplit(url) + query_dict = QueryDict(query).copy() + query_dict[key] = val + query = query_dict.urlencode() + return urlunsplit((scheme, netloc, path, query, fragment)) + + # Regex for adding classes to html snippets class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') @@ -31,19 +43,6 @@ hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '| trailing_empty_content_re = re.compile(r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z') -# Helper function for 'add_query_param' -def replace_query_param(url, key, val): - """ - Given a URL and a key/val pair, set or replace an item in the query - parameters of the URL, and return the new URL. - """ - (scheme, netloc, path, query, fragment) = urlsplit(url) - query_dict = QueryDict(query).copy() - query_dict[key] = val - query = query_dict.urlencode() - return urlunsplit((scheme, netloc, path, query, fragment)) - - # And the template tags themselves... @register.simple_tag diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py new file mode 100644 index 00000000..af2e6c2e --- /dev/null +++ b/rest_framework/tests/filterset.py @@ -0,0 +1,168 @@ +import datetime +from decimal import Decimal +from django.test import TestCase +from django.test.client import RequestFactory +from django.utils import unittest +from rest_framework import generics, status, filters +from rest_framework.compat import django_filters +from rest_framework.tests.models import FilterableItem, BasicModel + +factory = RequestFactory() + + +if django_filters: + # Basic filter on a list view. + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_fields = ['decimal', 'date'] + filter_backend = filters.DjangoFilterBackend + + # These class are used to test a filter class. + class SeveralFieldsFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + decimal = django_filters.NumberFilter(lookup_type='lt') + date = django_filters.DateFilter(lookup_type='gt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterClassRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backend = filters.DjangoFilterBackend + + # These classes are used to test a misconfigured filter class. + class MisconfiguredFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + + class Meta: + model = BasicModel + fields = ['text'] + + class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = MisconfiguredFilter + filter_backend = filters.DjangoFilterBackend + + +class IntegrationTestFiltering(TestCase): + """ + Integration tests for filtered list views. + """ + + def setUp(self): + """ + Create 10 FilterableItem instances. + """ + base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) + for i in range(10): + text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. + decimal = base_data[1] + i + date = base_data[2] - datetime.timedelta(days=i * 2) + FilterableItem(text=text, decimal=decimal, date=date).save() + + self.objects = FilterableItem.objects + self.data = [ + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + for obj in self.objects.all() + ] + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_fields_root_view(self): + """ + GET requests to paginated ListCreateAPIView should return paginated results. + """ + view = FilterFieldsRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data) + + # Tests that the decimal filter works. + search_decimal = Decimal('2.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] == search_decimal] + self.assertEquals(response.data, expected_data) + + # Tests that the date filter works. + search_date = datetime.date(2012, 9, 22) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] == search_date] + self.assertEquals(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_class_root_view(self): + """ + GET requests to filtered ListCreateAPIView that have a filter_class set + should return filtered results. + """ + view = FilterClassRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data) + + # Tests that the decimal filter set with 'lt' in the filter class works. + search_decimal = Decimal('4.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] < search_decimal] + self.assertEquals(response.data, expected_data) + + # Tests that the date filter set with 'gt' in the filter class works. + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] > search_date] + self.assertEquals(response.data, expected_data) + + # Tests that the text filter set with 'icontains' in the filter class works. + search_text = 'ff' + request = factory.get('/?text=%s' % search_text) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if search_text in f['text'].lower()] + self.assertEquals(response.data, expected_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] > search_date and + f['decimal'] < search_decimal] + self.assertEquals(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_incorrectly_configured_filter(self): + """ + An error should be displayed when the filter class is misconfigured. + """ + view = IncorrectlyConfiguredRootView.as_view() + + request = factory.get('/') + self.assertRaises(AssertionError, view, request) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_unknown_filter(self): + """ + GET requests with filters that aren't configured should return 200. + """ + view = FilterFieldsRootView.as_view() + + search_integer = 10 + request = factory.get('/?integer=%s' % search_integer) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 0e23734e..a2aba5be 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -95,6 +95,13 @@ class Bookmark(RESTFrameworkModel): tags = GenericRelation(TaggedItem) +# Model to test filtering. +class FilterableItem(RESTFrameworkModel): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + # Model for regression test for #285 class Comment(RESTFrameworkModel): diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 64e8d822..713a7255 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -1,8 +1,12 @@ +import datetime +from decimal import Decimal from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory -from rest_framework import generics, status, pagination -from rest_framework.tests.models import BasicModel +from django.utils import unittest +from rest_framework import generics, status, pagination, filters +from rest_framework.compat import django_filters +from rest_framework.tests.models import BasicModel, FilterableItem factory = RequestFactory() @@ -15,6 +19,21 @@ class RootView(generics.ListCreateAPIView): paginate_by = 10 +if django_filters: + class DecimalFilter(django_filters.FilterSet): + decimal = django_filters.NumberFilter(lookup_type='lt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_class = DecimalFilter + filter_backend = filters.DjangoFilterBackend + + class IntegrationTestPagination(TestCase): """ Integration tests for paginated list views. @@ -22,7 +41,7 @@ class IntegrationTestPagination(TestCase): def setUp(self): """ - Create 26 BasicModel intances. + Create 26 BasicModel instances. """ for char in 'abcdefghijklmnopqrstuvwxyz': BasicModel(text=char * 3).save() @@ -62,6 +81,58 @@ class IntegrationTestPagination(TestCase): self.assertNotEquals(response.data['previous'], None) +class IntegrationTestPaginationAndFiltering(TestCase): + + def setUp(self): + """ + Create 50 FilterableItem instances. + """ + base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) + for i in range(26): + text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. + decimal = base_data[1] + i + date = base_data[2] - datetime.timedelta(days=i * 2) + FilterableItem(text=text, decimal=decimal, date=date).save() + + self.objects = FilterableItem.objects + self.data = [ + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + for obj in self.objects.all() + ] + self.view = FilterFieldsRootView.as_view() + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_paginated_filtered_root_view(self): + """ + GET requests to paginated filtered ListCreateAPIView should return + paginated results. The next and previous links should preserve the + filtered parameters. + """ + request = factory.get('/?decimal=15.20') + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 15) + self.assertEquals(response.data['results'], self.data[:10]) + self.assertNotEquals(response.data['next'], None) + self.assertEquals(response.data['previous'], None) + + request = factory.get(response.data['next']) + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 15) + self.assertEquals(response.data['results'], self.data[10:15]) + self.assertEquals(response.data['next'], None) + self.assertNotEquals(response.data['previous'], None) + + request = factory.get(response.data['previous']) + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 15) + self.assertEquals(response.data['results'], self.data[:10]) + self.assertNotEquals(response.data['next'], None) + self.assertEquals(response.data['previous'], None) + + class UnitTestPagination(TestCase): """ Unit tests for pagination of primative objects. diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py index 18b6af39..d7b75450 100644 --- a/rest_framework/tests/response.py +++ b/rest_framework/tests/response.py @@ -131,12 +131,6 @@ class RendererIntegrationTests(TestCase): self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.status_code, DUMMYSTATUS) - @unittest.skip('can\'t pass because view is a simple Django view and response is an ImmediateResponse') - def test_unsatisfiable_accept_header_on_request_returns_406_status(self): - """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" - resp = self.client.get('/', HTTP_ACCEPT='foo/bar') - self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) - def test_specified_renderer_serializes_content_on_format_query(self): """If a 'format' query is specified, the renderer with the matching format attribute should serialize the response.""" diff --git a/rest_framework/utils/__init__.py b/rest_framework/utils/__init__.py index a59fff45..84fcb5db 100644 --- a/rest_framework/utils/__init__.py +++ b/rest_framework/utils/__init__.py @@ -1,7 +1,6 @@ from django.utils.encoding import smart_unicode from django.utils.xmlutils import SimplerXMLGenerator from rest_framework.compat import StringIO - import re import xml.etree.ElementTree as ET |
