diff options
| author | Tom Christie | 2012-11-07 20:13:27 +0000 | 
|---|---|---|
| committer | Tom Christie | 2012-11-07 20:13:27 +0000 | 
| commit | 9fd061a0b68f0cef6683bf195911a2cc7ff2fa06 (patch) | |
| tree | 60769e37be41acf4bf12ba4ad59737e57c55da6a /rest_framework/generics.py | |
| parent | 066d51faa16d1cfd3c8370c6bfe46f8494bbc26a (diff) | |
| parent | 09f39bd23b3c688c89551845d665395e1aabbfab (diff) | |
| download | django-rest-framework-9fd061a0b68f0cef6683bf195911a2cc7ff2fa06.tar.bz2 | |
Merge branch 'restframework2-filter' of git://github.com/onepercentclub/django-rest-framework into filtering
Diffstat (limited to 'rest_framework/generics.py')
| -rw-r--r-- | rest_framework/generics.py | 33 | 
1 files changed, 32 insertions, 1 deletions
| diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 45cedd8b..ac02d3da 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -6,7 +6,7 @@ from rest_framework import views, mixins  from rest_framework.settings import api_settings  from django.views.generic.detail import SingleObjectMixin  from django.views.generic.list import MultipleObjectMixin - +import django_filters  ### Base classes for the generic views ### @@ -58,6 +58,37 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):      pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS      paginate_by = api_settings.PAGINATE_BY +    filter_class = None +    filter_fields = None + +    def get_filter_class(self): +        """ +        Return the django-filters `FilterSet` used to filter the queryset. +        """ +        if self.filter_class: +            return self.filter_class + +        if self.filter_fields: +            class AutoFilterSet(django_filters.FilterSet): +                class Meta: +                    model = self.model +                fields = self.filter_fields +            return AutoFilterSet + +        return None + +    def filter_queryset(self, queryset): +        filter_class = self.get_filter_class() + +        if filter_class: +            assert issubclass(filter_class.Meta.model, self.model), \ +                "%s is not a subclass of %s" % (filter_class.Meta.model, self.model) +            return filter_class(self.request.GET, queryset=queryset) + +        return queryset + +    def get_filtered_queryset(self): +        return self.filter_queryset(self.get_queryset())      def get_pagination_serializer_class(self):          """ | 
