diff options
| -rw-r--r-- | rest_framework/filters.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/models.py | 12 | ||||
| -rw-r--r-- | rest_framework/tests/test_filters.py | 35 | ||||
| -rw-r--r-- | rest_framework/tests/test_pagination.py | 8 |
4 files changed, 41 insertions, 16 deletions
diff --git a/rest_framework/filters.py b/rest_framework/filters.py index de91caed..f7ad37ba 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -43,7 +43,7 @@ class DjangoFilterBackend(BaseFilterBackend): if filter_class: filter_model = filter_class.Meta.model - assert issubclass(filter_model, queryset.model), \ + assert issubclass(queryset.model, filter_model), \ 'FilterSet model %s does not match queryset model %s' % \ (filter_model, queryset.model) diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 32a726c0..0137d45a 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -60,6 +60,18 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel): rel = models.ManyToManyField(Anchor) +class BaseFilterableItem(RESTFrameworkModel): + text = models.CharField(max_length=100) + + class Meta: + abstract = True + + +class FilterableItem(BaseFilterableItem): + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + # Model for regression test for #285 class Comment(RESTFrameworkModel): diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py index 18188186..769d3426 100644 --- a/rest_framework/tests/test_filters.py +++ b/rest_framework/tests/test_filters.py @@ -8,17 +8,12 @@ from django.utils import unittest from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url from rest_framework.test import APIRequestFactory -from rest_framework.tests.models import BasicModel +from rest_framework.tests.models import (BaseFilterableItem, BasicModel, + FilterableItem) factory = APIRequestFactory() -class FilterableItem(models.Model): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() - - if django_filters: # Basic filter on a list view. class FilterFieldsRootView(generics.ListCreateAPIView): @@ -59,6 +54,18 @@ if django_filters: filter_class = SeveralFieldsFilter filter_backends = (filters.DjangoFilterBackend,) + # These classes are used to test base model filter support + class BaseFilterableItemFilter(django_filters.FilterSet): + text = django_filters.CharFilter() + + class Meta: + model = BaseFilterableItem + + class BaseFilterableItemFilterRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = BaseFilterableItemFilter + filter_backends = (filters.DjangoFilterBackend,) + # Regression test for #814 class FilterableItemSerializer(serializers.ModelSerializer): class Meta: @@ -227,6 +234,18 @@ class IntegrationTestFiltering(CommonFilteringTestCase): self.assertRaises(AssertionError, view, request) @unittest.skipUnless(django_filters, 'django-filter not installed') + def test_base_model_filter(self): + """ + The `get_filter_class` model checks should allow base model filters. + """ + view = BaseFilterableItemFilterRootView.as_view() + + request = factory.get('/?text=aaa') + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + + @unittest.skipUnless(django_filters, 'django-filter not installed') def test_unknown_filter(self): """ GET requests with filters that aren't configured should return 200. @@ -612,4 +631,4 @@ class SensitiveOrderingFilterTests(TestCase): {'id': 2, username_field: 'userB'}, # PassC {'id': 3, username_field: 'userC'}, # PassA ] - )
\ No newline at end of file + ) diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index cadb515f..cd299613 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -8,17 +8,11 @@ from django.utils import unittest from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory -from rest_framework.tests.models import BasicModel +from rest_framework.tests.models import BasicModel, FilterableItem factory = APIRequestFactory() -class FilterableItem(models.Model): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() - - class RootView(generics.ListCreateAPIView): """ Example description for OPTIONS. |
