diff options
Diffstat (limited to 'rest_framework/tests')
| -rw-r--r-- | rest_framework/tests/authentication.py | 44 | ||||
| -rw-r--r-- | rest_framework/tests/filterset.py | 75 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 10 |
3 files changed, 80 insertions, 49 deletions
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index b663ca48..8e6d3e51 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -466,17 +466,13 @@ class OAuth2Tests(TestCase): def _create_authorization_header(self, token=None): return "Bearer {0}".format(token or self.access_token.token) - def _client_credentials_params(self): - return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET} - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_get_form_with_wrong_authorization_header_token_type_failing(self): """Ensure that a wrong token type lead to the correct HTTP error status code""" auth = "Wrong token-type-obsviously" response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -485,8 +481,7 @@ class OAuth2Tests(TestCase): auth = "Bearer wrong token format" response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -495,33 +490,21 @@ class OAuth2Tests(TestCase): auth = "Bearer wrong-token" response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_with_wrong_client_data_failing_auth(self): - """Ensure GETing form over OAuth with incorrect client credentials fails""" - auth = self._create_authorization_header() - params = self._client_credentials_params() - params['client_id'] += 'a' - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_get_form_passing_auth(self): """Ensure GETing form over OAuth with correct client credentials succeed""" auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_post_form_passing_auth(self): """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -529,16 +512,14 @@ class OAuth2Tests(TestCase): """Ensure POSTing when there is no OAuth access token in db fails""" self.access_token.delete() auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_post_form_with_refresh_token_failing_auth(self): """Ensure POSTing with refresh token instead of access token fails""" auth = self._create_authorization_header(token=self.refresh_token.token) - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -547,8 +528,7 @@ class OAuth2Tests(TestCase): self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late self.access_token.save() auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) self.assertIn('Invalid token', response.content) @@ -559,10 +539,9 @@ class OAuth2Tests(TestCase): read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] read_only_access_token.save() auth = self._create_authorization_header(token=read_only_access_token.token) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) - response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -572,6 +551,5 @@ class OAuth2Tests(TestCase): read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] read_write_access_token.save() auth = self._create_authorization_header(token=read_write_access_token.token) - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 238da56e..1a71558c 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -1,11 +1,12 @@ from __future__ import unicode_literals import datetime from decimal import Decimal +from django.core.urlresolvers import reverse from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, filters -from rest_framework.compat import django_filters +from rest_framework.compat import django_filters, patterns, url from rest_framework.tests.models import FilterableItem, BasicModel factory = RequestFactory() @@ -46,12 +47,21 @@ if django_filters: filter_class = MisconfiguredFilter filter_backend = filters.DjangoFilterBackend + class FilterClassDetailView(generics.RetrieveAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backend = filters.DjangoFilterBackend + + urlpatterns = patterns('', + url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), + url(r'^$', FilterClassRootView.as_view(), name='root-view'), + ) -class IntegrationTestFiltering(TestCase): - """ - Integration tests for filtered list views. - """ +class CommonFilteringTestCase(TestCase): + def _serialize_object(self, obj): + return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + def setUp(self): """ Create 10 FilterableItem instances. @@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase): self.objects = FilterableItem.objects self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + self._serialize_object(obj) for obj in self.objects.all() ] + +class IntegrationTestFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered list views. + """ + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_fields_root_view(self): """ @@ -167,3 +183,50 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?integer=%s' % search_integer) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class IntegrationTestDetailFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered detail views. + """ + urls = 'rest_framework.tests.filterset' + + def _get_url(self, item): + return reverse('detail-view', kwargs=dict(pk=item.pk)) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_detail_view(self): + """ + GET requests to filtered RetrieveAPIView that have a filter_class set + should return filtered results. + """ + item = self.objects.all()[0] + data = self._serialize_object(item) + + # Basic test with no filter. + response = self.client.get(self._get_url(item)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, data) + + # Tests that the decimal filter set that should fail. + search_decimal = Decimal('4.25') + high_item = self.objects.filter(decimal__gt=search_decimal)[0] + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + # Tests that the decimal filter set that should succeed. + search_decimal = Decimal('4.25') + low_item = self.objects.filter(decimal__lt=search_decimal)[0] + low_item_data = self._serialize_object(low_item) + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, low_item_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] + valid_item_data = self._serialize_object(valid_item) + response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, valid_item_data) diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index d2c9b051..6b8ef02f 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase): view = FilterFieldsRootView.as_view() EXPECTED_NUM_QUERIES = 2 - if django.VERSION < (1, 4): - # On Django 1.3 we need to use django-filter 0.5.4 - # - # The filter objects there don't expose a `.count()` method, - # which means we only make a single query *but* it's a single - # query across *all* of the queryset, instead of a COUNT and then - # a SELECT with a LIMIT. - # - # Although this is fewer queries, it's actually a regression. - EXPECTED_NUM_QUERIES = 1 request = factory.get('/?decimal=15.20') with self.assertNumQueries(EXPECTED_NUM_QUERIES): |
