aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-05-08 20:18:01 +0100
committerTom Christie2013-05-08 20:18:01 +0100
commitde69a28b9e786b8c759cda4acedb0a1b8542298b (patch)
tree85e3044c696d297d2721dfc480383efea0969eb5 /rest_framework
parent9d59e55cec6458e17cba758bb11986f01fd401c4 (diff)
downloaddjango-rest-framework-de69a28b9e786b8c759cda4acedb0a1b8542298b.tar.bz2
Test and fix for #814.
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/filters.py14
-rw-r--r--rest_framework/tests/filterset.py28
2 files changed, 37 insertions, 5 deletions
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index 571704dc..f2163f6f 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -38,21 +38,27 @@ class DjangoFilterBackend(BaseFilterBackend):
"""
filter_class = getattr(view, 'filter_class', None)
filter_fields = getattr(view, 'filter_fields', None)
- view_model = getattr(view, 'model', None)
+ model_cls = getattr(view, 'model', None)
+ queryset = getattr(view, 'queryset', None)
+ if model_cls is None and queryset is not None:
+ model_cls = queryset.model
if filter_class:
filter_model = filter_class.Meta.model
- assert issubclass(filter_model, view_model), \
+ assert issubclass(filter_model, model_cls), \
'FilterSet model %s does not match view model %s' % \
- (filter_model, view_model)
+ (filter_model, model_cls)
return filter_class
if filter_fields:
+ assert model_cls is not None, 'Cannot use DjangoFilterBackend ' \
+ 'on a view which does not have a .model or .queryset attribute.'
+
class AutoFilterSet(self.default_filter_set):
class Meta:
- model = view_model
+ model = model_cls
fields = filter_fields
return AutoFilterSet
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
index 1e53a5cd..023bd016 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/filterset.py
@@ -5,7 +5,7 @@ from django.core.urlresolvers import reverse
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
-from rest_framework import generics, status, filters
+from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url
from rest_framework.tests.models import FilterableItem, BasicModel
@@ -52,6 +52,17 @@ if django_filters:
filter_class = SeveralFieldsFilter
filter_backend = filters.DjangoFilterBackend
+ # Regression test for #814
+ class FilterableItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FilterableItem
+
+ class FilterFieldsQuerysetView(generics.ListCreateAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_fields = ['decimal', 'date']
+ filter_backend = filters.DjangoFilterBackend
+
urlpatterns = patterns('',
url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
url(r'^$', FilterClassRootView.as_view(), name='root-view'),
@@ -115,6 +126,21 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_filter_with_queryset(self):
+ """
+ Regression test for #814.
+ """
+ view = FilterFieldsQuerysetView.as_view()
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_filtered_class_root_view(self):
"""
GET requests to filtered ListCreateAPIView that have a filter_class set