aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSimon Charette2014-02-09 00:50:03 -0500
committerSimon Charette2014-02-09 00:50:03 -0500
commit4d45865bd73ba16801950e3f47199aa6da0f7c19 (patch)
tree88cb9ab4ee1440526c1c4e9adb0bbc39bff96305
parent00b187710623d8efda62f207573fa4e356d1f8ef (diff)
downloaddjango-rest-framework-4d45865bd73ba16801950e3f47199aa6da0f7c19.tar.bz2
Allow filter model to be a subclass of the queryset one.
-rw-r--r--rest_framework/filters.py2
-rw-r--r--rest_framework/tests/models.py12
-rw-r--r--rest_framework/tests/test_filters.py35
-rw-r--r--rest_framework/tests/test_pagination.py8
4 files changed, 41 insertions, 16 deletions
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index de91caed..f7ad37ba 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -43,7 +43,7 @@ class DjangoFilterBackend(BaseFilterBackend):
if filter_class:
filter_model = filter_class.Meta.model
- assert issubclass(filter_model, queryset.model), \
+ assert issubclass(queryset.model, filter_model), \
'FilterSet model %s does not match queryset model %s' % \
(filter_model, queryset.model)
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 32a726c0..0137d45a 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -60,6 +60,18 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
+class BaseFilterableItem(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+
+ class Meta:
+ abstract = True
+
+
+class FilterableItem(BaseFilterableItem):
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
# Model for regression test for #285
class Comment(RESTFrameworkModel):
diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py
index 18188186..769d3426 100644
--- a/rest_framework/tests/test_filters.py
+++ b/rest_framework/tests/test_filters.py
@@ -8,17 +8,12 @@ from django.utils import unittest
from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url
from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import BasicModel
+from rest_framework.tests.models import (BaseFilterableItem, BasicModel,
+ FilterableItem)
factory = APIRequestFactory()
-class FilterableItem(models.Model):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
-
-
if django_filters:
# Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView):
@@ -59,6 +54,18 @@ if django_filters:
filter_class = SeveralFieldsFilter
filter_backends = (filters.DjangoFilterBackend,)
+ # These classes are used to test base model filter support
+ class BaseFilterableItemFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter()
+
+ class Meta:
+ model = BaseFilterableItem
+
+ class BaseFilterableItemFilterRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = BaseFilterableItemFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
# Regression test for #814
class FilterableItemSerializer(serializers.ModelSerializer):
class Meta:
@@ -227,6 +234,18 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
self.assertRaises(AssertionError, view, request)
@unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_base_model_filter(self):
+ """
+ The `get_filter_class` model checks should allow base model filters.
+ """
+ view = BaseFilterableItemFilterRootView.as_view()
+
+ request = factory.get('/?text=aaa')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(len(response.data), 1)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
def test_unknown_filter(self):
"""
GET requests with filters that aren't configured should return 200.
@@ -612,4 +631,4 @@ class SensitiveOrderingFilterTests(TestCase):
{'id': 2, username_field: 'userB'}, # PassC
{'id': 3, username_field: 'userC'}, # PassA
]
- ) \ No newline at end of file
+ )
diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py
index cadb515f..cd299613 100644
--- a/rest_framework/tests/test_pagination.py
+++ b/rest_framework/tests/test_pagination.py
@@ -8,17 +8,11 @@ from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import BasicModel
+from rest_framework.tests.models import BasicModel, FilterableItem
factory = APIRequestFactory()
-class FilterableItem(models.Model):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
-
-
class RootView(generics.ListCreateAPIView):
"""
Example description for OPTIONS.