aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/generics.py
diff options
context:
space:
mode:
authorTom Christie2012-11-07 20:13:27 +0000
committerTom Christie2012-11-07 20:13:27 +0000
commit9fd061a0b68f0cef6683bf195911a2cc7ff2fa06 (patch)
tree60769e37be41acf4bf12ba4ad59737e57c55da6a /rest_framework/generics.py
parent066d51faa16d1cfd3c8370c6bfe46f8494bbc26a (diff)
parent09f39bd23b3c688c89551845d665395e1aabbfab (diff)
downloaddjango-rest-framework-9fd061a0b68f0cef6683bf195911a2cc7ff2fa06.tar.bz2
Merge branch 'restframework2-filter' of git://github.com/onepercentclub/django-rest-framework into filtering
Diffstat (limited to 'rest_framework/generics.py')
-rw-r--r--rest_framework/generics.py33
1 files changed, 32 insertions, 1 deletions
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 45cedd8b..ac02d3da 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -6,7 +6,7 @@ from rest_framework import views, mixins
from rest_framework.settings import api_settings
from django.views.generic.detail import SingleObjectMixin
from django.views.generic.list import MultipleObjectMixin
-
+import django_filters
### Base classes for the generic views ###
@@ -58,6 +58,37 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
paginate_by = api_settings.PAGINATE_BY
+ filter_class = None
+ filter_fields = None
+
+ def get_filter_class(self):
+ """
+ Return the django-filters `FilterSet` used to filter the queryset.
+ """
+ if self.filter_class:
+ return self.filter_class
+
+ if self.filter_fields:
+ class AutoFilterSet(django_filters.FilterSet):
+ class Meta:
+ model = self.model
+ fields = self.filter_fields
+ return AutoFilterSet
+
+ return None
+
+ def filter_queryset(self, queryset):
+ filter_class = self.get_filter_class()
+
+ if filter_class:
+ assert issubclass(filter_class.Meta.model, self.model), \
+ "%s is not a subclass of %s" % (filter_class.Meta.model, self.model)
+ return filter_class(self.request.GET, queryset=queryset)
+
+ return queryset
+
+ def get_filtered_queryset(self):
+ return self.filter_queryset(self.get_queryset())
def get_pagination_serializer_class(self):
"""