diff options
| -rw-r--r-- | docs/api-guide/authentication.md | 8 | ||||
| -rw-r--r-- | docs/css/default.css | 11 | ||||
| -rw-r--r-- | docs/index.md | 4 | ||||
| -rw-r--r-- | docs/template.html | 31 | ||||
| -rw-r--r-- | docs/topics/contributing.md | 2 | ||||
| -rw-r--r-- | docs/topics/release-notes.md | 6 | ||||
| -rwxr-xr-x | mkdocs.py | 75 | ||||
| -rw-r--r-- | rest_framework/authentication.py | 26 | ||||
| -rw-r--r-- | rest_framework/compat.py | 33 | ||||
| -rw-r--r-- | rest_framework/filters.py | 2 | ||||
| -rw-r--r-- | rest_framework/templatetags/rest_framework.py | 84 | ||||
| -rw-r--r-- | rest_framework/tests/authentication.py | 44 | ||||
| -rw-r--r-- | rest_framework/tests/filterset.py | 75 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 10 |
14 files changed, 301 insertions, 110 deletions
diff --git a/docs/api-guide/authentication.md b/docs/api-guide/authentication.md index 541c6575..0eea31d7 100644 --- a/docs/api-guide/authentication.md +++ b/docs/api-guide/authentication.md @@ -119,6 +119,8 @@ To use the `TokenAuthentication` scheme, include `rest_framework.authtoken` in y ... 'rest_framework.authtoken' ) + +Make sure to run `manage.py syncdb` after changing your settings. You'll also need to create tokens for your users. @@ -140,6 +142,10 @@ Unauthenticated responses that are denied permission will result in an `HTTP 401 WWW-Authenticate: Token +The `curl` command line tool may be useful for testing token authenticated APIs. For example: + + curl -X GET http://127.0.0.1:8000/api/example/ -H 'Authorization: Token 9944b09199c62bcf9418ad846dd0e4bbdfc6ee4b' + --- **Note:** If you use `TokenAuthentication` in production you must ensure that your API is only available over `https` only. @@ -294,7 +300,7 @@ The only thing needed to make the `OAuth2Authentication` class work is to insert The command line to test the authentication looks like: - curl -H "Authorization: Bearer <your-access-token>" http://localhost:8000/api/?client_id=YOUR_CLIENT_ID\&client_secret=YOUR_CLIENT_SECRET + curl -H "Authorization: Bearer <your-access-token>" http://localhost:8000/api/ --- diff --git a/docs/css/default.css b/docs/css/default.css index c160b63d..173d70e0 100644 --- a/docs/css/default.css +++ b/docs/css/default.css @@ -277,3 +277,14 @@ footer a { footer a:hover { color: gray; } + +.btn-inverse { + background-image: -webkit-gradient(linear, 0 0, 0 100%, from(#606060), to(#404040)) !important; + background-image: -webkit-linear-gradient(top, #606060, #404040) !important; +} + +.modal-open .modal,.btn:focus{outline:none;} + +@media (max-width: 650px) { + .repo-link.btn-inverse {display: none;} +} diff --git a/docs/index.md b/docs/index.md index 5357536d..4c2720c8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,9 +9,9 @@ # Django REST framework -**Web APIs for Django, made easy.** +**Awesome web-browseable Web APIs.** -Django REST framework is a flexible, powerful library that makes it incredibly easy to build Web APIs. It is designed as a modular and easy to customize architecture, based on Django's class based views. +Django REST framework is a flexible, powerful Web API toolkit. It is designed as a modular and easy to customize architecture, based on Django's class based views. APIs built using REST framework are fully self-describing and web browseable - a huge useability win for your developers. It also supports a wide range of media types, authentication and permission policies out of the box. diff --git a/docs/template.html b/docs/template.html index 3e0f29aa..7e929762 100644 --- a/docs/template.html +++ b/docs/template.html @@ -41,6 +41,9 @@ <div class="navbar-inner"> <div class="container-fluid"> <a class="repo-link btn btn-primary btn-small" href="https://github.com/tomchristie/django-rest-framework/tree/master">GitHub</a> + <a class="repo-link btn btn-inverse btn-small {{ next_url_disabled }}" href="{{ next_url }}">Next <i class="icon-arrow-right icon-white"></i></a> + <a class="repo-link btn btn-inverse btn-small {{ prev_url_disabled }}" href="{{ prev_url }}"><i class="icon-arrow-left icon-white"></i> Previous</a> + <a class="repo-link btn btn-inverse btn-small" href="#searchModal" data-toggle="modal"><i class="icon-search icon-white"></i> Search</a> <a class="btn btn-navbar" data-toggle="collapse" data-target=".nav-collapse"> <span class="icon-bar"></span> <span class="icon-bar"></span> @@ -118,6 +121,34 @@ <div class="body-content"> <div class="container-fluid"> + +<!-- Search Modal --> +<div id="searchModal" class="modal hide fade" tabindex="-1" role="dialog" aria-labelledby="myModalLabel" aria-hidden="true"> + <div class="modal-header"> + <button type="button" class="close" data-dismiss="modal" aria-hidden="true">×</button> + <h3 id="myModalLabel">Documentation search</h3> + </div> + <div class="modal-body"> + <!-- Custom google search --> + <script> + (function() { + var cx = '015016005043623903336:rxraeohqk6w'; + var gcse = document.createElement('script'); + gcse.type = 'text/javascript'; + gcse.async = true; + gcse.src = (document.location.protocol == 'https:' ? 'https:' : 'http:') + + '//www.google.com/cse/cse.js?cx=' + cx; + var s = document.getElementsByTagName('script')[0]; + s.parentNode.insertBefore(gcse, s); + })(); + </script> + <gcse:search></gcse:search> + </div> + <div class="modal-footer"> + <button class="btn" data-dismiss="modal" aria-hidden="true">Close</button> + </div> +</div> + <div class="row-fluid"> <div class="span3"> diff --git a/docs/topics/contributing.md b/docs/topics/contributing.md index a13f4461..1d1fe892 100644 --- a/docs/topics/contributing.md +++ b/docs/topics/contributing.md @@ -18,7 +18,7 @@ When answering questions make sure to help future contributors find their way ar # Issues -Usage questions should be directed to the [discussion group][google-group]. Feature requests, bug reports and other issues should be raised on the GitHub [issue tracker][issues]. +It's really helpful if you make sure you address issues to the correct channel. Usage questions should be directed to the [discussion group][google-group]. Feature requests, bug reports and other issues should be raised on the GitHub [issue tracker][issues]. Some tips on good issue reporting: diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md index e63aee49..62c31358 100644 --- a/docs/topics/release-notes.md +++ b/docs/topics/release-notes.md @@ -40,6 +40,12 @@ You can determine your currently installed version using `pip freeze`: ## 2.2.x series +### Master + +* OAuth2 authentication no longer requires unneccessary URL parameters in addition to the token. +* URL hyperlinking in browseable API now handles more cases correctly. +* Bugfix: Fix regression with DjangoFilterBackend not worthing correctly with single object views. + ### 2.2.5 **Date**: 26th March 2013 @@ -37,6 +37,60 @@ page = open(os.path.join(docs_dir, 'template.html'), 'r').read() # shutil.rmtree(target) # shutil.copytree(source, target) + +# Hacky, but what the hell, it'll do the job +path_list = [ + 'index.md', + 'tutorial/quickstart.md', + 'tutorial/1-serialization.md', + 'tutorial/2-requests-and-responses.md', + 'tutorial/3-class-based-views.md', + 'tutorial/4-authentication-and-permissions.md', + 'tutorial/5-relationships-and-hyperlinked-apis.md', + 'api-guide/requests.md', + 'api-guide/responses.md', + 'api-guide/views.md', + 'api-guide/generic-views.md', + 'api-guide/parsers.md', + 'api-guide/renderers.md', + 'api-guide/serializers.md', + 'api-guide/fields.md', + 'api-guide/relations.md', + 'api-guide/authentication.md', + 'api-guide/permissions.md', + 'api-guide/throttling.md', + 'api-guide/filtering.md', + 'api-guide/pagination.md', + 'api-guide/content-negotiation.md', + 'api-guide/format-suffixes.md', + 'api-guide/reverse.md', + 'api-guide/exceptions.md', + 'api-guide/status-codes.md', + 'api-guide/settings.md', + 'topics/ajax-csrf-cors.md', + 'topics/browser-enhancements.md', + 'topics/browsable-api.md', + 'topics/rest-hypermedia-hateoas.md', + 'topics/contributing.md', + 'topics/rest-framework-2-announcement.md', + 'topics/2.2-announcement.md', + 'topics/release-notes.md', + 'topics/credits.md', +] + +prev_url_map = {} +next_url_map = {} +for idx in range(len(path_list)): + path = path_list[idx] + rel = '../' * path.count('/') + + if idx > 0: + prev_url_map[path] = rel + path_list[idx - 1][:-3] + suffix + + if idx < len(path_list) - 1: + next_url_map[path] = rel + path_list[idx + 1][:-3] + suffix + + for (dirpath, dirnames, filenames) in os.walk(docs_dir): relative_dir = dirpath.replace(docs_dir, '').lstrip(os.path.sep) build_dir = os.path.join(html_dir, relative_dir) @@ -46,6 +100,7 @@ for (dirpath, dirnames, filenames) in os.walk(docs_dir): for filename in filenames: path = os.path.join(dirpath, filename) + relative_path = os.path.join(relative_dir, filename) if not filename.endswith('.md'): if relative_dir: @@ -78,16 +133,34 @@ for (dirpath, dirnames, filenames) in os.walk(docs_dir): toc += template + '\n' if filename == 'index.md': - main_title = 'Django REST framework - APIs made easy' + main_title = 'Django REST framework - Web Browseable APIs' else: main_title = 'Django REST framework - ' + main_title + prev_url = prev_url_map.get(relative_path) + next_url = next_url_map.get(relative_path) + content = markdown.markdown(text, ['headerid']) output = page.replace('{{ content }}', content).replace('{{ toc }}', toc).replace('{{ base_url }}', base_url).replace('{{ suffix }}', suffix).replace('{{ index }}', index) output = output.replace('{{ title }}', main_title) output = output.replace('{{ description }}', description) output = output.replace('{{ page_id }}', filename[:-3]) + + if prev_url: + output = output.replace('{{ prev_url }}', prev_url) + output = output.replace('{{ prev_url_disabled }}', '') + else: + output = output.replace('{{ prev_url }}', '#') + output = output.replace('{{ prev_url_disabled }}', 'disabled') + + if next_url: + output = output.replace('{{ next_url }}', next_url) + output = output.replace('{{ next_url_disabled }}', '') + else: + output = output.replace('{{ next_url }}', '#') + output = output.replace('{{ next_url_disabled }}', 'disabled') + output = re.sub(r'a href="([^"]*)\.md"', r'a href="\1%s"' % suffix, output) output = re.sub(r'<pre><code>:::bash', r'<pre class="prettyprint lang-bsh">', output) output = re.sub(r'<pre>', r'<pre class="prettyprint lang-py">', output) diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 8f4ec536..145d4295 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -2,14 +2,16 @@ Provides a set of pluggable authentication policies. """ from __future__ import unicode_literals +import base64 +from datetime import datetime + from django.contrib.auth import authenticate from django.core.exceptions import ImproperlyConfigured from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import oauth, oauth_provider, oauth_provider_store -from rest_framework.compat import oauth2_provider, oauth2_provider_forms, oauth2_provider_backends +from rest_framework.compat import oauth2_provider, oauth2_provider_forms from rest_framework.authtoken.models import Token -import base64 def get_authorization_header(request): @@ -315,21 +317,15 @@ class OAuth2Authentication(BaseAuthentication): Authenticate the request, given the access token. """ - # Authenticate the client - oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST) - if not oauth2_client_form.is_valid(): - raise exceptions.AuthenticationFailed('Client could not be validated') - client = oauth2_client_form.cleaned_data.get('client') - - # Retrieve the `OAuth2AccessToken` instance from the access_token - auth_backend = oauth2_provider_backends.AccessTokenBackend() - token = auth_backend.authenticate(access_token, client) - if token is None: + try: + token = oauth2_provider.models.AccessToken.objects.select_related('user') + # TODO: Change to timezone aware datetime when oauth2_provider add + # support to it. + token = token.get(token=access_token, expires__gt=datetime.now()) + except oauth2_provider.models.AccessToken.DoesNotExist: raise exceptions.AuthenticationFailed('Invalid token') - user = token.user - - if not user.is_active: + if not token.user.is_active: msg = 'User inactive or deleted: %s' % user.username raise exceptions.AuthenticationFailed(msg) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 7b2ef738..6551723a 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -395,6 +395,37 @@ except ImportError: kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None) return datetime.datetime(**kw) + +# smart_urlquote is new on Django 1.4 +try: + from django.utils.html import smart_urlquote +except ImportError: + try: + from urllib.parse import quote, urlsplit, urlunsplit + except ImportError: # Python 2 + from urllib import quote + from urlparse import urlsplit, urlunsplit + + def smart_urlquote(url): + "Quotes a URL if it isn't already quoted." + # Handle IDN before quoting. + scheme, netloc, path, query, fragment = urlsplit(url) + try: + netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE + except UnicodeError: # invalid domain part + pass + else: + url = urlunsplit((scheme, netloc, path, query, fragment)) + + # An URL is considered unquoted if it contains no % characters or + # contains a % not followed by two hexadecimal digits. See #9655. + if '%' not in url or unquoted_percents_re.search(url): + # See http://bugs.python.org/issue2637 + url = quote(force_bytes(url), safe=b'!*\'();:@&=+$,/?#[]~') + + return force_text(url) + + # Markdown is optional try: import markdown @@ -445,14 +476,12 @@ except ImportError: # OAuth 2 support is optional try: import provider.oauth2 as oauth2_provider - from provider.oauth2 import backends as oauth2_provider_backends from provider.oauth2 import models as oauth2_provider_models from provider.oauth2 import forms as oauth2_provider_forms from provider import scope as oauth2_provider_scope from provider import constants as oauth2_constants except ImportError: oauth2_provider = None - oauth2_provider_backends = None oauth2_provider_models = None oauth2_provider_forms = None oauth2_provider_scope = None diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 6fea46fa..413fa0d2 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -55,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend): filter_class = self.get_filter_class(view) if filter_class: - return filter_class(request.QUERY_PARAMS, queryset=queryset) + return filter_class(request.QUERY_PARAMS, queryset=queryset).qs return queryset diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index c21ddcd7..b6ab2de3 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -4,11 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch from django.http import QueryDict from django.utils.html import escape from django.utils.safestring import SafeData, mark_safe -from rest_framework.compat import urlparse -from rest_framework.compat import force_text -from rest_framework.compat import six -import re -import string +from rest_framework.compat import urlparse, force_text, six, smart_urlquote +import re, string register = template.Library() @@ -112,22 +109,6 @@ def replace_query_param(url, key, val): class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') -# Bunch of stuff cloned from urlize -LEADING_PUNCTUATION = ['(', '<', '<', '"', "'"] -TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '>', '"', "'"] -DOTS = ['·', '*', '\xe2\x80\xa2', '•', '•', '•'] -unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)') -word_split_re = re.compile(r'(\s+)') -punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \ - ('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]), - '|'.join([re.escape(x) for x in TRAILING_PUNCTUATION]))) -simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$') -link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+') -html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE) -hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL) -trailing_empty_content_re = re.compile(r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z') - - # And the template tags themselves... @register.simple_tag @@ -195,15 +176,25 @@ def add_class(value, css_class): return value +# Bunch of stuff cloned from urlize +TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"] +WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), + ('"', '"'), ("'", "'")] +word_split_re = re.compile(r'(\s+)') +simple_url_re = re.compile(r'^https?://\w', re.IGNORECASE) +simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE) +simple_email_re = re.compile(r'^\S+@\S+\.\S+$') + + @register.filter def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True): """ Converts any URLs in text into clickable links. - Works on http://, https://, www. links and links ending in .org, .net or - .com. Links can have trailing punctuation (periods, commas, close-parens) - and leading punctuation (opening parens) and it'll still do the right - thing. + Works on http://, https://, www. links, and also on links ending in one of + the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org). + Links can have trailing punctuation (periods, commas, close-parens) and + leading punctuation (opening parens) and it'll still do the right thing. If trim_url_limit is not None, the URLs in link text longer than this limit will truncated to trim_url_limit-3 characters and appended with an elipsis. @@ -216,24 +207,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x safe_input = isinstance(text, SafeData) words = word_split_re.split(force_text(text)) - nofollow_attr = nofollow and ' rel="nofollow"' or '' for i, word in enumerate(words): match = None if '.' in word or '@' in word or ':' in word: - match = punctuation_re.match(word) - if match: - lead, middle, trail = match.groups() + # Deal with punctuation. + lead, middle, trail = '', word, '' + for punctuation in TRAILING_PUNCTUATION: + if middle.endswith(punctuation): + middle = middle[:-len(punctuation)] + trail = punctuation + trail + for opening, closing in WRAPPING_PUNCTUATION: + if middle.startswith(opening): + middle = middle[len(opening):] + lead = lead + opening + # Keep parentheses at the end only if they're balanced. + if (middle.endswith(closing) + and middle.count(closing) == middle.count(opening) + 1): + middle = middle[:-len(closing)] + trail = closing + trail + # Make URL we want to point to. url = None - if middle.startswith('http://') or middle.startswith('https://'): - url = middle - elif middle.startswith('www.') or ('@' not in middle and \ - middle and middle[0] in string.ascii_letters + string.digits and \ - (middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))): - url = 'http://%s' % middle - elif '@' in middle and not ':' in middle and simple_email_re.match(middle): - url = 'mailto:%s' % middle + nofollow_attr = ' rel="nofollow"' if nofollow else '' + if simple_url_re.match(middle): + url = smart_urlquote(middle) + elif simple_url_2_re.match(middle): + url = smart_urlquote('http://%s' % middle) + elif not ':' in middle and simple_email_re.match(middle): + local, domain = middle.rsplit('@', 1) + try: + domain = domain.encode('idna').decode('ascii') + except UnicodeError: + continue + url = 'mailto:%s@%s' % (local, domain) nofollow_attr = '' + # Make link. if url: trimmed = trim_url(middle) @@ -251,4 +259,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru words[i] = mark_safe(word) elif autoescape: words[i] = escape(word) - return mark_safe(''.join(words)) + return ''.join(words) diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index b663ca48..8e6d3e51 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -466,17 +466,13 @@ class OAuth2Tests(TestCase): def _create_authorization_header(self, token=None): return "Bearer {0}".format(token or self.access_token.token) - def _client_credentials_params(self): - return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET} - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_get_form_with_wrong_authorization_header_token_type_failing(self): """Ensure that a wrong token type lead to the correct HTTP error status code""" auth = "Wrong token-type-obsviously" response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -485,8 +481,7 @@ class OAuth2Tests(TestCase): auth = "Bearer wrong token format" response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -495,33 +490,21 @@ class OAuth2Tests(TestCase): auth = "Bearer wrong-token" response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_with_wrong_client_data_failing_auth(self): - """Ensure GETing form over OAuth with incorrect client credentials fails""" - auth = self._create_authorization_header() - params = self._client_credentials_params() - params['client_id'] += 'a' - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_get_form_passing_auth(self): """Ensure GETing form over OAuth with correct client credentials succeed""" auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_post_form_passing_auth(self): """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -529,16 +512,14 @@ class OAuth2Tests(TestCase): """Ensure POSTing when there is no OAuth access token in db fails""" self.access_token.delete() auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_post_form_with_refresh_token_failing_auth(self): """Ensure POSTing with refresh token instead of access token fails""" auth = self._create_authorization_header(token=self.refresh_token.token) - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -547,8 +528,7 @@ class OAuth2Tests(TestCase): self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late self.access_token.save() auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) self.assertIn('Invalid token', response.content) @@ -559,10 +539,9 @@ class OAuth2Tests(TestCase): read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] read_only_access_token.save() auth = self._create_authorization_header(token=read_only_access_token.token) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) - response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -572,6 +551,5 @@ class OAuth2Tests(TestCase): read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] read_write_access_token.save() auth = self._create_authorization_header(token=read_write_access_token.token) - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 238da56e..1a71558c 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -1,11 +1,12 @@ from __future__ import unicode_literals import datetime from decimal import Decimal +from django.core.urlresolvers import reverse from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, filters -from rest_framework.compat import django_filters +from rest_framework.compat import django_filters, patterns, url from rest_framework.tests.models import FilterableItem, BasicModel factory = RequestFactory() @@ -46,12 +47,21 @@ if django_filters: filter_class = MisconfiguredFilter filter_backend = filters.DjangoFilterBackend + class FilterClassDetailView(generics.RetrieveAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backend = filters.DjangoFilterBackend + + urlpatterns = patterns('', + url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), + url(r'^$', FilterClassRootView.as_view(), name='root-view'), + ) -class IntegrationTestFiltering(TestCase): - """ - Integration tests for filtered list views. - """ +class CommonFilteringTestCase(TestCase): + def _serialize_object(self, obj): + return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + def setUp(self): """ Create 10 FilterableItem instances. @@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase): self.objects = FilterableItem.objects self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + self._serialize_object(obj) for obj in self.objects.all() ] + +class IntegrationTestFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered list views. + """ + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_fields_root_view(self): """ @@ -167,3 +183,50 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?integer=%s' % search_integer) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class IntegrationTestDetailFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered detail views. + """ + urls = 'rest_framework.tests.filterset' + + def _get_url(self, item): + return reverse('detail-view', kwargs=dict(pk=item.pk)) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_detail_view(self): + """ + GET requests to filtered RetrieveAPIView that have a filter_class set + should return filtered results. + """ + item = self.objects.all()[0] + data = self._serialize_object(item) + + # Basic test with no filter. + response = self.client.get(self._get_url(item)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, data) + + # Tests that the decimal filter set that should fail. + search_decimal = Decimal('4.25') + high_item = self.objects.filter(decimal__gt=search_decimal)[0] + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + # Tests that the decimal filter set that should succeed. + search_decimal = Decimal('4.25') + low_item = self.objects.filter(decimal__lt=search_decimal)[0] + low_item_data = self._serialize_object(low_item) + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, low_item_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] + valid_item_data = self._serialize_object(valid_item) + response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, valid_item_data) diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index d2c9b051..6b8ef02f 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase): view = FilterFieldsRootView.as_view() EXPECTED_NUM_QUERIES = 2 - if django.VERSION < (1, 4): - # On Django 1.3 we need to use django-filter 0.5.4 - # - # The filter objects there don't expose a `.count()` method, - # which means we only make a single query *but* it's a single - # query across *all* of the queryset, instead of a COUNT and then - # a SELECT with a LIMIT. - # - # Although this is fewer queries, it's actually a regression. - EXPECTED_NUM_QUERIES = 1 request = factory.get('/?decimal=15.20') with self.assertNumQueries(EXPECTED_NUM_QUERIES): |
