diff options
Diffstat (limited to 'rest_framework/tests')
| -rw-r--r-- | rest_framework/tests/filterset.py | 160 | ||||
| -rw-r--r-- | rest_framework/tests/models.py | 7 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 71 | 
3 files changed, 236 insertions, 2 deletions
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py new file mode 100644 index 00000000..5374eefc --- /dev/null +++ b/rest_framework/tests/filterset.py @@ -0,0 +1,160 @@ +import datetime +from decimal import Decimal +from django.test import TestCase +from django.test.client import RequestFactory +from rest_framework import generics, status +from rest_framework.tests.models import FilterableItem, BasicModel +import django_filters + +factory = RequestFactory() + +# Basic filter on a list view. +class FilterFieldsRootView(generics.ListCreateAPIView): +    model = FilterableItem +    filter_fields = ['decimal', 'date'] + + +# These class are used to test a filter class. +class SeveralFieldsFilter(django_filters.FilterSet): +    text = django_filters.CharFilter(lookup_type='icontains') +    decimal = django_filters.NumberFilter(lookup_type='lt') +    date = django_filters.DateFilter(lookup_type='gt') +    class Meta: +        model = FilterableItem +        fields = ['text', 'decimal', 'date'] + + +class FilterClassRootView(generics.ListCreateAPIView): +    model = FilterableItem +    filter_class = SeveralFieldsFilter + + +# These classes are used to test a misconfigured filter class. +class MisconfiguredFilter(django_filters.FilterSet): +    text = django_filters.CharFilter(lookup_type='icontains') +    class Meta: +        model = BasicModel +        fields = ['text'] + + +class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): +    model = FilterableItem +    filter_class = MisconfiguredFilter + + +class IntegrationTestFiltering(TestCase): +    """ +    Integration tests for filtered list views. +    """ + +    def setUp(self): +        """ +        Create 10 FilterableItem instances. +        """ +        base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) +        for i in range(10): +            text = chr(i + ord(base_data[0])) * 3  # Produces string 'aaa', 'bbb', etc. +            decimal = base_data[1] + i +            date = base_data[2] - datetime.timedelta(days=i * 2) +            FilterableItem(text=text, decimal=decimal, date=date).save() + +        self.objects = FilterableItem.objects +        self.data = [ +        {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} +        for obj in self.objects.all() +        ] + +    def test_get_filtered_fields_root_view(self): +        """ +        GET requests to paginated ListCreateAPIView should return paginated results. +        """ +        view = FilterFieldsRootView.as_view() + +        # Basic test with no filter. +        request = factory.get('/') +        response = view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        self.assertEquals(response.data, self.data) + +        # Tests that the decimal filter works. +        search_decimal = Decimal('2.25') +        request = factory.get('/?decimal=%s' % search_decimal) +        response = view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        expected_data = [ f for f in self.data if f['decimal'] == search_decimal ] +        self.assertEquals(response.data, expected_data) + +        # Tests that the date filter works. +        search_date = datetime.date(2012, 9, 22) +        request = factory.get('/?date=%s' % search_date)  # search_date str: '2012-09-22' +        response = view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        expected_data = [ f for f in self.data if f['date'] == search_date ] +        self.assertEquals(response.data, expected_data) + +    def test_get_filtered_class_root_view(self): +        """ +        GET requests to filtered ListCreateAPIView that have a filter_class set +        should return filtered results. +        """ +        view = FilterClassRootView.as_view() + +        # Basic test with no filter. +        request = factory.get('/') +        response = view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        self.assertEquals(response.data, self.data) + +        # Tests that the decimal filter set with 'lt' in the filter class works. +        search_decimal = Decimal('4.25') +        request = factory.get('/?decimal=%s' % search_decimal) +        response = view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        expected_data = [ f for f in self.data if f['decimal'] < search_decimal ] +        self.assertEquals(response.data, expected_data) + +        # Tests that the date filter set with 'gt' in the filter class works. +        search_date = datetime.date(2012, 10, 2) +        request = factory.get('/?date=%s' % search_date)  # search_date str: '2012-10-02' +        response = view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        expected_data = [ f for f in self.data if f['date'] > search_date ] +        self.assertEquals(response.data, expected_data) + +        # Tests that the text filter set with 'icontains' in the filter class works. +        search_text = 'ff' +        request = factory.get('/?text=%s' % search_text) +        response = view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        expected_data = [ f for f in self.data if search_text in f['text'].lower() ] +        self.assertEquals(response.data, expected_data) + +        # Tests that multiple filters works. +        search_decimal = Decimal('5.25') +        search_date = datetime.date(2012, 10, 2) +        request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) +        response = view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        expected_data = [ f for f in self.data if f['date'] > search_date and +                                                  f['decimal'] < search_decimal ] +        self.assertEquals(response.data, expected_data) + +    def test_incorrectly_configured_filter(self): +        """ +        An error should be displayed when the filter class is misconfigured. +        """ +        view = IncorrectlyConfiguredRootView.as_view() + +        request = factory.get('/') +        self.assertRaises(AssertionError, view, request) + +    def test_unknown_filter(self): +        """ +        GET requests with filters that aren't configured should return 200. +        """ +        view = FilterFieldsRootView.as_view() + +        search_integer = 10 +        request = factory.get('/?integer=%s' % search_integer) +        response = view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK)
\ No newline at end of file diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 0e23734e..a2aba5be 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -95,6 +95,13 @@ class Bookmark(RESTFrameworkModel):      tags = GenericRelation(TaggedItem) +# Model to test filtering. +class FilterableItem(RESTFrameworkModel): +    text = models.CharField(max_length=100) +    decimal = models.DecimalField(max_digits=4, decimal_places=2) +    date = models.DateField() + +  # Model for regression test for #285  class Comment(RESTFrameworkModel): diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 64e8d822..7a2134e0 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -1,8 +1,11 @@ +import datetime +from decimal import Decimal  from django.core.paginator import Paginator  from django.test import TestCase  from django.test.client import RequestFactory  from rest_framework import generics, status, pagination -from rest_framework.tests.models import BasicModel +from rest_framework.tests.models import BasicModel, FilterableItem +import django_filters  factory = RequestFactory() @@ -15,6 +18,19 @@ class RootView(generics.ListCreateAPIView):      paginate_by = 10 +class DecimalFilter(django_filters.FilterSet): +    decimal = django_filters.NumberFilter(lookup_type='lt') +    class Meta: +        model = FilterableItem +        fields = ['text', 'decimal', 'date'] + + +class FilterFieldsRootView(generics.ListCreateAPIView): +    model = FilterableItem +    paginate_by = 10 +    filter_class = DecimalFilter + +  class IntegrationTestPagination(TestCase):      """      Integration tests for paginated list views. @@ -22,7 +38,7 @@ class IntegrationTestPagination(TestCase):      def setUp(self):          """ -        Create 26 BasicModel intances. +        Create 26 BasicModel instances.          """          for char in 'abcdefghijklmnopqrstuvwxyz':              BasicModel(text=char * 3).save() @@ -62,6 +78,57 @@ class IntegrationTestPagination(TestCase):          self.assertNotEquals(response.data['previous'], None) +class IntegrationTestPaginationAndFiltering(TestCase): + +    def setUp(self): +        """ +        Create 50 FilterableItem instances. +        """ +        base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) +        for i in range(26): +            text = chr(i + ord(base_data[0])) * 3  # Produces string 'aaa', 'bbb', etc. +            decimal = base_data[1] + i +            date = base_data[2] - datetime.timedelta(days=i * 2) +            FilterableItem(text=text, decimal=decimal, date=date).save() + +        self.objects = FilterableItem.objects +        self.data = [ +        {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} +        for obj in self.objects.all() +        ] +        self.view = FilterFieldsRootView.as_view() + +    def test_get_paginated_filtered_root_view(self): +        """ +        GET requests to paginated filtered ListCreateAPIView should return +        paginated results. The next and previous links should preserve the +        filtered parameters. +        """ +        request = factory.get('/?decimal=15.20') +        response = self.view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        self.assertEquals(response.data['count'], 15) +        self.assertEquals(response.data['results'], self.data[:10]) +        self.assertNotEquals(response.data['next'], None) +        self.assertEquals(response.data['previous'], None) + +        request = factory.get(response.data['next']) +        response = self.view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        self.assertEquals(response.data['count'], 15) +        self.assertEquals(response.data['results'], self.data[10:15]) +        self.assertEquals(response.data['next'], None) +        self.assertNotEquals(response.data['previous'], None) + +        request = factory.get(response.data['previous']) +        response = self.view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        self.assertEquals(response.data['count'], 15) +        self.assertEquals(response.data['results'], self.data[:10]) +        self.assertNotEquals(response.data['next'], None) +        self.assertEquals(response.data['previous'], None) + +  class UnitTestPagination(TestCase):      """      Unit tests for pagination of primative objects.  | 
