aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--rest_framework/filters.py9
-rw-r--r--rest_framework/tests/filterset.py81
2 files changed, 87 insertions, 3 deletions
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index f2163f6f..54cbbde3 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -74,6 +74,8 @@ class DjangoFilterBackend(BaseFilterBackend):
class SearchFilter(BaseFilterBackend):
+ search_param = 'search'
+
def construct_search(self, field_name):
if field_name.startswith('^'):
return "%s__istartswith" % field_name[1:]
@@ -90,10 +92,13 @@ class SearchFilter(BaseFilterBackend):
if not search_fields:
return None
+ search_terms = request.QUERY_PARAMS.get(self.search_param)
orm_lookups = [self.construct_search(str(search_field))
- for search_field in self.search_fields]
- for bit in self.query.split():
+ for search_field in search_fields]
+
+ for bit in search_terms.split():
or_queries = [models.Q(**{orm_lookup: bit})
for orm_lookup in orm_lookups]
queryset = queryset.filter(reduce(operator.or_, or_queries))
+
return queryset
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
index 023bd016..7865fedd 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/filterset.py
@@ -1,17 +1,24 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
+from django.db import models
from django.core.urlresolvers import reverse
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url
-from rest_framework.tests.models import FilterableItem, BasicModel
+from rest_framework.tests.models import BasicModel
factory = RequestFactory()
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
if django_filters:
# Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView):
@@ -256,3 +263,75 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, valid_item_data)
+
+
+class SearchFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class SearchFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # z abc
+ # zz bcd
+ # zzz cde
+ # ...
+ for idx in range(10):
+ title = 'z' * (idx + 1)
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ SearchFilterModel(title=title, text=text).save()
+
+ def test_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {u'id': 1, 'title': u'z', 'text': u'abc'},
+ {u'id': 2, 'title': u'zz', 'text': u'bcd'}
+ ]
+ )
+
+ def test_exact_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=zzz')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {u'id': 3, 'title': u'zzz', 'text': u'cde'}
+ ]
+ )
+
+ def test_startswith_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', '^text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {u'id': 2, 'title': u'zz', 'text': u'bcd'}
+ ]
+ )