diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/filters.py | 21 | ||||
| -rw-r--r-- | rest_framework/runtests/settings.py | 1 | ||||
| -rw-r--r-- | rest_framework/settings.py | 9 | ||||
| -rw-r--r-- | rest_framework/tests/filterset.py | 5 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 3 | 
5 files changed, 28 insertions, 11 deletions
diff --git a/rest_framework/filters.py b/rest_framework/filters.py index b972e82a..14902a69 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -17,6 +17,10 @@ class DjangoFilterBackend(BaseFilterBackend):      """      A filter backend that uses django-filter.      """ +    default_filter_set = django_filters.FilterSet + +    def __init__(self): +        assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed'      def get_filter_class(self, view):          """ @@ -24,20 +28,21 @@ class DjangoFilterBackend(BaseFilterBackend):          """          filter_class = getattr(view, 'filter_class', None)          filter_fields = getattr(view, 'filter_fields', None) -        filter_model = getattr(view, 'model', None) - -        if filter_class or filter_fields: -            assert django_filters, 'django-filter is not installed' +        view_model = getattr(view, 'model', None)          if filter_class: -            assert issubclass(filter_class.Meta.model, filter_model), \ -                '%s is not a subclass of %s' % (filter_class.Meta.model, filter_model) +            filter_model = filter_class.Meta.model + +            assert issubclass(filter_model, view_model), \ +                'FilterSet model %s does not match view model %s' % \ +                (filter_model, view_model) +              return filter_class          if filter_fields: -            class AutoFilterSet(django_filters.FilterSet): +            class AutoFilterSet(self.default_filter_set):                  class Meta: -                    model = filter_model +                    model = view_model                  fields = filter_fields              return AutoFilterSet diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index b48f85e4..dd5d9dc3 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -107,6 +107,7 @@ import django  if django.VERSION < (1, 3):      INSTALLED_APPS += ('staticfiles',) +  # If we're running on the Jenkins server we want to archive the coverage reports as XML.  import os  if os.environ.get('HUDSON_URL', None): diff --git a/rest_framework/settings.py b/rest_framework/settings.py index da647658..906a7cf6 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -55,7 +55,7 @@ DEFAULTS = {          'anon': None,      },      'PAGINATE_BY': None, -    'FILTER_BACKEND': 'rest_framework.filters.DjangoFilterBackend', +    'FILTER_BACKEND': None,      'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',      'UNAUTHENTICATED_TOKEN': None, @@ -144,8 +144,15 @@ class APISettings(object):          if val and attr in self.import_strings:              val = perform_import(val, attr) +        self.validate_setting(attr, val) +          # Cache the result          setattr(self, attr, val)          return val +    def validate_setting(self, attr, val): +        if attr == 'FILTER_BACKEND' and val is not None: +            # Make sure we can initilize the class +            val() +  api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 6cdea32f..af2e6c2e 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -3,7 +3,7 @@ from decimal import Decimal  from django.test import TestCase  from django.test.client import RequestFactory  from django.utils import unittest -from rest_framework import generics, status +from rest_framework import generics, status, filters  from rest_framework.compat import django_filters  from rest_framework.tests.models import FilterableItem, BasicModel @@ -15,6 +15,7 @@ if django_filters:      class FilterFieldsRootView(generics.ListCreateAPIView):          model = FilterableItem          filter_fields = ['decimal', 'date'] +        filter_backend = filters.DjangoFilterBackend      # These class are used to test a filter class.      class SeveralFieldsFilter(django_filters.FilterSet): @@ -29,6 +30,7 @@ if django_filters:      class FilterClassRootView(generics.ListCreateAPIView):          model = FilterableItem          filter_class = SeveralFieldsFilter +        filter_backend = filters.DjangoFilterBackend      # These classes are used to test a misconfigured filter class.      class MisconfiguredFilter(django_filters.FilterSet): @@ -41,6 +43,7 @@ if django_filters:      class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):          model = FilterableItem          filter_class = MisconfiguredFilter +        filter_backend = filters.DjangoFilterBackend  class IntegrationTestFiltering(TestCase): diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 7f8cd524..713a7255 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -4,7 +4,7 @@ from django.core.paginator import Paginator  from django.test import TestCase  from django.test.client import RequestFactory  from django.utils import unittest -from rest_framework import generics, status, pagination +from rest_framework import generics, status, pagination, filters  from rest_framework.compat import django_filters  from rest_framework.tests.models import BasicModel, FilterableItem @@ -31,6 +31,7 @@ if django_filters:          model = FilterableItem          paginate_by = 10          filter_class = DecimalFilter +        filter_backend = filters.DjangoFilterBackend  class IntegrationTestPagination(TestCase):  | 
