aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/filters.py53
-rw-r--r--rest_framework/routers.py2
-rw-r--r--rest_framework/tests/filters.py (renamed from rest_framework/tests/filterset.py)116
3 files changed, 166 insertions, 5 deletions
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index 57f0f7c8..6a3e055d 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -4,7 +4,7 @@ returned by list views.
"""
from __future__ import unicode_literals
from django.db import models
-from rest_framework.compat import django_filters
+from rest_framework.compat import django_filters, six
from functools import reduce
import operator
@@ -75,7 +75,14 @@ class DjangoFilterBackend(BaseFilterBackend):
class SearchFilter(BaseFilterBackend):
search_param = 'search' # The URL query parameter used for the search.
- delimiter = None # For example, set to ',' for comma delimited searchs.
+
+ def get_search_terms(self, request):
+ """
+ Search terms are set by a ?search=... query parameter,
+ and may be comma and/or whitespace delimited.
+ """
+ params = request.QUERY_PARAMS.get(self.search_param)
+ return params.replace(',', ' ').split()
def construct_search(self, field_name):
if field_name.startswith('^'):
@@ -93,13 +100,51 @@ class SearchFilter(BaseFilterBackend):
if not search_fields:
return None
- search_terms = request.QUERY_PARAMS.get(self.search_param)
orm_lookups = [self.construct_search(str(search_field))
for search_field in search_fields]
- for search_term in search_terms.split(self.delimiter):
+ for search_term in self.get_search_terms(request):
or_queries = [models.Q(**{orm_lookup: search_term})
for orm_lookup in orm_lookups]
queryset = queryset.filter(reduce(operator.or_, or_queries))
return queryset
+
+
+class OrderingFilter(BaseFilterBackend):
+ ordering_param = 'ordering' # The URL query parameter used for the ordering.
+
+ def get_ordering(self, request):
+ """
+ Search terms are set by a ?search=... query parameter,
+ and may be comma and/or whitespace delimited.
+ """
+ params = request.QUERY_PARAMS.get(self.ordering_param)
+ if params:
+ return [param.strip() for param in params.split(',')]
+
+ def get_default_ordering(self, view):
+ ordering = getattr(view, 'ordering', None)
+ if isinstance(ordering, six.string_types):
+ return (ordering,)
+ return ordering
+
+ def remove_invalid_fields(self, queryset, ordering):
+ field_names = [field.name for field in queryset.model._meta.fields]
+ return [term for term in ordering if term.lstrip('-') in field_names]
+
+ def filter_queryset(self, request, queryset, view):
+ ordering = self.get_ordering(request)
+
+ if ordering:
+ # Skip any incorrect parameters
+ ordering = self.remove_invalid_fields(queryset, ordering)
+
+ if not ordering:
+ # Use 'ordering' attribtue by default
+ ordering = self.get_default_ordering(view)
+
+ if ordering:
+ return queryset.order_by(*ordering)
+
+ return queryset
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index ebdf2b2a..76714fd0 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -16,7 +16,7 @@ For example, you might have a `urls.py` that looks something like this:
from __future__ import unicode_literals
from collections import namedtuple
-from django.conf.urls import url, patterns
+from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.reverse import reverse
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filters.py
index e5414232..b20d5980 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/filters.py
@@ -335,3 +335,119 @@ class SearchFilterTests(TestCase):
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
)
+
+
+class OrdringFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class OrderingFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # zyx abc
+ # yxw bcd
+ # xwv cde
+ for idx in range(3):
+ title = (
+ chr(ord('z') - idx) +
+ chr(ord('y') - idx) +
+ chr(ord('x') - idx)
+ )
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ OrdringFilterModel(title=title, text=text).save()
+
+ def test_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
+ def test_reverse_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=-text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_incorrectfield_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=foobar')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering_using_string(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )