diff options
| -rw-r--r-- | rest_framework/filters.py | 14 | ||||
| -rw-r--r-- | rest_framework/tests/filterset.py | 28 | 
2 files changed, 37 insertions, 5 deletions
diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 571704dc..f2163f6f 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -38,21 +38,27 @@ class DjangoFilterBackend(BaseFilterBackend):          """          filter_class = getattr(view, 'filter_class', None)          filter_fields = getattr(view, 'filter_fields', None) -        view_model = getattr(view, 'model', None) +        model_cls = getattr(view, 'model', None) +        queryset = getattr(view, 'queryset', None) +        if model_cls is None and queryset is not None: +            model_cls = queryset.model          if filter_class:              filter_model = filter_class.Meta.model -            assert issubclass(filter_model, view_model), \ +            assert issubclass(filter_model, model_cls), \                  'FilterSet model %s does not match view model %s' % \ -                (filter_model, view_model) +                (filter_model, model_cls)              return filter_class          if filter_fields: +            assert model_cls is not None, 'Cannot use DjangoFilterBackend ' \ +                'on a view which does not have a .model or .queryset attribute.' +              class AutoFilterSet(self.default_filter_set):                  class Meta: -                    model = view_model +                    model = model_cls                      fields = filter_fields              return AutoFilterSet diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 1e53a5cd..023bd016 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -5,7 +5,7 @@ from django.core.urlresolvers import reverse  from django.test import TestCase  from django.test.client import RequestFactory  from django.utils import unittest -from rest_framework import generics, status, filters +from rest_framework import generics, serializers, status, filters  from rest_framework.compat import django_filters, patterns, url  from rest_framework.tests.models import FilterableItem, BasicModel @@ -52,6 +52,17 @@ if django_filters:          filter_class = SeveralFieldsFilter          filter_backend = filters.DjangoFilterBackend +    # Regression test for #814 +    class FilterableItemSerializer(serializers.ModelSerializer): +        class Meta: +            model = FilterableItem + +    class FilterFieldsQuerysetView(generics.ListCreateAPIView): +        queryset = FilterableItem.objects.all() +        serializer_class = FilterableItemSerializer +        filter_fields = ['decimal', 'date'] +        filter_backend = filters.DjangoFilterBackend +      urlpatterns = patterns('',          url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),          url(r'^$', FilterClassRootView.as_view(), name='root-view'), @@ -115,6 +126,21 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          self.assertEqual(response.data, expected_data)      @unittest.skipUnless(django_filters, 'django-filters not installed') +    def test_filter_with_queryset(self): +        """ +        Regression test for #814. +        """ +        view = FilterFieldsQuerysetView.as_view() + +        # Tests that the decimal filter works. +        search_decimal = Decimal('2.25') +        request = factory.get('/?decimal=%s' % search_decimal) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if f['decimal'] == search_decimal] +        self.assertEqual(response.data, expected_data) + +    @unittest.skipUnless(django_filters, 'django-filters not installed')      def test_get_filtered_class_root_view(self):          """          GET requests to filtered ListCreateAPIView that have a filter_class set  | 
