diff options
| -rw-r--r-- | rest_framework/filters.py | 9 | ||||
| -rw-r--r-- | rest_framework/tests/filterset.py | 81 |
2 files changed, 87 insertions, 3 deletions
diff --git a/rest_framework/filters.py b/rest_framework/filters.py index f2163f6f..54cbbde3 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -74,6 +74,8 @@ class DjangoFilterBackend(BaseFilterBackend): class SearchFilter(BaseFilterBackend): + search_param = 'search' + def construct_search(self, field_name): if field_name.startswith('^'): return "%s__istartswith" % field_name[1:] @@ -90,10 +92,13 @@ class SearchFilter(BaseFilterBackend): if not search_fields: return None + search_terms = request.QUERY_PARAMS.get(self.search_param) orm_lookups = [self.construct_search(str(search_field)) - for search_field in self.search_fields] - for bit in self.query.split(): + for search_field in search_fields] + + for bit in search_terms.split(): or_queries = [models.Q(**{orm_lookup: bit}) for orm_lookup in orm_lookups] queryset = queryset.filter(reduce(operator.or_, or_queries)) + return queryset diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 023bd016..7865fedd 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -1,17 +1,24 @@ from __future__ import unicode_literals import datetime from decimal import Decimal +from django.db import models from django.core.urlresolvers import reverse from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url -from rest_framework.tests.models import FilterableItem, BasicModel +from rest_framework.tests.models import BasicModel factory = RequestFactory() +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + if django_filters: # Basic filter on a list view. class FilterFieldsRootView(generics.ListCreateAPIView): @@ -256,3 +263,75 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, valid_item_data) + + +class SearchFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class SearchFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # z abc + # zz bcd + # zzz cde + # ... + for idx in range(10): + title = 'z' * (idx + 1) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + SearchFilterModel(title=title, text=text).save() + + def test_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {u'id': 1, 'title': u'z', 'text': u'abc'}, + {u'id': 2, 'title': u'zz', 'text': u'bcd'} + ] + ) + + def test_exact_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('=title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=zzz') + response = view(request) + self.assertEqual( + response.data, + [ + {u'id': 3, 'title': u'zzz', 'text': u'cde'} + ] + ) + + def test_startswith_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', '^text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {u'id': 2, 'title': u'zz', 'text': u'bcd'} + ] + ) |
