aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/tests
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/tests')
-rw-r--r--rest_framework/tests/filterset.py75
-rw-r--r--rest_framework/tests/pagination.py10
2 files changed, 69 insertions, 16 deletions
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
index 238da56e..1a71558c 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/filterset.py
@@ -1,11 +1,12 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
+from django.core.urlresolvers import reverse
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, status, filters
-from rest_framework.compat import django_filters
+from rest_framework.compat import django_filters, patterns, url
from rest_framework.tests.models import FilterableItem, BasicModel
factory = RequestFactory()
@@ -46,12 +47,21 @@ if django_filters:
filter_class = MisconfiguredFilter
filter_backend = filters.DjangoFilterBackend
+ class FilterClassDetailView(generics.RetrieveAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backend = filters.DjangoFilterBackend
+
+ urlpatterns = patterns('',
+ url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
+ url(r'^$', FilterClassRootView.as_view(), name='root-view'),
+ )
-class IntegrationTestFiltering(TestCase):
- """
- Integration tests for filtered list views.
- """
+class CommonFilteringTestCase(TestCase):
+ def _serialize_object(self, obj):
+ return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+
def setUp(self):
"""
Create 10 FilterableItem instances.
@@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ self._serialize_object(obj)
for obj in self.objects.all()
]
+
+class IntegrationTestFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
@unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_filtered_fields_root_view(self):
"""
@@ -167,3 +183,50 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?integer=%s' % search_integer)
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class IntegrationTestDetailFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered detail views.
+ """
+ urls = 'rest_framework.tests.filterset'
+
+ def _get_url(self, item):
+ return reverse('detail-view', kwargs=dict(pk=item.pk))
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_detail_view(self):
+ """
+ GET requests to filtered RetrieveAPIView that have a filter_class set
+ should return filtered results.
+ """
+ item = self.objects.all()[0]
+ data = self._serialize_object(item)
+
+ # Basic test with no filter.
+ response = self.client.get(self._get_url(item))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, data)
+
+ # Tests that the decimal filter set that should fail.
+ search_decimal = Decimal('4.25')
+ high_item = self.objects.filter(decimal__gt=search_decimal)[0]
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Tests that the decimal filter set that should succeed.
+ search_decimal = Decimal('4.25')
+ low_item = self.objects.filter(decimal__lt=search_decimal)[0]
+ low_item_data = self._serialize_object(low_item)
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, low_item_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
+ valid_item_data = self._serialize_object(valid_item)
+ response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, valid_item_data)
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index d2c9b051..6b8ef02f 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase):
view = FilterFieldsRootView.as_view()
EXPECTED_NUM_QUERIES = 2
- if django.VERSION < (1, 4):
- # On Django 1.3 we need to use django-filter 0.5.4
- #
- # The filter objects there don't expose a `.count()` method,
- # which means we only make a single query *but* it's a single
- # query across *all* of the queryset, instead of a COUNT and then
- # a SELECT with a LIMIT.
- #
- # Although this is fewer queries, it's actually a regression.
- EXPECTED_NUM_QUERIES = 1
request = factory.get('/?decimal=15.20')
with self.assertNumQueries(EXPECTED_NUM_QUERIES):