aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2012-11-07 21:07:24 +0000
committerTom Christie2012-11-07 21:07:24 +0000
commit47b534a13e42d498629bf9522225633122c563d5 (patch)
treefc7acddb14038fc5f159c1399dac7974a76caf4b /rest_framework
parent9fd061a0b68f0cef6683bf195911a2cc7ff2fa06 (diff)
downloaddjango-rest-framework-47b534a13e42d498629bf9522225633122c563d5.tar.bz2
Make filtering optional, and pluggable.
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/compat.py34
-rw-r--r--rest_framework/filters.py52
-rw-r--r--rest_framework/generics.py33
-rw-r--r--rest_framework/pagination.py19
-rw-r--r--rest_framework/settings.py2
-rw-r--r--rest_framework/templatetags/rest_framework.py27
-rw-r--r--rest_framework/tests/filterset.py71
-rw-r--r--rest_framework/tests/pagination.py23
-rw-r--r--rest_framework/tests/response.py6
-rw-r--r--rest_framework/utils/__init__.py1
10 files changed, 136 insertions, 132 deletions
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index b0367a32..02e50604 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -5,6 +5,13 @@ versions of django/python, and compatbility wrappers around optional packages.
# flake8: noqa
import django
+# django-filter is optional
+try:
+ import django_filters
+except:
+ django_filters = None
+
+
# cStringIO only if it's available, otherwise StringIO
try:
import cStringIO as StringIO
@@ -348,33 +355,6 @@ except ImportError:
yaml = None
-import unittest
-try:
- import unittest.skip
-except ImportError: # python < 2.7
- from unittest import TestCase
- import functools
-
- def skip(reason):
- # Pasted from py27/lib/unittest/case.py
- """
- Unconditionally skip a test.
- """
- def decorator(test_item):
- if not (isinstance(test_item, type) and issubclass(test_item, TestCase)):
- @functools.wraps(test_item)
- def skip_wrapper(*args, **kwargs):
- pass
- test_item = skip_wrapper
-
- test_item.__unittest_skip__ = True
- test_item.__unittest_skip_why__ = reason
- return test_item
- return decorator
-
- unittest.skip = skip
-
-
# xml.etree.parse only throws ParseError for python >= 2.7
try:
from xml.etree import ParseError as ETParseError
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
new file mode 100644
index 00000000..b972e82a
--- /dev/null
+++ b/rest_framework/filters.py
@@ -0,0 +1,52 @@
+from rest_framework.compat import django_filters
+
+
+class BaseFilterBackend(object):
+ """
+ A base class from which all filter backend classes should inherit.
+ """
+
+ def filter_queryset(self, request, queryset, view):
+ """
+ Return a filtered queryset.
+ """
+ raise NotImplementedError(".filter_queryset() must be overridden.")
+
+
+class DjangoFilterBackend(BaseFilterBackend):
+ """
+ A filter backend that uses django-filter.
+ """
+
+ def get_filter_class(self, view):
+ """
+ Return the django-filters `FilterSet` used to filter the queryset.
+ """
+ filter_class = getattr(view, 'filter_class', None)
+ filter_fields = getattr(view, 'filter_fields', None)
+ filter_model = getattr(view, 'model', None)
+
+ if filter_class or filter_fields:
+ assert django_filters, 'django-filter is not installed'
+
+ if filter_class:
+ assert issubclass(filter_class.Meta.model, filter_model), \
+ '%s is not a subclass of %s' % (filter_class.Meta.model, filter_model)
+ return filter_class
+
+ if filter_fields:
+ class AutoFilterSet(django_filters.FilterSet):
+ class Meta:
+ model = filter_model
+ fields = filter_fields
+ return AutoFilterSet
+
+ return None
+
+ def filter_queryset(self, request, queryset, view):
+ filter_class = self.get_filter_class(view)
+
+ if filter_class:
+ return filter_class(request.GET, queryset=queryset)
+
+ return queryset
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index ac02d3da..ebd06e45 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -6,7 +6,7 @@ from rest_framework import views, mixins
from rest_framework.settings import api_settings
from django.views.generic.detail import SingleObjectMixin
from django.views.generic.list import MultipleObjectMixin
-import django_filters
+
### Base classes for the generic views ###
@@ -58,34 +58,13 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
paginate_by = api_settings.PAGINATE_BY
- filter_class = None
- filter_fields = None
-
- def get_filter_class(self):
- """
- Return the django-filters `FilterSet` used to filter the queryset.
- """
- if self.filter_class:
- return self.filter_class
-
- if self.filter_fields:
- class AutoFilterSet(django_filters.FilterSet):
- class Meta:
- model = self.model
- fields = self.filter_fields
- return AutoFilterSet
-
- return None
+ filter_backend = api_settings.FILTER_BACKEND
def filter_queryset(self, queryset):
- filter_class = self.get_filter_class()
-
- if filter_class:
- assert issubclass(filter_class.Meta.model, self.model), \
- "%s is not a subclass of %s" % (filter_class.Meta.model, self.model)
- return filter_class(self.request.GET, queryset=queryset)
-
- return queryset
+ if not self.filter_backend:
+ return queryset
+ backend = self.filter_backend()
+ return backend.filter_queryset(self.request, queryset, self)
def get_filtered_queryset(self):
return self.filter_queryset(self.get_queryset())
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index c77a1005..aa54d154 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -1,4 +1,5 @@
from rest_framework import serializers
+from rest_framework.templatetags.rest_framework import replace_query_param
# TODO: Support URLconf kwarg-style paging
@@ -16,13 +17,8 @@ class NextPageField(PageField):
return None
page = value.next_page_number()
request = self.context.get('request')
- relative_url = '?%s=%d' % (self.page_field, page)
- if request:
- for field, value in request.QUERY_PARAMS.iteritems():
- if field != self.page_field:
- relative_url += '&%s=%s' % (field, value)
- return request.build_absolute_uri(relative_url)
- return relative_url
+ url = request and request.get_full_path() or ''
+ return replace_query_param(url, self.page_field, page)
class PreviousPageField(PageField):
@@ -34,13 +30,8 @@ class PreviousPageField(PageField):
return None
page = value.previous_page_number()
request = self.context.get('request')
- relative_url = '?%s=%d' % (self.page_field, page)
- if request:
- for field, value in request.QUERY_PARAMS.iteritems():
- if field != self.page_field:
- relative_url += '&%s=%s' % (field, value)
- return request.build_absolute_uri(relative_url)
- return relative_url
+ url = request and request.get_full_path() or ''
+ return replace_query_param(url, self.page_field, page)
class PaginationSerializerOptions(serializers.SerializerOptions):
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 9c40a214..da647658 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -55,6 +55,7 @@ DEFAULTS = {
'anon': None,
},
'PAGINATE_BY': None,
+ 'FILTER_BACKEND': 'rest_framework.filters.DjangoFilterBackend',
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
@@ -79,6 +80,7 @@ IMPORT_STRINGS = (
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
'DEFAULT_MODEL_SERIALIZER_CLASS',
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
+ 'FILTER_BACKEND',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
)
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index c9b6eb10..0672ee4f 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -1,9 +1,9 @@
from django import template
from django.core.urlresolvers import reverse
-from django.http import QueryDict
from django.utils.encoding import force_unicode
from django.utils.html import escape
from django.utils.safestring import SafeData, mark_safe
+from django.http import QueryDict
from urlparse import urlsplit, urlunsplit
import re
import string
@@ -11,6 +11,18 @@ import string
register = template.Library()
+def replace_query_param(url, key, val):
+ """
+ Given a URL and a key/val pair, set or replace an item in the query
+ parameters of the URL, and return the new URL.
+ """
+ (scheme, netloc, path, query, fragment) = urlsplit(url)
+ query_dict = QueryDict(query).copy()
+ query_dict[key] = val
+ query = query_dict.urlencode()
+ return urlunsplit((scheme, netloc, path, query, fragment))
+
+
# Regex for adding classes to html snippets
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
@@ -31,19 +43,6 @@ hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|
trailing_empty_content_re = re.compile(r'(?:<p>(?:&nbsp;|\s|<br \/>)*?</p>\s*)+\Z')
-# Helper function for 'add_query_param'
-def replace_query_param(url, key, val):
- """
- Given a URL and a key/val pair, set or replace an item in the query
- parameters of the URL, and return the new URL.
- """
- (scheme, netloc, path, query, fragment) = urlsplit(url)
- query_dict = QueryDict(query).copy()
- query_dict[key] = val
- query = query_dict.urlencode()
- return urlunsplit((scheme, netloc, path, query, fragment))
-
-
# And the template tags themselves...
@register.simple_tag
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
index 5374eefc..6cdea32f 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/filterset.py
@@ -2,44 +2,45 @@ import datetime
from decimal import Decimal
from django.test import TestCase
from django.test.client import RequestFactory
+from django.utils import unittest
from rest_framework import generics, status
+from rest_framework.compat import django_filters
from rest_framework.tests.models import FilterableItem, BasicModel
-import django_filters
factory = RequestFactory()
-# Basic filter on a list view.
-class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_fields = ['decimal', 'date']
-
-# These class are used to test a filter class.
-class SeveralFieldsFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
- decimal = django_filters.NumberFilter(lookup_type='lt')
- date = django_filters.DateFilter(lookup_type='gt')
- class Meta:
+if django_filters:
+ # Basic filter on a list view.
+ class FilterFieldsRootView(generics.ListCreateAPIView):
model = FilterableItem
- fields = ['text', 'decimal', 'date']
+ filter_fields = ['decimal', 'date']
+ # These class are used to test a filter class.
+ class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
-class FilterClassRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = SeveralFieldsFilter
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+ class FilterClassRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
-# These classes are used to test a misconfigured filter class.
-class MisconfiguredFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
- class Meta:
- model = BasicModel
- fields = ['text']
+ # These classes are used to test a misconfigured filter class.
+ class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ class Meta:
+ model = BasicModel
+ fields = ['text']
-class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = MisconfiguredFilter
+ class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = MisconfiguredFilter
class IntegrationTestFiltering(TestCase):
@@ -64,6 +65,7 @@ class IntegrationTestFiltering(TestCase):
for obj in self.objects.all()
]
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_filtered_fields_root_view(self):
"""
GET requests to paginated ListCreateAPIView should return paginated results.
@@ -81,7 +83,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
- expected_data = [ f for f in self.data if f['decimal'] == search_decimal ]
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
self.assertEquals(response.data, expected_data)
# Tests that the date filter works.
@@ -89,9 +91,10 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
- expected_data = [ f for f in self.data if f['date'] == search_date ]
+ expected_data = [f for f in self.data if f['date'] == search_date]
self.assertEquals(response.data, expected_data)
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_filtered_class_root_view(self):
"""
GET requests to filtered ListCreateAPIView that have a filter_class set
@@ -110,7 +113,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
- expected_data = [ f for f in self.data if f['decimal'] < search_decimal ]
+ expected_data = [f for f in self.data if f['decimal'] < search_decimal]
self.assertEquals(response.data, expected_data)
# Tests that the date filter set with 'gt' in the filter class works.
@@ -118,7 +121,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
- expected_data = [ f for f in self.data if f['date'] > search_date ]
+ expected_data = [f for f in self.data if f['date'] > search_date]
self.assertEquals(response.data, expected_data)
# Tests that the text filter set with 'icontains' in the filter class works.
@@ -126,7 +129,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?text=%s' % search_text)
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
- expected_data = [ f for f in self.data if search_text in f['text'].lower() ]
+ expected_data = [f for f in self.data if search_text in f['text'].lower()]
self.assertEquals(response.data, expected_data)
# Tests that multiple filters works.
@@ -135,10 +138,11 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
- expected_data = [ f for f in self.data if f['date'] > search_date and
- f['decimal'] < search_decimal ]
+ expected_data = [f for f in self.data if f['date'] > search_date and
+ f['decimal'] < search_decimal]
self.assertEquals(response.data, expected_data)
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
def test_incorrectly_configured_filter(self):
"""
An error should be displayed when the filter class is misconfigured.
@@ -148,6 +152,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/')
self.assertRaises(AssertionError, view, request)
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
def test_unknown_filter(self):
"""
GET requests with filters that aren't configured should return 200.
@@ -157,4 +162,4 @@ class IntegrationTestFiltering(TestCase):
search_integer = 10
request = factory.get('/?integer=%s' % search_integer)
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK) \ No newline at end of file
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 7a2134e0..7f8cd524 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -3,9 +3,10 @@ from decimal import Decimal
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
+from django.utils import unittest
from rest_framework import generics, status, pagination
+from rest_framework.compat import django_filters
from rest_framework.tests.models import BasicModel, FilterableItem
-import django_filters
factory = RequestFactory()
@@ -18,17 +19,18 @@ class RootView(generics.ListCreateAPIView):
paginate_by = 10
-class DecimalFilter(django_filters.FilterSet):
- decimal = django_filters.NumberFilter(lookup_type='lt')
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
+if django_filters:
+ class DecimalFilter(django_filters.FilterSet):
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
-class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- paginate_by = 10
- filter_class = DecimalFilter
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_class = DecimalFilter
class IntegrationTestPagination(TestCase):
@@ -98,6 +100,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
]
self.view = FilterFieldsRootView.as_view()
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_paginated_filtered_root_view(self):
"""
GET requests to paginated filtered ListCreateAPIView should return
diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py
index 18b6af39..d7b75450 100644
--- a/rest_framework/tests/response.py
+++ b/rest_framework/tests/response.py
@@ -131,12 +131,6 @@ class RendererIntegrationTests(TestCase):
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
- @unittest.skip('can\'t pass because view is a simple Django view and response is an ImmediateResponse')
- def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
- """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
- resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
- self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
-
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
diff --git a/rest_framework/utils/__init__.py b/rest_framework/utils/__init__.py
index a59fff45..84fcb5db 100644
--- a/rest_framework/utils/__init__.py
+++ b/rest_framework/utils/__init__.py
@@ -1,7 +1,6 @@
from django.utils.encoding import smart_unicode
from django.utils.xmlutils import SimplerXMLGenerator
from rest_framework.compat import StringIO
-
import re
import xml.etree.ElementTree as ET