diff options
| author | Tom Christie | 2012-11-07 21:07:24 +0000 |
|---|---|---|
| committer | Tom Christie | 2012-11-07 21:07:24 +0000 |
| commit | 47b534a13e42d498629bf9522225633122c563d5 (patch) | |
| tree | fc7acddb14038fc5f159c1399dac7974a76caf4b /rest_framework | |
| parent | 9fd061a0b68f0cef6683bf195911a2cc7ff2fa06 (diff) | |
| download | django-rest-framework-47b534a13e42d498629bf9522225633122c563d5.tar.bz2 | |
Make filtering optional, and pluggable.
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/compat.py | 34 | ||||
| -rw-r--r-- | rest_framework/filters.py | 52 | ||||
| -rw-r--r-- | rest_framework/generics.py | 33 | ||||
| -rw-r--r-- | rest_framework/pagination.py | 19 | ||||
| -rw-r--r-- | rest_framework/settings.py | 2 | ||||
| -rw-r--r-- | rest_framework/templatetags/rest_framework.py | 27 | ||||
| -rw-r--r-- | rest_framework/tests/filterset.py | 71 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 23 | ||||
| -rw-r--r-- | rest_framework/tests/response.py | 6 | ||||
| -rw-r--r-- | rest_framework/utils/__init__.py | 1 |
10 files changed, 136 insertions, 132 deletions
diff --git a/rest_framework/compat.py b/rest_framework/compat.py index b0367a32..02e50604 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -5,6 +5,13 @@ versions of django/python, and compatbility wrappers around optional packages. # flake8: noqa import django +# django-filter is optional +try: + import django_filters +except: + django_filters = None + + # cStringIO only if it's available, otherwise StringIO try: import cStringIO as StringIO @@ -348,33 +355,6 @@ except ImportError: yaml = None -import unittest -try: - import unittest.skip -except ImportError: # python < 2.7 - from unittest import TestCase - import functools - - def skip(reason): - # Pasted from py27/lib/unittest/case.py - """ - Unconditionally skip a test. - """ - def decorator(test_item): - if not (isinstance(test_item, type) and issubclass(test_item, TestCase)): - @functools.wraps(test_item) - def skip_wrapper(*args, **kwargs): - pass - test_item = skip_wrapper - - test_item.__unittest_skip__ = True - test_item.__unittest_skip_why__ = reason - return test_item - return decorator - - unittest.skip = skip - - # xml.etree.parse only throws ParseError for python >= 2.7 try: from xml.etree import ParseError as ETParseError diff --git a/rest_framework/filters.py b/rest_framework/filters.py new file mode 100644 index 00000000..b972e82a --- /dev/null +++ b/rest_framework/filters.py @@ -0,0 +1,52 @@ +from rest_framework.compat import django_filters + + +class BaseFilterBackend(object): + """ + A base class from which all filter backend classes should inherit. + """ + + def filter_queryset(self, request, queryset, view): + """ + Return a filtered queryset. + """ + raise NotImplementedError(".filter_queryset() must be overridden.") + + +class DjangoFilterBackend(BaseFilterBackend): + """ + A filter backend that uses django-filter. + """ + + def get_filter_class(self, view): + """ + Return the django-filters `FilterSet` used to filter the queryset. + """ + filter_class = getattr(view, 'filter_class', None) + filter_fields = getattr(view, 'filter_fields', None) + filter_model = getattr(view, 'model', None) + + if filter_class or filter_fields: + assert django_filters, 'django-filter is not installed' + + if filter_class: + assert issubclass(filter_class.Meta.model, filter_model), \ + '%s is not a subclass of %s' % (filter_class.Meta.model, filter_model) + return filter_class + + if filter_fields: + class AutoFilterSet(django_filters.FilterSet): + class Meta: + model = filter_model + fields = filter_fields + return AutoFilterSet + + return None + + def filter_queryset(self, request, queryset, view): + filter_class = self.get_filter_class(view) + + if filter_class: + return filter_class(request.GET, queryset=queryset) + + return queryset diff --git a/rest_framework/generics.py b/rest_framework/generics.py index ac02d3da..ebd06e45 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -6,7 +6,7 @@ from rest_framework import views, mixins from rest_framework.settings import api_settings from django.views.generic.detail import SingleObjectMixin from django.views.generic.list import MultipleObjectMixin -import django_filters + ### Base classes for the generic views ### @@ -58,34 +58,13 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS paginate_by = api_settings.PAGINATE_BY - filter_class = None - filter_fields = None - - def get_filter_class(self): - """ - Return the django-filters `FilterSet` used to filter the queryset. - """ - if self.filter_class: - return self.filter_class - - if self.filter_fields: - class AutoFilterSet(django_filters.FilterSet): - class Meta: - model = self.model - fields = self.filter_fields - return AutoFilterSet - - return None + filter_backend = api_settings.FILTER_BACKEND def filter_queryset(self, queryset): - filter_class = self.get_filter_class() - - if filter_class: - assert issubclass(filter_class.Meta.model, self.model), \ - "%s is not a subclass of %s" % (filter_class.Meta.model, self.model) - return filter_class(self.request.GET, queryset=queryset) - - return queryset + if not self.filter_backend: + return queryset + backend = self.filter_backend() + return backend.filter_queryset(self.request, queryset, self) def get_filtered_queryset(self): return self.filter_queryset(self.get_queryset()) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index c77a1005..aa54d154 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -1,4 +1,5 @@ from rest_framework import serializers +from rest_framework.templatetags.rest_framework import replace_query_param # TODO: Support URLconf kwarg-style paging @@ -16,13 +17,8 @@ class NextPageField(PageField): return None page = value.next_page_number() request = self.context.get('request') - relative_url = '?%s=%d' % (self.page_field, page) - if request: - for field, value in request.QUERY_PARAMS.iteritems(): - if field != self.page_field: - relative_url += '&%s=%s' % (field, value) - return request.build_absolute_uri(relative_url) - return relative_url + url = request and request.get_full_path() or '' + return replace_query_param(url, self.page_field, page) class PreviousPageField(PageField): @@ -34,13 +30,8 @@ class PreviousPageField(PageField): return None page = value.previous_page_number() request = self.context.get('request') - relative_url = '?%s=%d' % (self.page_field, page) - if request: - for field, value in request.QUERY_PARAMS.iteritems(): - if field != self.page_field: - relative_url += '&%s=%s' % (field, value) - return request.build_absolute_uri(relative_url) - return relative_url + url = request and request.get_full_path() or '' + return replace_query_param(url, self.page_field, page) class PaginationSerializerOptions(serializers.SerializerOptions): diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 9c40a214..da647658 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -55,6 +55,7 @@ DEFAULTS = { 'anon': None, }, 'PAGINATE_BY': None, + 'FILTER_BACKEND': 'rest_framework.filters.DjangoFilterBackend', 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -79,6 +80,7 @@ IMPORT_STRINGS = ( 'DEFAULT_CONTENT_NEGOTIATION_CLASS', 'DEFAULT_MODEL_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS', + 'FILTER_BACKEND', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', ) diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index c9b6eb10..0672ee4f 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -1,9 +1,9 @@ from django import template from django.core.urlresolvers import reverse -from django.http import QueryDict from django.utils.encoding import force_unicode from django.utils.html import escape from django.utils.safestring import SafeData, mark_safe +from django.http import QueryDict from urlparse import urlsplit, urlunsplit import re import string @@ -11,6 +11,18 @@ import string register = template.Library() +def replace_query_param(url, key, val): + """ + Given a URL and a key/val pair, set or replace an item in the query + parameters of the URL, and return the new URL. + """ + (scheme, netloc, path, query, fragment) = urlsplit(url) + query_dict = QueryDict(query).copy() + query_dict[key] = val + query = query_dict.urlencode() + return urlunsplit((scheme, netloc, path, query, fragment)) + + # Regex for adding classes to html snippets class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') @@ -31,19 +43,6 @@ hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '| trailing_empty_content_re = re.compile(r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z') -# Helper function for 'add_query_param' -def replace_query_param(url, key, val): - """ - Given a URL and a key/val pair, set or replace an item in the query - parameters of the URL, and return the new URL. - """ - (scheme, netloc, path, query, fragment) = urlsplit(url) - query_dict = QueryDict(query).copy() - query_dict[key] = val - query = query_dict.urlencode() - return urlunsplit((scheme, netloc, path, query, fragment)) - - # And the template tags themselves... @register.simple_tag diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 5374eefc..6cdea32f 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -2,44 +2,45 @@ import datetime from decimal import Decimal from django.test import TestCase from django.test.client import RequestFactory +from django.utils import unittest from rest_framework import generics, status +from rest_framework.compat import django_filters from rest_framework.tests.models import FilterableItem, BasicModel -import django_filters factory = RequestFactory() -# Basic filter on a list view. -class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_fields = ['decimal', 'date'] - -# These class are used to test a filter class. -class SeveralFieldsFilter(django_filters.FilterSet): - text = django_filters.CharFilter(lookup_type='icontains') - decimal = django_filters.NumberFilter(lookup_type='lt') - date = django_filters.DateFilter(lookup_type='gt') - class Meta: +if django_filters: + # Basic filter on a list view. + class FilterFieldsRootView(generics.ListCreateAPIView): model = FilterableItem - fields = ['text', 'decimal', 'date'] + filter_fields = ['decimal', 'date'] + # These class are used to test a filter class. + class SeveralFieldsFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + decimal = django_filters.NumberFilter(lookup_type='lt') + date = django_filters.DateFilter(lookup_type='gt') -class FilterClassRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_class = SeveralFieldsFilter + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + class FilterClassRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter -# These classes are used to test a misconfigured filter class. -class MisconfiguredFilter(django_filters.FilterSet): - text = django_filters.CharFilter(lookup_type='icontains') - class Meta: - model = BasicModel - fields = ['text'] + # These classes are used to test a misconfigured filter class. + class MisconfiguredFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + class Meta: + model = BasicModel + fields = ['text'] -class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_class = MisconfiguredFilter + class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = MisconfiguredFilter class IntegrationTestFiltering(TestCase): @@ -64,6 +65,7 @@ class IntegrationTestFiltering(TestCase): for obj in self.objects.all() ] + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_fields_root_view(self): """ GET requests to paginated ListCreateAPIView should return paginated results. @@ -81,7 +83,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?decimal=%s' % search_decimal) response = view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [ f for f in self.data if f['decimal'] == search_decimal ] + expected_data = [f for f in self.data if f['decimal'] == search_decimal] self.assertEquals(response.data, expected_data) # Tests that the date filter works. @@ -89,9 +91,10 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' response = view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [ f for f in self.data if f['date'] == search_date ] + expected_data = [f for f in self.data if f['date'] == search_date] self.assertEquals(response.data, expected_data) + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_class_root_view(self): """ GET requests to filtered ListCreateAPIView that have a filter_class set @@ -110,7 +113,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?decimal=%s' % search_decimal) response = view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [ f for f in self.data if f['decimal'] < search_decimal ] + expected_data = [f for f in self.data if f['decimal'] < search_decimal] self.assertEquals(response.data, expected_data) # Tests that the date filter set with 'gt' in the filter class works. @@ -118,7 +121,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' response = view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [ f for f in self.data if f['date'] > search_date ] + expected_data = [f for f in self.data if f['date'] > search_date] self.assertEquals(response.data, expected_data) # Tests that the text filter set with 'icontains' in the filter class works. @@ -126,7 +129,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?text=%s' % search_text) response = view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [ f for f in self.data if search_text in f['text'].lower() ] + expected_data = [f for f in self.data if search_text in f['text'].lower()] self.assertEquals(response.data, expected_data) # Tests that multiple filters works. @@ -135,10 +138,11 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) response = view(request).render() self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [ f for f in self.data if f['date'] > search_date and - f['decimal'] < search_decimal ] + expected_data = [f for f in self.data if f['date'] > search_date and + f['decimal'] < search_decimal] self.assertEquals(response.data, expected_data) + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_incorrectly_configured_filter(self): """ An error should be displayed when the filter class is misconfigured. @@ -148,6 +152,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/') self.assertRaises(AssertionError, view, request) + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_unknown_filter(self): """ GET requests with filters that aren't configured should return 200. @@ -157,4 +162,4 @@ class IntegrationTestFiltering(TestCase): search_integer = 10 request = factory.get('/?integer=%s' % search_integer) response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK)
\ No newline at end of file + self.assertEquals(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 7a2134e0..7f8cd524 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -3,9 +3,10 @@ from decimal import Decimal from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory +from django.utils import unittest from rest_framework import generics, status, pagination +from rest_framework.compat import django_filters from rest_framework.tests.models import BasicModel, FilterableItem -import django_filters factory = RequestFactory() @@ -18,17 +19,18 @@ class RootView(generics.ListCreateAPIView): paginate_by = 10 -class DecimalFilter(django_filters.FilterSet): - decimal = django_filters.NumberFilter(lookup_type='lt') - class Meta: - model = FilterableItem - fields = ['text', 'decimal', 'date'] +if django_filters: + class DecimalFilter(django_filters.FilterSet): + decimal = django_filters.NumberFilter(lookup_type='lt') + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] -class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem - paginate_by = 10 - filter_class = DecimalFilter + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_class = DecimalFilter class IntegrationTestPagination(TestCase): @@ -98,6 +100,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): ] self.view = FilterFieldsRootView.as_view() + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_paginated_filtered_root_view(self): """ GET requests to paginated filtered ListCreateAPIView should return diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py index 18b6af39..d7b75450 100644 --- a/rest_framework/tests/response.py +++ b/rest_framework/tests/response.py @@ -131,12 +131,6 @@ class RendererIntegrationTests(TestCase): self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.status_code, DUMMYSTATUS) - @unittest.skip('can\'t pass because view is a simple Django view and response is an ImmediateResponse') - def test_unsatisfiable_accept_header_on_request_returns_406_status(self): - """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" - resp = self.client.get('/', HTTP_ACCEPT='foo/bar') - self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) - def test_specified_renderer_serializes_content_on_format_query(self): """If a 'format' query is specified, the renderer with the matching format attribute should serialize the response.""" diff --git a/rest_framework/utils/__init__.py b/rest_framework/utils/__init__.py index a59fff45..84fcb5db 100644 --- a/rest_framework/utils/__init__.py +++ b/rest_framework/utils/__init__.py @@ -1,7 +1,6 @@ from django.utils.encoding import smart_unicode from django.utils.xmlutils import SimplerXMLGenerator from rest_framework.compat import StringIO - import re import xml.etree.ElementTree as ET |
