diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/authentication.py | 183 | ||||
| -rw-r--r-- | rest_framework/compat.py | 79 | ||||
| -rw-r--r-- | rest_framework/exceptions.py | 5 | ||||
| -rw-r--r-- | rest_framework/parsers.py | 98 | ||||
| -rw-r--r-- | rest_framework/permissions.py | 28 | ||||
| -rw-r--r-- | rest_framework/renderers.py | 111 | ||||
| -rw-r--r-- | rest_framework/reverse.py | 12 | ||||
| -rw-r--r-- | rest_framework/settings.py | 15 | ||||
| -rw-r--r-- | rest_framework/utils/encoders.py | 65 | ||||
| -rw-r--r-- | rest_framework/versioning.py | 174 | ||||
| -rw-r--r-- | rest_framework/views.py | 56 |
11 files changed, 263 insertions, 563 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 4832ad33..124ef68a 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -3,14 +3,9 @@ Provides various authentication policies. """ from __future__ import unicode_literals import base64 - from django.contrib.auth import authenticate -from django.core.exceptions import ImproperlyConfigured from django.middleware.csrf import CsrfViewMiddleware -from django.conf import settings from rest_framework import exceptions, HTTP_HEADER_ENCODING -from rest_framework.compat import oauth, oauth_provider, oauth_provider_store -from rest_framework.compat import oauth2_provider, provider_now, check_nonce from rest_framework.authtoken.models import Token @@ -178,181 +173,3 @@ class TokenAuthentication(BaseAuthentication): def authenticate_header(self, request): return 'Token' - - -class OAuthAuthentication(BaseAuthentication): - """ - OAuth 1.0a authentication backend using `django-oauth-plus` and `oauth2`. - - Note: The `oauth2` package actually provides oauth1.0a support. Urg. - We import it from the `compat` module as `oauth`. - """ - www_authenticate_realm = 'api' - - def __init__(self, *args, **kwargs): - super(OAuthAuthentication, self).__init__(*args, **kwargs) - - if oauth is None: - raise ImproperlyConfigured( - "The 'oauth2' package could not be imported." - "It is required for use with the 'OAuthAuthentication' class.") - - if oauth_provider is None: - raise ImproperlyConfigured( - "The 'django-oauth-plus' package could not be imported." - "It is required for use with the 'OAuthAuthentication' class.") - - def authenticate(self, request): - """ - Returns two-tuple of (user, token) if authentication succeeds, - or None otherwise. - """ - try: - oauth_request = oauth_provider.utils.get_oauth_request(request) - except oauth.Error as err: - raise exceptions.AuthenticationFailed(err.message) - - if not oauth_request: - return None - - oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES - - found = any(param for param in oauth_params if param in oauth_request) - missing = list(param for param in oauth_params if param not in oauth_request) - - if not found: - # OAuth authentication was not attempted. - return None - - if missing: - # OAuth was attempted but missing parameters. - msg = 'Missing parameters: %s' % (', '.join(missing)) - raise exceptions.AuthenticationFailed(msg) - - if not self.check_nonce(request, oauth_request): - msg = 'Nonce check failed' - raise exceptions.AuthenticationFailed(msg) - - try: - consumer_key = oauth_request.get_parameter('oauth_consumer_key') - consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key) - except oauth_provider.store.InvalidConsumerError: - msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key') - raise exceptions.AuthenticationFailed(msg) - - if consumer.status != oauth_provider.consts.ACCEPTED: - msg = 'Invalid consumer key status: %s' % consumer.get_status_display() - raise exceptions.AuthenticationFailed(msg) - - try: - token_param = oauth_request.get_parameter('oauth_token') - token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param) - except oauth_provider.store.InvalidTokenError: - msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token') - raise exceptions.AuthenticationFailed(msg) - - try: - self.validate_token(request, consumer, token) - except oauth.Error as err: - raise exceptions.AuthenticationFailed(err.message) - - user = token.user - - if not user.is_active: - msg = 'User inactive or deleted: %s' % user.username - raise exceptions.AuthenticationFailed(msg) - - return (token.user, token) - - def authenticate_header(self, request): - """ - If permission is denied, return a '401 Unauthorized' response, - with an appropriate 'WWW-Authenticate' header. - """ - return 'OAuth realm="%s"' % self.www_authenticate_realm - - def validate_token(self, request, consumer, token): - """ - Check the token and raise an `oauth.Error` exception if invalid. - """ - oauth_server, oauth_request = oauth_provider.utils.initialize_server_request(request) - oauth_server.verify_request(oauth_request, consumer, token) - - def check_nonce(self, request, oauth_request): - """ - Checks nonce of request, and return True if valid. - """ - oauth_nonce = oauth_request['oauth_nonce'] - oauth_timestamp = oauth_request['oauth_timestamp'] - return check_nonce(request, oauth_request, oauth_nonce, oauth_timestamp) - - -class OAuth2Authentication(BaseAuthentication): - """ - OAuth 2 authentication backend using `django-oauth2-provider` - """ - www_authenticate_realm = 'api' - allow_query_params_token = settings.DEBUG - - def __init__(self, *args, **kwargs): - super(OAuth2Authentication, self).__init__(*args, **kwargs) - - if oauth2_provider is None: - raise ImproperlyConfigured( - "The 'django-oauth2-provider' package could not be imported. " - "It is required for use with the 'OAuth2Authentication' class.") - - def authenticate(self, request): - """ - Returns two-tuple of (user, token) if authentication succeeds, - or None otherwise. - """ - - auth = get_authorization_header(request).split() - - if len(auth) == 1: - msg = 'Invalid bearer header. No credentials provided.' - raise exceptions.AuthenticationFailed(msg) - elif len(auth) > 2: - msg = 'Invalid bearer header. Token string should not contain spaces.' - raise exceptions.AuthenticationFailed(msg) - - if auth and auth[0].lower() == b'bearer': - access_token = auth[1] - elif 'access_token' in request.POST: - access_token = request.POST['access_token'] - elif 'access_token' in request.GET and self.allow_query_params_token: - access_token = request.GET['access_token'] - else: - return None - - return self.authenticate_credentials(request, access_token) - - def authenticate_credentials(self, request, access_token): - """ - Authenticate the request, given the access token. - """ - - try: - token = oauth2_provider.oauth2.models.AccessToken.objects.select_related('user') - # provider_now switches to timezone aware datetime when - # the oauth2_provider version supports to it. - token = token.get(token=access_token, expires__gt=provider_now()) - except oauth2_provider.oauth2.models.AccessToken.DoesNotExist: - raise exceptions.AuthenticationFailed('Invalid token') - - user = token.user - - if not user.is_active: - msg = 'User inactive or deleted: %s' % user.get_username() - raise exceptions.AuthenticationFailed(msg) - - return (user, token) - - def authenticate_header(self, request): - """ - Bearer is the only finalized type currently - - Check details on the `OAuth2Authentication.authenticate` method - """ - return 'Bearer realm="%s"' % self.www_authenticate_realm diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 69fdd793..3c8fb0da 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -5,15 +5,13 @@ versions of django/python, and compatibility wrappers around optional packages. # flake8: noqa from __future__ import unicode_literals - -import inspect - from django.core.exceptions import ImproperlyConfigured +from django.conf import settings from django.utils.encoding import force_text from django.utils.six.moves.urllib import parse as urlparse -from django.conf import settings from django.utils import six import django +import inspect def unicode_repr(instance): @@ -33,6 +31,13 @@ def unicode_to_repr(value): return value +def unicode_http_header(value): + # Coerce HTTP header value to unicode. + if isinstance(value, six.binary_type): + return value.decode('iso-8859-1') + return value + + # OrderedDict only available in Python 2.7. # This will always be the case in Django 1.7 and above, as these versions # no longer support Python 2.6. @@ -207,72 +212,6 @@ except ImportError: apply_markdown = None -# Yaml is optional -try: - import yaml -except ImportError: - yaml = None - - -# XML is optional -try: - import defusedxml.ElementTree as etree -except ImportError: - etree = None - - -# OAuth2 is optional -try: - # Note: The `oauth2` package actually provides oauth1.0a support. Urg. - import oauth2 as oauth -except ImportError: - oauth = None - - -# OAuthProvider is optional -try: - import oauth_provider - from oauth_provider.store import store as oauth_provider_store - - # check_nonce's calling signature in django-oauth-plus changes sometime - # between versions 2.0 and 2.2.1 - def check_nonce(request, oauth_request, oauth_nonce, oauth_timestamp): - check_nonce_args = inspect.getargspec(oauth_provider_store.check_nonce).args - if 'timestamp' in check_nonce_args: - return oauth_provider_store.check_nonce( - request, oauth_request, oauth_nonce, oauth_timestamp - ) - return oauth_provider_store.check_nonce( - request, oauth_request, oauth_nonce - ) - -except (ImportError, ImproperlyConfigured): - oauth_provider = None - oauth_provider_store = None - check_nonce = None - - -# OAuth 2 support is optional -try: - import provider as oauth2_provider - from provider import scope as oauth2_provider_scope - from provider import constants as oauth2_constants - - if oauth2_provider.__version__ in ('0.2.3', '0.2.4'): - # 0.2.3 and 0.2.4 are supported version that do not support - # timezone aware datetimes - import datetime - - provider_now = datetime.datetime.now - else: - # Any other supported version does use timezone aware datetimes - from django.utils.timezone import now as provider_now -except ImportError: - oauth2_provider = None - oauth2_provider_scope = None - oauth2_constants = None - provider_now = None - # `separators` argument to `json.dumps()` differs between 2.x and 3.x # See: http://bugs.python.org/issue22767 if six.PY3: diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 1f381e4e..bcfd8961 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -89,6 +89,11 @@ class PermissionDenied(APIException): default_detail = _('You do not have permission to perform this action.') +class NotFound(APIException): + status_code = status.HTTP_404_NOT_FOUND + default_detail = _('Not found') + + class MethodNotAllowed(APIException): status_code = status.HTTP_405_METHOD_NOT_ALLOWED default_detail = _("Method '%s' not allowed.") diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 3e3395c0..cb23423d 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -14,12 +14,9 @@ from django.http.multipartparser import MultiPartParserError, parse_header, Chun from django.utils import six from django.utils.six.moves.urllib import parse as urlparse from django.utils.encoding import force_text -from rest_framework.compat import etree, yaml from rest_framework.exceptions import ParseError from rest_framework import renderers import json -import datetime -import decimal class DataAndFiles(object): @@ -67,29 +64,6 @@ class JSONParser(BaseParser): raise ParseError('JSON parse error - %s' % six.text_type(exc)) -class YAMLParser(BaseParser): - """ - Parses YAML-serialized data. - """ - - media_type = 'application/yaml' - - def parse(self, stream, media_type=None, parser_context=None): - """ - Parses the incoming bytestream as YAML and returns the resulting data. - """ - assert yaml, 'YAMLParser requires pyyaml to be installed' - - parser_context = parser_context or {} - encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) - - try: - data = stream.read().decode(encoding) - return yaml.safe_load(data) - except (ValueError, yaml.parser.ParserError) as exc: - raise ParseError('YAML parse error - %s' % six.text_type(exc)) - - class FormParser(BaseParser): """ Parser for form data. @@ -138,78 +112,6 @@ class MultiPartParser(BaseParser): raise ParseError('Multipart form parse error - %s' % six.text_type(exc)) -class XMLParser(BaseParser): - """ - XML parser. - """ - - media_type = 'application/xml' - - def parse(self, stream, media_type=None, parser_context=None): - """ - Parses the incoming bytestream as XML and returns the resulting data. - """ - assert etree, 'XMLParser requires defusedxml to be installed' - - parser_context = parser_context or {} - encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) - parser = etree.DefusedXMLParser(encoding=encoding) - try: - tree = etree.parse(stream, parser=parser, forbid_dtd=True) - except (etree.ParseError, ValueError) as exc: - raise ParseError('XML parse error - %s' % six.text_type(exc)) - data = self._xml_convert(tree.getroot()) - - return data - - def _xml_convert(self, element): - """ - convert the xml `element` into the corresponding python object - """ - - children = list(element) - - if len(children) == 0: - return self._type_convert(element.text) - else: - # if the fist child tag is list-item means all children are list-item - if children[0].tag == "list-item": - data = [] - for child in children: - data.append(self._xml_convert(child)) - else: - data = {} - for child in children: - data[child.tag] = self._xml_convert(child) - - return data - - def _type_convert(self, value): - """ - Converts the value returned by the XMl parse into the equivalent - Python type - """ - if value is None: - return value - - try: - return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') - except ValueError: - pass - - try: - return int(value) - except ValueError: - pass - - try: - return decimal.Decimal(value) - except decimal.InvalidOperation: - pass - - return value - - class FileUploadParser(BaseParser): """ Parser for file upload data. diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 3f6f5961..9069d315 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -3,8 +3,7 @@ Provides a set of pluggable permission policies. """ from __future__ import unicode_literals from django.http import Http404 -from rest_framework.compat import (get_model_name, oauth2_provider_scope, - oauth2_constants) +from rest_framework.compat import get_model_name SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS'] @@ -199,28 +198,3 @@ class DjangoObjectPermissions(DjangoModelPermissions): return False return True - - -class TokenHasReadWriteScope(BasePermission): - """ - The request is authenticated as a user and the token used has the right scope - """ - - def has_permission(self, request, view): - token = request.auth - read_only = request.method in SAFE_METHODS - - if not token: - return False - - if hasattr(token, 'resource'): # OAuth 1 - return read_only or not request.auth.resource.is_readonly - elif hasattr(token, 'scope'): # OAuth 2 - required = oauth2_constants.READ if read_only else oauth2_constants.WRITE - return oauth2_provider_scope.check(required, request.auth.scope) - - assert False, ( - 'TokenHasReadWriteScope requires either the' - '`OAuthAuthentication` or `OAuth2Authentication` authentication ' - 'class to be used.' - ) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 634338e9..c4de30db 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -17,11 +17,8 @@ from django.http.multipartparser import parse_header from django.template import Context, RequestContext, loader, Template from django.test.client import encode_multipart from django.utils import six -from django.utils.encoding import smart_text -from django.utils.xmlutils import SimplerXMLGenerator -from django.utils.six.moves import StringIO from rest_framework import exceptions, serializers, status, VERSION -from rest_framework.compat import SHORT_SEPARATORS, LONG_SEPARATORS, yaml +from rest_framework.compat import SHORT_SEPARATORS, LONG_SEPARATORS from rest_framework.exceptions import ParseError from rest_framework.settings import api_settings from rest_framework.request import is_form_media_type, override_method @@ -112,112 +109,6 @@ class JSONRenderer(BaseRenderer): return ret -class JSONPRenderer(JSONRenderer): - """ - Renderer which serializes to json, - wrapping the json output in a callback function. - """ - - media_type = 'application/javascript' - format = 'jsonp' - callback_parameter = 'callback' - default_callback = 'callback' - charset = 'utf-8' - - def get_callback(self, renderer_context): - """ - Determine the name of the callback to wrap around the json output. - """ - request = renderer_context.get('request', None) - params = request and request.query_params or {} - return params.get(self.callback_parameter, self.default_callback) - - def render(self, data, accepted_media_type=None, renderer_context=None): - """ - Renders into jsonp, wrapping the json output in a callback function. - - Clients may set the callback function name using a query parameter - on the URL, for example: ?callback=exampleCallbackName - """ - renderer_context = renderer_context or {} - callback = self.get_callback(renderer_context) - json = super(JSONPRenderer, self).render(data, accepted_media_type, - renderer_context) - return callback.encode(self.charset) + b'(' + json + b');' - - -class XMLRenderer(BaseRenderer): - """ - Renderer which serializes to XML. - """ - - media_type = 'application/xml' - format = 'xml' - charset = 'utf-8' - - def render(self, data, accepted_media_type=None, renderer_context=None): - """ - Renders `data` into serialized XML. - """ - if data is None: - return '' - - stream = StringIO() - - xml = SimplerXMLGenerator(stream, self.charset) - xml.startDocument() - xml.startElement("root", {}) - - self._to_xml(xml, data) - - xml.endElement("root") - xml.endDocument() - return stream.getvalue() - - def _to_xml(self, xml, data): - if isinstance(data, (list, tuple)): - for item in data: - xml.startElement("list-item", {}) - self._to_xml(xml, item) - xml.endElement("list-item") - - elif isinstance(data, dict): - for key, value in six.iteritems(data): - xml.startElement(key, {}) - self._to_xml(xml, value) - xml.endElement(key) - - elif data is None: - # Don't output any value - pass - - else: - xml.characters(smart_text(data)) - - -class YAMLRenderer(BaseRenderer): - """ - Renderer which serializes to YAML. - """ - - media_type = 'application/yaml' - format = 'yaml' - encoder = encoders.SafeDumper - charset = 'utf-8' - ensure_ascii = False - - def render(self, data, accepted_media_type=None, renderer_context=None): - """ - Renders `data` into serialized YAML. - """ - assert yaml, 'YAMLRenderer requires pyyaml to be installed' - - if data is None: - return '' - - return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii) - - class TemplateHTMLRenderer(BaseRenderer): """ An HTML renderer for use with templates. diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index a74e8aa2..8fcca55b 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -9,6 +9,18 @@ from django.utils.functional import lazy def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra): """ + If versioning is being used then we pass any `reverse` calls through + to the versioning scheme instance, so that the resulting URL + can be modified if needed. + """ + scheme = getattr(request, 'versioning_scheme', None) + if scheme is not None: + return scheme.reverse(viewname, args, kwargs, request, format, **extra) + return _reverse(viewname, args, kwargs, request, format, **extra) + + +def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra): + """ Same as `django.core.urlresolvers.reverse`, but optionally takes a request and returns a fully qualified URL, using the request to get the base URL. """ diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 33f84813..877d461b 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -5,11 +5,11 @@ For example your project's `settings.py` file might look like this: REST_FRAMEWORK = { 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer', - 'rest_framework.renderers.YAMLRenderer', + 'rest_framework.renderers.TemplateHTMLRenderer', ) 'DEFAULT_PARSER_CLASSES': ( 'rest_framework.parsers.JSONParser', - 'rest_framework.parsers.YAMLParser', + 'rest_framework.parsers.TemplateHTMLRenderer', ) } @@ -46,6 +46,7 @@ DEFAULTS = { 'DEFAULT_THROTTLE_CLASSES': (), 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', 'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata', + 'DEFAULT_VERSIONING_CLASS': None, # Generic view behavior 'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer', @@ -67,6 +68,11 @@ DEFAULTS = { 'SEARCH_PARAM': 'search', 'ORDERING_PARAM': 'ordering', + # Versioning + 'DEFAULT_VERSION': None, + 'ALLOWED_VERSIONS': None, + 'VERSION_PARAM': 'version', + # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -123,6 +129,7 @@ IMPORT_STRINGS = ( 'DEFAULT_THROTTLE_CLASSES', 'DEFAULT_CONTENT_NEGOTIATION_CLASS', 'DEFAULT_METADATA_CLASS', + 'DEFAULT_VERSIONING_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS', 'DEFAULT_FILTER_BACKENDS', 'EXCEPTION_HANDLER', @@ -139,7 +146,9 @@ def perform_import(val, setting_name): If the given setting is a string import notation, then perform the necessary import or imports. """ - if isinstance(val, six.string_types): + if val is None: + return None + elif isinstance(val, six.string_types): return import_from_string(val, setting_name) elif isinstance(val, (list, tuple)): return [import_from_string(item, setting_name) for item in val] diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 73cbe5d8..0bd24939 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -6,11 +6,8 @@ from django.db.models.query import QuerySet from django.utils import six, timezone from django.utils.encoding import force_text from django.utils.functional import Promise -from rest_framework.compat import OrderedDict -from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList import datetime import decimal -import types import json @@ -58,65 +55,3 @@ class JSONEncoder(json.JSONEncoder): elif hasattr(obj, '__iter__'): return tuple(item for item in obj) return super(JSONEncoder, self).default(obj) - - -try: - import yaml -except ImportError: - SafeDumper = None -else: - # Adapted from http://pyyaml.org/attachment/ticket/161/use_ordered_dict.py - class SafeDumper(yaml.SafeDumper): - """ - Handles decimals as strings. - Handles OrderedDicts as usual dicts, but preserves field order, rather - than the usual behaviour of sorting the keys. - """ - def represent_decimal(self, data): - return self.represent_scalar('tag:yaml.org,2002:str', six.text_type(data)) - - def represent_mapping(self, tag, mapping, flow_style=None): - value = [] - node = yaml.MappingNode(tag, value, flow_style=flow_style) - if self.alias_key is not None: - self.represented_objects[self.alias_key] = node - best_style = True - if hasattr(mapping, 'items'): - mapping = list(mapping.items()) - if not isinstance(mapping, OrderedDict): - mapping.sort() - for item_key, item_value in mapping: - node_key = self.represent_data(item_key) - node_value = self.represent_data(item_value) - if not (isinstance(node_key, yaml.ScalarNode) and not node_key.style): - best_style = False - if not (isinstance(node_value, yaml.ScalarNode) and not node_value.style): - best_style = False - value.append((node_key, node_value)) - if flow_style is None: - if self.default_flow_style is not None: - node.flow_style = self.default_flow_style - else: - node.flow_style = best_style - return node - - SafeDumper.add_representer( - decimal.Decimal, - SafeDumper.represent_decimal - ) - SafeDumper.add_representer( - OrderedDict, - yaml.representer.SafeRepresenter.represent_dict - ) - SafeDumper.add_representer( - ReturnDict, - yaml.representer.SafeRepresenter.represent_dict - ) - SafeDumper.add_representer( - ReturnList, - yaml.representer.SafeRepresenter.represent_list - ) - SafeDumper.add_representer( - types.GeneratorType, - yaml.representer.SafeRepresenter.represent_list - ) diff --git a/rest_framework/versioning.py b/rest_framework/versioning.py new file mode 100644 index 00000000..440efd13 --- /dev/null +++ b/rest_framework/versioning.py @@ -0,0 +1,174 @@ +# coding: utf-8 +from __future__ import unicode_literals +from django.utils.translation import ugettext_lazy as _ +from rest_framework import exceptions +from rest_framework.compat import unicode_http_header +from rest_framework.reverse import _reverse +from rest_framework.settings import api_settings +from rest_framework.templatetags.rest_framework import replace_query_param +from rest_framework.utils.mediatypes import _MediaType +import re + + +class BaseVersioning(object): + default_version = api_settings.DEFAULT_VERSION + allowed_versions = api_settings.ALLOWED_VERSIONS + version_param = api_settings.VERSION_PARAM + + def determine_version(self, request, *args, **kwargs): + msg = '{cls}.determine_version() must be implemented.' + raise NotImplemented(msg.format( + cls=self.__class__.__name__ + )) + + def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): + return _reverse(viewname, args, kwargs, request, format, **extra) + + def is_allowed_version(self, version): + if not self.allowed_versions: + return True + return (version == self.default_version) or (version in self.allowed_versions) + + +class AcceptHeaderVersioning(BaseVersioning): + """ + GET /something/ HTTP/1.1 + Host: example.com + Accept: application/json; version=1.0 + """ + invalid_version_message = _("Invalid version in 'Accept' header.") + + def determine_version(self, request, *args, **kwargs): + media_type = _MediaType(request.accepted_media_type) + version = media_type.params.get(self.version_param, self.default_version) + version = unicode_http_header(version) + if not self.is_allowed_version(version): + raise exceptions.NotAcceptable(self.invalid_version_message) + return version + + # We don't need to implement `reverse`, as the versioning is based + # on the `Accept` header, not on the request URL. + + +class URLPathVersioning(BaseVersioning): + """ + To the client this is the same style as `NamespaceVersioning`. + The difference is in the backend - this implementation uses + Django's URL keyword arguments to determine the version. + + An example URL conf for two views that accept two different versions. + + urlpatterns = [ + url(r'^(?P<version>{v1,v2})/users/$', users_list, name='users-list'), + url(r'^(?P<version>{v1,v2})/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail') + ] + + GET /1.0/something/ HTTP/1.1 + Host: example.com + Accept: application/json + """ + invalid_version_message = _('Invalid version in URL path.') + + def determine_version(self, request, *args, **kwargs): + version = kwargs.get(self.version_param, self.default_version) + if not self.is_allowed_version(version): + raise exceptions.NotFound(self.invalid_version_message) + return version + + def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): + if request.version is not None: + kwargs = {} if (kwargs is None) else kwargs + kwargs[self.version_param] = request.version + + return super(URLPathVersioning, self).reverse( + viewname, args, kwargs, request, format, **extra + ) + + +class NamespaceVersioning(BaseVersioning): + """ + To the client this is the same style as `URLPathVersioning`. + The difference is in the backend - this implementation uses + Django's URL namespaces to determine the version. + + An example URL conf that is namespaced into two seperate versions + + # users/urls.py + urlpatterns = [ + url(r'^/users/$', users_list, name='users-list'), + url(r'^/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail') + ] + + # urls.py + urlpatterns = [ + url(r'^v1/', include('users.urls', namespace='v1')), + url(r'^v2/', include('users.urls', namespace='v2')) + ] + + GET /1.0/something/ HTTP/1.1 + Host: example.com + Accept: application/json + """ + invalid_version_message = _('Invalid version in URL path.') + + def determine_version(self, request, *args, **kwargs): + resolver_match = getattr(request, 'resolver_match', None) + if (resolver_match is None or not resolver_match.namespace): + return self.default_version + version = resolver_match.namespace + if not self.is_allowed_version(version): + raise exceptions.NotFound(self.invalid_version_message) + return version + + def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): + if request.version is not None: + viewname = request.version + ':' + viewname + return super(NamespaceVersioning, self).reverse( + viewname, args, kwargs, request, format, **extra + ) + + +class HostNameVersioning(BaseVersioning): + """ + GET /something/ HTTP/1.1 + Host: v1.example.com + Accept: application/json + """ + hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$') + invalid_version_message = _('Invalid version in hostname.') + + def determine_version(self, request, *args, **kwargs): + hostname, seperator, port = request.get_host().partition(':') + match = self.hostname_regex.match(hostname) + if not match: + return self.default_version + version = match.group(1) + if not self.is_allowed_version(version): + raise exceptions.NotFound(self.invalid_version_message) + return version + + # We don't need to implement `reverse`, as the hostname will already be + # preserved as part of the REST framework `reverse` implementation. + + +class QueryParameterVersioning(BaseVersioning): + """ + GET /something/?version=0.1 HTTP/1.1 + Host: example.com + Accept: application/json + """ + invalid_version_message = _('Invalid version in query parameter.') + + def determine_version(self, request, *args, **kwargs): + version = request.query_params.get(self.version_param) + if not self.is_allowed_version(version): + raise exceptions.NotFound(self.invalid_version_message) + return version + + def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): + url = super(QueryParameterVersioning, self).reverse( + viewname, args, kwargs, request, format, **extra + ) + if request.version is not None: + return replace_query_param(url, self.version_param, request.version) + return url diff --git a/rest_framework/views.py b/rest_framework/views.py index bc870417..12bb78bd 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -2,6 +2,8 @@ Provides an APIView class that is the base of all views in REST framework. """ from __future__ import unicode_literals +import inspect +import warnings from django.core.exceptions import PermissionDenied from django.http import Http404 @@ -46,7 +48,7 @@ def get_view_description(view_cls, html=False): return description -def exception_handler(exc): +def exception_handler(exc, context): """ Returns the response that should be used for any given exception. @@ -93,6 +95,7 @@ class APIView(View): permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS metadata_class = api_settings.DEFAULT_METADATA_CLASS + versioning_class = api_settings.DEFAULT_VERSIONING_CLASS # Allow dependency injection of other settings to make testing easier. settings = api_settings @@ -184,6 +187,18 @@ class APIView(View): 'request': getattr(self, 'request', None) } + def get_exception_handler_context(self): + """ + Returns a dict that is passed through to EXCEPTION_HANDLER, + as the `context` argument. + """ + return { + 'view': self, + 'args': getattr(self, 'args', ()), + 'kwargs': getattr(self, 'kwargs', {}), + 'request': getattr(self, 'request', None) + } + def get_view_name(self): """ Return the view name, as used in OPTIONS responses and in the @@ -300,6 +315,16 @@ class APIView(View): if not throttle.allow_request(request, self): self.throttled(request, throttle.wait()) + def determine_version(self, request, *args, **kwargs): + """ + If versioning is being used, then determine any API version for the + incoming request. Returns a two-tuple of (version, versioning_scheme) + """ + if self.versioning_class is None: + return (None, None) + scheme = self.versioning_class() + return (scheme.determine_version(request, *args, **kwargs), scheme) + # Dispatch methods def initialize_request(self, request, *args, **kwargs): @@ -308,11 +333,13 @@ class APIView(View): """ parser_context = self.get_parser_context(request) - return Request(request, - parsers=self.get_parsers(), - authenticators=self.get_authenticators(), - negotiator=self.get_content_negotiator(), - parser_context=parser_context) + return Request( + request, + parsers=self.get_parsers(), + authenticators=self.get_authenticators(), + negotiator=self.get_content_negotiator(), + parser_context=parser_context + ) def initial(self, request, *args, **kwargs): """ @@ -329,6 +356,10 @@ class APIView(View): neg = self.perform_content_negotiation(request) request.accepted_renderer, request.accepted_media_type = neg + # Determine the API version, if versioning is in use. + version, scheme = self.determine_version(request, *args, **kwargs) + request.version, request.versioning_scheme = version, scheme + def finalize_response(self, request, response, *args, **kwargs): """ Returns the final response object. @@ -369,7 +400,18 @@ class APIView(View): else: exc.status_code = status.HTTP_403_FORBIDDEN - response = self.settings.EXCEPTION_HANDLER(exc) + exception_handler = self.settings.EXCEPTION_HANDLER + + if len(inspect.getargspec(exception_handler).args) == 1: + warnings.warn( + 'The `exception_handler(exc)` call signature is deprecated. ' + 'Use `exception_handler(exc, context) instead.', + PendingDeprecationWarning + ) + response = exception_handler(exc) + else: + context = self.get_exception_handler_context() + response = exception_handler(exc, context) if response is None: raise |
