aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/authentication.py200
-rw-r--r--rest_framework/authtoken/serializers.py2
-rw-r--r--rest_framework/compat.py81
-rw-r--r--rest_framework/exceptions.py37
-rw-r--r--rest_framework/fields.py43
-rw-r--r--rest_framework/filters.py24
-rw-r--r--rest_framework/generics.py230
-rw-r--r--rest_framework/locale/en_US/LC_MESSAGES/django.po316
-rw-r--r--rest_framework/mixins.py13
-rw-r--r--rest_framework/pagination.py699
-rw-r--r--rest_framework/parsers.py98
-rw-r--r--rest_framework/permissions.py28
-rw-r--r--rest_framework/relations.py6
-rw-r--r--rest_framework/renderers.py123
-rw-r--r--rest_framework/request.py4
-rw-r--r--rest_framework/reverse.py12
-rw-r--r--rest_framework/serializers.py699
-rw-r--r--rest_framework/settings.py19
-rw-r--r--rest_framework/static/rest_framework/css/bootstrap-tweaks.css21
-rw-r--r--rest_framework/templates/rest_framework/base.html9
-rw-r--r--rest_framework/templates/rest_framework/pagination/numbers.html27
-rw-r--r--rest_framework/templates/rest_framework/pagination/previous_and_next.html12
-rw-r--r--rest_framework/templatetags/rest_framework.py21
-rw-r--r--rest_framework/utils/encoders.py66
-rw-r--r--rest_framework/utils/formatting.py5
-rw-r--r--rest_framework/utils/model_meta.py10
-rw-r--r--rest_framework/utils/urls.py25
-rw-r--r--rest_framework/versioning.py174
-rw-r--r--rest_framework/views.py56
29 files changed, 1949 insertions, 1111 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 4832ad33..11db0585 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -3,14 +3,10 @@ Provides various authentication policies.
"""
from __future__ import unicode_literals
import base64
-
from django.contrib.auth import authenticate
-from django.core.exceptions import ImproperlyConfigured
from django.middleware.csrf import CsrfViewMiddleware
-from django.conf import settings
+from django.utils.translation import ugettext_lazy as _
from rest_framework import exceptions, HTTP_HEADER_ENCODING
-from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
-from rest_framework.compat import oauth2_provider, provider_now, check_nonce
from rest_framework.authtoken.models import Token
@@ -70,16 +66,16 @@ class BasicAuthentication(BaseAuthentication):
return None
if len(auth) == 1:
- msg = 'Invalid basic header. No credentials provided.'
+ msg = _('Invalid basic header. No credentials provided.')
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
- msg = 'Invalid basic header. Credentials string should not contain spaces.'
+ msg = _('Invalid basic header. Credentials string should not contain spaces.')
raise exceptions.AuthenticationFailed(msg)
try:
auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':')
except (TypeError, UnicodeDecodeError):
- msg = 'Invalid basic header. Credentials not correctly base64 encoded'
+ msg = _('Invalid basic header. Credentials not correctly base64 encoded.')
raise exceptions.AuthenticationFailed(msg)
userid, password = auth_parts[0], auth_parts[2]
@@ -91,7 +87,7 @@ class BasicAuthentication(BaseAuthentication):
"""
user = authenticate(username=userid, password=password)
if user is None or not user.is_active:
- raise exceptions.AuthenticationFailed('Invalid username/password')
+ raise exceptions.AuthenticationFailed(_('Invalid username/password.'))
return (user, None)
def authenticate_header(self, request):
@@ -157,10 +153,10 @@ class TokenAuthentication(BaseAuthentication):
return None
if len(auth) == 1:
- msg = 'Invalid token header. No credentials provided.'
+ msg = _('Invalid token header. No credentials provided.')
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
- msg = 'Invalid token header. Token string should not contain spaces.'
+ msg = _('Invalid token header. Token string should not contain spaces.')
raise exceptions.AuthenticationFailed(msg)
return self.authenticate_credentials(auth[1])
@@ -169,190 +165,12 @@ class TokenAuthentication(BaseAuthentication):
try:
token = self.model.objects.get(key=key)
except self.model.DoesNotExist:
- raise exceptions.AuthenticationFailed('Invalid token')
+ raise exceptions.AuthenticationFailed(_('Invalid token.'))
if not token.user.is_active:
- raise exceptions.AuthenticationFailed('User inactive or deleted')
+ raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))
return (token.user, token)
def authenticate_header(self, request):
return 'Token'
-
-
-class OAuthAuthentication(BaseAuthentication):
- """
- OAuth 1.0a authentication backend using `django-oauth-plus` and `oauth2`.
-
- Note: The `oauth2` package actually provides oauth1.0a support. Urg.
- We import it from the `compat` module as `oauth`.
- """
- www_authenticate_realm = 'api'
-
- def __init__(self, *args, **kwargs):
- super(OAuthAuthentication, self).__init__(*args, **kwargs)
-
- if oauth is None:
- raise ImproperlyConfigured(
- "The 'oauth2' package could not be imported."
- "It is required for use with the 'OAuthAuthentication' class.")
-
- if oauth_provider is None:
- raise ImproperlyConfigured(
- "The 'django-oauth-plus' package could not be imported."
- "It is required for use with the 'OAuthAuthentication' class.")
-
- def authenticate(self, request):
- """
- Returns two-tuple of (user, token) if authentication succeeds,
- or None otherwise.
- """
- try:
- oauth_request = oauth_provider.utils.get_oauth_request(request)
- except oauth.Error as err:
- raise exceptions.AuthenticationFailed(err.message)
-
- if not oauth_request:
- return None
-
- oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES
-
- found = any(param for param in oauth_params if param in oauth_request)
- missing = list(param for param in oauth_params if param not in oauth_request)
-
- if not found:
- # OAuth authentication was not attempted.
- return None
-
- if missing:
- # OAuth was attempted but missing parameters.
- msg = 'Missing parameters: %s' % (', '.join(missing))
- raise exceptions.AuthenticationFailed(msg)
-
- if not self.check_nonce(request, oauth_request):
- msg = 'Nonce check failed'
- raise exceptions.AuthenticationFailed(msg)
-
- try:
- consumer_key = oauth_request.get_parameter('oauth_consumer_key')
- consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
- except oauth_provider.store.InvalidConsumerError:
- msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key')
- raise exceptions.AuthenticationFailed(msg)
-
- if consumer.status != oauth_provider.consts.ACCEPTED:
- msg = 'Invalid consumer key status: %s' % consumer.get_status_display()
- raise exceptions.AuthenticationFailed(msg)
-
- try:
- token_param = oauth_request.get_parameter('oauth_token')
- token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param)
- except oauth_provider.store.InvalidTokenError:
- msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token')
- raise exceptions.AuthenticationFailed(msg)
-
- try:
- self.validate_token(request, consumer, token)
- except oauth.Error as err:
- raise exceptions.AuthenticationFailed(err.message)
-
- user = token.user
-
- if not user.is_active:
- msg = 'User inactive or deleted: %s' % user.username
- raise exceptions.AuthenticationFailed(msg)
-
- return (token.user, token)
-
- def authenticate_header(self, request):
- """
- If permission is denied, return a '401 Unauthorized' response,
- with an appropriate 'WWW-Authenticate' header.
- """
- return 'OAuth realm="%s"' % self.www_authenticate_realm
-
- def validate_token(self, request, consumer, token):
- """
- Check the token and raise an `oauth.Error` exception if invalid.
- """
- oauth_server, oauth_request = oauth_provider.utils.initialize_server_request(request)
- oauth_server.verify_request(oauth_request, consumer, token)
-
- def check_nonce(self, request, oauth_request):
- """
- Checks nonce of request, and return True if valid.
- """
- oauth_nonce = oauth_request['oauth_nonce']
- oauth_timestamp = oauth_request['oauth_timestamp']
- return check_nonce(request, oauth_request, oauth_nonce, oauth_timestamp)
-
-
-class OAuth2Authentication(BaseAuthentication):
- """
- OAuth 2 authentication backend using `django-oauth2-provider`
- """
- www_authenticate_realm = 'api'
- allow_query_params_token = settings.DEBUG
-
- def __init__(self, *args, **kwargs):
- super(OAuth2Authentication, self).__init__(*args, **kwargs)
-
- if oauth2_provider is None:
- raise ImproperlyConfigured(
- "The 'django-oauth2-provider' package could not be imported. "
- "It is required for use with the 'OAuth2Authentication' class.")
-
- def authenticate(self, request):
- """
- Returns two-tuple of (user, token) if authentication succeeds,
- or None otherwise.
- """
-
- auth = get_authorization_header(request).split()
-
- if len(auth) == 1:
- msg = 'Invalid bearer header. No credentials provided.'
- raise exceptions.AuthenticationFailed(msg)
- elif len(auth) > 2:
- msg = 'Invalid bearer header. Token string should not contain spaces.'
- raise exceptions.AuthenticationFailed(msg)
-
- if auth and auth[0].lower() == b'bearer':
- access_token = auth[1]
- elif 'access_token' in request.POST:
- access_token = request.POST['access_token']
- elif 'access_token' in request.GET and self.allow_query_params_token:
- access_token = request.GET['access_token']
- else:
- return None
-
- return self.authenticate_credentials(request, access_token)
-
- def authenticate_credentials(self, request, access_token):
- """
- Authenticate the request, given the access token.
- """
-
- try:
- token = oauth2_provider.oauth2.models.AccessToken.objects.select_related('user')
- # provider_now switches to timezone aware datetime when
- # the oauth2_provider version supports to it.
- token = token.get(token=access_token, expires__gt=provider_now())
- except oauth2_provider.oauth2.models.AccessToken.DoesNotExist:
- raise exceptions.AuthenticationFailed('Invalid token')
-
- user = token.user
-
- if not user.is_active:
- msg = 'User inactive or deleted: %s' % user.get_username()
- raise exceptions.AuthenticationFailed(msg)
-
- return (user, token)
-
- def authenticate_header(self, request):
- """
- Bearer is the only finalized type currently
-
- Check details on the `OAuth2Authentication.authenticate` method
- """
- return 'Bearer realm="%s"' % self.www_authenticate_realm
diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py
index f31dded1..37ade255 100644
--- a/rest_framework/authtoken/serializers.py
+++ b/rest_framework/authtoken/serializers.py
@@ -23,7 +23,7 @@ class AuthTokenSerializer(serializers.Serializer):
msg = _('Unable to log in with provided credentials.')
raise exceptions.ValidationError(msg)
else:
- msg = _('Must include "username" and "password"')
+ msg = _('Must include "username" and "password".')
raise exceptions.ValidationError(msg)
attrs['user'] = user
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 36413394..50f37014 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -5,15 +5,13 @@ versions of django/python, and compatibility wrappers around optional packages.
# flake8: noqa
from __future__ import unicode_literals
-
-import inspect
-
from django.core.exceptions import ImproperlyConfigured
+from django.conf import settings
from django.utils.encoding import force_text
from django.utils.six.moves.urllib.parse import urlparse as _urlparse
-from django.conf import settings
from django.utils import six
import django
+import inspect
def unicode_repr(instance):
@@ -33,6 +31,13 @@ def unicode_to_repr(value):
return value
+def unicode_http_header(value):
+ # Coerce HTTP header value to unicode.
+ if isinstance(value, six.binary_type):
+ return value.decode('iso-8859-1')
+ return value
+
+
def total_seconds(timedelta):
# TimeDelta.total_seconds() is only available in Python 2.7
if hasattr(timedelta, 'total_seconds'):
@@ -232,77 +237,13 @@ except ImportError:
apply_markdown = None
-# Yaml is optional
-try:
- import yaml
-except ImportError:
- yaml = None
-
-
-# XML is optional
-try:
- import defusedxml.ElementTree as etree
-except ImportError:
- etree = None
-
-
-# OAuth2 is optional
-try:
- # Note: The `oauth2` package actually provides oauth1.0a support. Urg.
- import oauth2 as oauth
-except ImportError:
- oauth = None
-
-
-# OAuthProvider is optional
-try:
- import oauth_provider
- from oauth_provider.store import store as oauth_provider_store
-
- # check_nonce's calling signature in django-oauth-plus changes sometime
- # between versions 2.0 and 2.2.1
- def check_nonce(request, oauth_request, oauth_nonce, oauth_timestamp):
- check_nonce_args = inspect.getargspec(oauth_provider_store.check_nonce).args
- if 'timestamp' in check_nonce_args:
- return oauth_provider_store.check_nonce(
- request, oauth_request, oauth_nonce, oauth_timestamp
- )
- return oauth_provider_store.check_nonce(
- request, oauth_request, oauth_nonce
- )
-
-except (ImportError, ImproperlyConfigured):
- oauth_provider = None
- oauth_provider_store = None
- check_nonce = None
-
-
-# OAuth 2 support is optional
-try:
- import provider as oauth2_provider
- from provider import scope as oauth2_provider_scope
- from provider import constants as oauth2_constants
-
- if oauth2_provider.__version__ in ('0.2.3', '0.2.4'):
- # 0.2.3 and 0.2.4 are supported version that do not support
- # timezone aware datetimes
- import datetime
-
- provider_now = datetime.datetime.now
- else:
- # Any other supported version does use timezone aware datetimes
- from django.utils.timezone import now as provider_now
-except ImportError:
- oauth2_provider = None
- oauth2_provider_scope = None
- oauth2_constants = None
- provider_now = None
-
# `separators` argument to `json.dumps()` differs between 2.x and 3.x
# See: http://bugs.python.org/issue22767
if six.PY3:
SHORT_SEPARATORS = (',', ':')
LONG_SEPARATORS = (', ', ': ')
+ INDENT_SEPARATORS = (',', ': ')
else:
SHORT_SEPARATORS = (b',', b':')
LONG_SEPARATORS = (b', ', b': ')
+ INDENT_SEPARATORS = (b',', b': ')
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 1f381e4e..f954c13e 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -7,8 +7,7 @@ In addition Django's built in 403 and 404 exceptions are handled.
from __future__ import unicode_literals
from django.utils import six
from django.utils.encoding import force_text
-from django.utils.translation import ugettext_lazy as _
-from django.utils.translation import ungettext_lazy
+from django.utils.translation import ugettext_lazy as _, ungettext
from rest_framework import status
import math
@@ -36,7 +35,7 @@ class APIException(Exception):
Subclasses should provide `.status_code` and `.default_detail` properties.
"""
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
- default_detail = _('A server error occured')
+ default_detail = _('A server error occurred.')
def __init__(self, detail=None):
if detail is not None:
@@ -89,20 +88,25 @@ class PermissionDenied(APIException):
default_detail = _('You do not have permission to perform this action.')
+class NotFound(APIException):
+ status_code = status.HTTP_404_NOT_FOUND
+ default_detail = _('Not found.')
+
+
class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
- default_detail = _("Method '%s' not allowed.")
+ default_detail = _('Method "{method}" not allowed.')
def __init__(self, method, detail=None):
if detail is not None:
self.detail = force_text(detail)
else:
- self.detail = force_text(self.default_detail) % method
+ self.detail = force_text(self.default_detail).format(method=method)
class NotAcceptable(APIException):
status_code = status.HTTP_406_NOT_ACCEPTABLE
- default_detail = _('Could not satisfy the request Accept header')
+ default_detail = _('Could not satisfy the request Accept header.')
def __init__(self, detail=None, available_renderers=None):
if detail is not None:
@@ -114,23 +118,22 @@ class NotAcceptable(APIException):
class UnsupportedMediaType(APIException):
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
- default_detail = _("Unsupported media type '%s' in request.")
+ default_detail = _('Unsupported media type "{media_type}" in request.')
def __init__(self, media_type, detail=None):
if detail is not None:
self.detail = force_text(detail)
else:
- self.detail = force_text(self.default_detail) % media_type
+ self.detail = force_text(self.default_detail).format(
+ media_type=media_type
+ )
class Throttled(APIException):
status_code = status.HTTP_429_TOO_MANY_REQUESTS
default_detail = _('Request was throttled.')
- extra_detail = ungettext_lazy(
- 'Expected available in %(wait)d second.',
- 'Expected available in %(wait)d seconds.',
- 'wait'
- )
+ extra_detail_singular = 'Expected available in {wait} second.'
+ extra_detail_plural = 'Expected available in {wait} seconds.'
def __init__(self, wait=None, detail=None):
if detail is not None:
@@ -142,6 +145,8 @@ class Throttled(APIException):
self.wait = None
else:
self.wait = math.ceil(wait)
- self.detail += ' ' + force_text(
- self.extra_detail % {'wait': self.wait}
- )
+ self.detail += ' ' + force_text(ungettext(
+ self.extra_detail_singular.format(wait=self.wait),
+ self.extra_detail_plural.format(wait=self.wait),
+ self.wait
+ ))
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 71a9f193..02d2adef 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -484,7 +484,7 @@ class Field(object):
class BooleanField(Field):
default_error_messages = {
- 'invalid': _('`{input}` is not a valid boolean.')
+ 'invalid': _('"{input}" is not a valid boolean.')
}
default_empty_html = False
initial = False
@@ -512,7 +512,7 @@ class BooleanField(Field):
class NullBooleanField(Field):
default_error_messages = {
- 'invalid': _('`{input}` is not a valid boolean.')
+ 'invalid': _('"{input}" is not a valid boolean.')
}
initial = None
TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
@@ -612,7 +612,7 @@ class RegexField(CharField):
class SlugField(CharField):
default_error_messages = {
- 'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")
+ 'invalid': _('Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.')
}
def __init__(self, **kwargs):
@@ -624,7 +624,7 @@ class SlugField(CharField):
class URLField(CharField):
default_error_messages = {
- 'invalid': _("Enter a valid URL.")
+ 'invalid': _('Enter a valid URL.')
}
def __init__(self, **kwargs):
@@ -657,7 +657,7 @@ class IntegerField(Field):
'invalid': _('A valid integer is required.'),
'max_value': _('Ensure this value is less than or equal to {max_value}.'),
'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
- 'max_string_length': _('String value too large')
+ 'max_string_length': _('String value too large.')
}
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
@@ -688,10 +688,10 @@ class IntegerField(Field):
class FloatField(Field):
default_error_messages = {
- 'invalid': _("A valid number is required."),
+ 'invalid': _('A valid number is required.'),
'max_value': _('Ensure this value is less than or equal to {max_value}.'),
'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
- 'max_string_length': _('String value too large')
+ 'max_string_length': _('String value too large.')
}
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
@@ -727,7 +727,7 @@ class DecimalField(Field):
'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'),
- 'max_string_length': _('String value too large')
+ 'max_string_length': _('String value too large.')
}
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
@@ -810,7 +810,7 @@ class DecimalField(Field):
class DateTimeField(Field):
default_error_messages = {
- 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'),
+ 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}.'),
'date': _('Expected a datetime but got a date.'),
}
format = api_settings.DATETIME_FORMAT
@@ -875,7 +875,7 @@ class DateTimeField(Field):
class DateField(Field):
default_error_messages = {
- 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'),
+ 'invalid': _('Date has wrong format. Use one of these formats instead: {format}.'),
'datetime': _('Expected a date but got a datetime.'),
}
format = api_settings.DATE_FORMAT
@@ -933,7 +933,7 @@ class DateField(Field):
class TimeField(Field):
default_error_messages = {
- 'invalid': _('Time has wrong format. Use one of these formats instead: {format}'),
+ 'invalid': _('Time has wrong format. Use one of these formats instead: {format}.'),
}
format = api_settings.TIME_FORMAT
input_formats = api_settings.TIME_INPUT_FORMATS
@@ -989,7 +989,7 @@ class TimeField(Field):
class ChoiceField(Field):
default_error_messages = {
- 'invalid_choice': _('`{input}` is not a valid choice.')
+ 'invalid_choice': _('"{input}" is not a valid choice.')
}
def __init__(self, choices, **kwargs):
@@ -1033,8 +1033,8 @@ class ChoiceField(Field):
class MultipleChoiceField(ChoiceField):
default_error_messages = {
- 'invalid_choice': _('`{input}` is not a valid choice.'),
- 'not_a_list': _('Expected a list of items but got type `{input_type}`.')
+ 'invalid_choice': _('"{input}" is not a valid choice.'),
+ 'not_a_list': _('Expected a list of items but got type "{input_type}".')
}
default_empty_html = []
@@ -1064,10 +1064,10 @@ class MultipleChoiceField(ChoiceField):
class FileField(Field):
default_error_messages = {
- 'required': _("No file was submitted."),
- 'invalid': _("The submitted data was not a file. Check the encoding type on the form."),
- 'no_name': _("No filename could be determined."),
- 'empty': _("The submitted file is empty."),
+ 'required': _('No file was submitted.'),
+ 'invalid': _('The submitted data was not a file. Check the encoding type on the form.'),
+ 'no_name': _('No filename could be determined.'),
+ 'empty': _('The submitted file is empty.'),
'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'),
}
use_url = api_settings.UPLOADED_FILES_USE_URL
@@ -1110,8 +1110,7 @@ class FileField(Field):
class ImageField(FileField):
default_error_messages = {
'invalid_image': _(
- 'Upload a valid image. The file you uploaded was either not an '
- 'image or a corrupted image.'
+ 'Upload a valid image. The file you uploaded was either not an image or a corrupted image.'
),
}
@@ -1149,7 +1148,7 @@ class ListField(Field):
child = _UnvalidatedField()
initial = []
default_error_messages = {
- 'not_a_list': _('Expected a list of items but got type `{input_type}`')
+ 'not_a_list': _('Expected a list of items but got type "{input_type}".')
}
def __init__(self, *args, **kwargs):
@@ -1186,7 +1185,7 @@ class DictField(Field):
child = _UnvalidatedField()
initial = []
default_error_messages = {
- 'not_a_dict': _('Expected a dictionary of items but got type `{input_type}`')
+ 'not_a_dict': _('Expected a dictionary of items but got type "{input_type}".')
}
def __init__(self, *args, **kwargs):
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index d188a2d1..2bcf3699 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -114,7 +114,7 @@ class OrderingFilter(BaseFilterBackend):
ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
- def get_ordering(self, request):
+ def get_ordering(self, request, queryset, view):
"""
Ordering is set by a comma delimited ?ordering=... query parameter.
@@ -124,7 +124,13 @@ class OrderingFilter(BaseFilterBackend):
"""
params = request.query_params.get(self.ordering_param)
if params:
- return [param.strip() for param in params.split(',')]
+ fields = [param.strip() for param in params.split(',')]
+ ordering = self.remove_invalid_fields(queryset, fields, view)
+ if ordering:
+ return ordering
+
+ # No ordering was included, or all the ordering fields were invalid
+ return self.get_default_ordering(view)
def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None)
@@ -132,7 +138,7 @@ class OrderingFilter(BaseFilterBackend):
return (ordering,)
return ordering
- def remove_invalid_fields(self, queryset, ordering, view):
+ def remove_invalid_fields(self, queryset, fields, view):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
if valid_fields is None:
@@ -152,18 +158,10 @@ class OrderingFilter(BaseFilterBackend):
valid_fields = [field.name for field in queryset.model._meta.fields]
valid_fields += queryset.query.aggregates.keys()
- return [term for term in ordering if term.lstrip('-') in valid_fields]
+ return [term for term in fields if term.lstrip('-') in valid_fields]
def filter_queryset(self, request, queryset, view):
- ordering = self.get_ordering(request)
-
- if ordering:
- # Skip any incorrect parameters
- ordering = self.remove_invalid_fields(queryset, ordering, view)
-
- if not ordering:
- # Use 'ordering' attribute by default
- ordering = self.get_default_ordering(view)
+ ordering = self.get_ordering(request, queryset, view)
if ordering:
return queryset.order_by(*ordering)
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index e6db155e..61dcb84a 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -2,29 +2,13 @@
Generic views that provide commonly needed behaviour.
"""
from __future__ import unicode_literals
-
-from django.core.paginator import Paginator, InvalidPage
from django.db.models.query import QuerySet
from django.http import Http404
from django.shortcuts import get_object_or_404 as _get_object_or_404
-from django.utils import six
-from django.utils.translation import ugettext as _
from rest_framework import views, mixins
from rest_framework.settings import api_settings
-def strict_positive_int(integer_string, cutoff=None):
- """
- Cast a string to a strictly positive integer.
- """
- ret = int(integer_string)
- if ret <= 0:
- raise ValueError()
- if cutoff:
- ret = min(ret, cutoff)
- return ret
-
-
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
"""
Same as Django's standard shortcut, but make sure to also raise 404
@@ -40,7 +24,6 @@ class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
-
# You'll need to either set these attributes,
# or override `get_queryset()`/`get_serializer_class()`.
# If you are overriding a view method, it is important that you call
@@ -50,146 +33,16 @@ class GenericAPIView(views.APIView):
queryset = None
serializer_class = None
- # If you want to use object lookups other than pk, set this attribute.
+ # If you want to use object lookups other than pk, set 'lookup_field'.
# For more complex lookup requirements override `get_object()`.
lookup_field = 'pk'
lookup_url_kwarg = None
- # Pagination settings
- paginate_by = api_settings.PAGINATE_BY
- paginate_by_param = api_settings.PAGINATE_BY_PARAM
- max_paginate_by = api_settings.MAX_PAGINATE_BY
- pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
- page_kwarg = 'page'
-
# The filter backend classes to use for queryset filtering
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
- # The following attribute may be subject to change,
- # and should be considered private API.
- paginator_class = Paginator
-
- def get_serializer_context(self):
- """
- Extra context provided to the serializer class.
- """
- return {
- 'request': self.request,
- 'format': self.format_kwarg,
- 'view': self
- }
-
- def get_serializer(self, *args, **kwargs):
- """
- Return the serializer instance that should be used for validating and
- deserializing input, and for serializing output.
- """
- serializer_class = self.get_serializer_class()
- kwargs['context'] = self.get_serializer_context()
- return serializer_class(*args, **kwargs)
-
- def get_pagination_serializer(self, page):
- """
- Return a serializer instance to use with paginated data.
- """
- class SerializerClass(self.pagination_serializer_class):
- class Meta:
- object_serializer_class = self.get_serializer_class()
-
- pagination_serializer_class = SerializerClass
- context = self.get_serializer_context()
- return pagination_serializer_class(instance=page, context=context)
-
- def paginate_queryset(self, queryset):
- """
- Paginate a queryset if required, either returning a page object,
- or `None` if pagination is not configured for this view.
- """
- page_size = self.get_paginate_by()
- if not page_size:
- return None
-
- paginator = self.paginator_class(queryset, page_size)
- page_kwarg = self.kwargs.get(self.page_kwarg)
- page_query_param = self.request.query_params.get(self.page_kwarg)
- page = page_kwarg or page_query_param or 1
- try:
- page_number = paginator.validate_number(page)
- except InvalidPage:
- if page == 'last':
- page_number = paginator.num_pages
- else:
- raise Http404(_("Page is not 'last', nor can it be converted to an int."))
- try:
- page = paginator.page(page_number)
- except InvalidPage as exc:
- error_format = _('Invalid page (%(page_number)s): %(message)s')
- raise Http404(error_format % {
- 'page_number': page_number,
- 'message': six.text_type(exc)
- })
-
- return page
-
- def filter_queryset(self, queryset):
- """
- Given a queryset, filter it with whichever filter backend is in use.
-
- You are unlikely to want to override this method, although you may need
- to call it either from a list view, or from a custom `get_object`
- method if you want to apply the configured filtering backend to the
- default queryset.
- """
- for backend in self.get_filter_backends():
- queryset = backend().filter_queryset(self.request, queryset, self)
- return queryset
-
- def get_filter_backends(self):
- """
- Returns the list of filter backends that this view requires.
- """
- return list(self.filter_backends)
-
- # The following methods provide default implementations
- # that you may want to override for more complex cases.
-
- def get_paginate_by(self):
- """
- Return the size of pages to use with pagination.
-
- If `PAGINATE_BY_PARAM` is set it will attempt to get the page size
- from a named query parameter in the url, eg. ?page_size=100
-
- Otherwise defaults to using `self.paginate_by`.
- """
- if self.paginate_by_param:
- try:
- return strict_positive_int(
- self.request.query_params[self.paginate_by_param],
- cutoff=self.max_paginate_by
- )
- except (KeyError, ValueError):
- pass
-
- return self.paginate_by
-
- def get_serializer_class(self):
- """
- Return the class to use for the serializer.
- Defaults to using `self.serializer_class`.
-
- You may want to override this if you need to provide different
- serializations depending on the incoming request.
-
- (Eg. admins get full serialization, others get basic serialization)
- """
- assert self.serializer_class is not None, (
- "'%s' should either include a `serializer_class` attribute, "
- "or override the `get_serializer_class()` method."
- % self.__class__.__name__
- )
-
- return self.serializer_class
+ # The style to use for queryset pagination.
+ pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
def get_queryset(self):
"""
@@ -246,6 +99,83 @@ class GenericAPIView(views.APIView):
return obj
+ def get_serializer(self, *args, **kwargs):
+ """
+ Return the serializer instance that should be used for validating and
+ deserializing input, and for serializing output.
+ """
+ serializer_class = self.get_serializer_class()
+ kwargs['context'] = self.get_serializer_context()
+ return serializer_class(*args, **kwargs)
+
+ def get_serializer_class(self):
+ """
+ Return the class to use for the serializer.
+ Defaults to using `self.serializer_class`.
+
+ You may want to override this if you need to provide different
+ serializations depending on the incoming request.
+
+ (Eg. admins get full serialization, others get basic serialization)
+ """
+ assert self.serializer_class is not None, (
+ "'%s' should either include a `serializer_class` attribute, "
+ "or override the `get_serializer_class()` method."
+ % self.__class__.__name__
+ )
+
+ return self.serializer_class
+
+ def get_serializer_context(self):
+ """
+ Extra context provided to the serializer class.
+ """
+ return {
+ 'request': self.request,
+ 'format': self.format_kwarg,
+ 'view': self
+ }
+
+ def filter_queryset(self, queryset):
+ """
+ Given a queryset, filter it with whichever filter backend is in use.
+
+ You are unlikely to want to override this method, although you may need
+ to call it either from a list view, or from a custom `get_object`
+ method if you want to apply the configured filtering backend to the
+ default queryset.
+ """
+ for backend in list(self.filter_backends):
+ queryset = backend().filter_queryset(self.request, queryset, self)
+ return queryset
+
+ @property
+ def paginator(self):
+ """
+ The paginator instance associated with the view, or `None`.
+ """
+ if not hasattr(self, '_paginator'):
+ if self.pagination_class is None:
+ self._paginator = None
+ else:
+ self._paginator = self.pagination_class()
+ return self._paginator
+
+ def paginate_queryset(self, queryset):
+ """
+ Return a single page of results, or `None` if pagination is disabled.
+ """
+ if self.paginator is None:
+ return None
+ return self.paginator.paginate_queryset(queryset, self.request, view=self)
+
+ def get_paginated_response(self, data):
+ """
+ Return a paginated style `Response` object for the given output data.
+ """
+ assert self.paginator is not None
+ return self.paginator.get_paginated_response(data)
+
# Concrete view classes that provide method handlers
# by composing the mixin classes with the base view.
diff --git a/rest_framework/locale/en_US/LC_MESSAGES/django.po b/rest_framework/locale/en_US/LC_MESSAGES/django.po
new file mode 100644
index 00000000..d98225ce
--- /dev/null
+++ b/rest_framework/locale/en_US/LC_MESSAGES/django.po
@@ -0,0 +1,316 @@
+# SOME DESCRIPTIVE TITLE.
+# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER
+# This file is distributed under the same license as the PACKAGE package.
+# FIRST AUTHOR <EMAIL@ADDRESS>, YEAR.
+#
+#, fuzzy
+msgid ""
+msgstr ""
+"Project-Id-Version: PACKAGE VERSION\n"
+"Report-Msgid-Bugs-To: \n"
+"POT-Creation-Date: 2015-01-07 18:21+0000\n"
+"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
+"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
+"Language-Team: LANGUAGE <LL@li.org>\n"
+"Language: \n"
+"MIME-Version: 1.0\n"
+"Content-Type: text/plain; charset=UTF-8\n"
+"Content-Transfer-Encoding: 8bit\n"
+
+#: authentication.py:69
+msgid "Invalid basic header. No credentials provided."
+msgstr ""
+
+#: authentication.py:72
+msgid "Invalid basic header. Credentials string should not contain spaces."
+msgstr ""
+
+#: authentication.py:78
+msgid "Invalid basic header. Credentials not correctly base64 encoded."
+msgstr ""
+
+#: authentication.py:90
+msgid "Invalid username/password."
+msgstr ""
+
+#: authentication.py:156
+msgid "Invalid token header. No credentials provided."
+msgstr ""
+
+#: authentication.py:159
+msgid "Invalid token header. Token string should not contain spaces."
+msgstr ""
+
+#: authentication.py:168
+msgid "Invalid token."
+msgstr ""
+
+#: authentication.py:171
+msgid "User inactive or deleted."
+msgstr ""
+
+#: authtoken/serializers.py:20
+msgid "User account is disabled."
+msgstr ""
+
+#: authtoken/serializers.py:23
+msgid "Unable to log in with provided credentials."
+msgstr ""
+
+#: authtoken/serializers.py:26
+msgid "Must include \"username\" and \"password\"."
+msgstr ""
+
+#: exceptions.py:38
+msgid "A server error occurred."
+msgstr ""
+
+#: exceptions.py:73
+msgid "Malformed request."
+msgstr ""
+
+#: exceptions.py:78
+msgid "Incorrect authentication credentials."
+msgstr ""
+
+#: exceptions.py:83
+msgid "Authentication credentials were not provided."
+msgstr ""
+
+#: exceptions.py:88
+msgid "You do not have permission to perform this action."
+msgstr ""
+
+#: exceptions.py:93
+msgid "Not found."
+msgstr ""
+
+#: exceptions.py:98
+msgid "Method \"{method}\" not allowed."
+msgstr ""
+
+#: exceptions.py:109
+msgid "Could not satisfy the request Accept header."
+msgstr ""
+
+#: exceptions.py:121
+msgid "Unsupported media type \"{media_type}\" in request."
+msgstr ""
+
+#: exceptions.py:134
+msgid "Request was throttled."
+msgstr ""
+
+#: fields.py:152 relations.py:131 relations.py:155 validators.py:77
+#: validators.py:155
+msgid "This field is required."
+msgstr ""
+
+#: fields.py:153
+msgid "This field may not be null."
+msgstr ""
+
+#: fields.py:480 fields.py:508
+msgid "\"{input}\" is not a valid boolean."
+msgstr ""
+
+#: fields.py:543
+msgid "This field may not be blank."
+msgstr ""
+
+#: fields.py:544 fields.py:1252
+msgid "Ensure this field has no more than {max_length} characters."
+msgstr ""
+
+#: fields.py:545
+msgid "Ensure this field has at least {min_length} characters."
+msgstr ""
+
+#: fields.py:587
+msgid "Enter a valid email address."
+msgstr ""
+
+#: fields.py:604
+msgid "This value does not match the required pattern."
+msgstr ""
+
+#: fields.py:615
+msgid ""
+"Enter a valid \"slug\" consisting of letters, numbers, underscores or "
+"hyphens."
+msgstr ""
+
+#: fields.py:627
+msgid "Enter a valid URL."
+msgstr ""
+
+#: fields.py:640
+msgid "A valid integer is required."
+msgstr ""
+
+#: fields.py:641 fields.py:675 fields.py:708
+msgid "Ensure this value is less than or equal to {max_value}."
+msgstr ""
+
+#: fields.py:642 fields.py:676 fields.py:709
+msgid "Ensure this value is greater than or equal to {min_value}."
+msgstr ""
+
+#: fields.py:643 fields.py:677 fields.py:713
+msgid "String value too large."
+msgstr ""
+
+#: fields.py:674 fields.py:707
+msgid "A valid number is required."
+msgstr ""
+
+#: fields.py:710
+msgid "Ensure that there are no more than {max_digits} digits in total."
+msgstr ""
+
+#: fields.py:711
+msgid "Ensure that there are no more than {max_decimal_places} decimal places."
+msgstr ""
+
+#: fields.py:712
+msgid ""
+"Ensure that there are no more than {max_whole_digits} digits before the "
+"decimal point."
+msgstr ""
+
+#: fields.py:796
+msgid "Datetime has wrong format. Use one of these formats instead: {format}."
+msgstr ""
+
+#: fields.py:797
+msgid "Expected a datetime but got a date."
+msgstr ""
+
+#: fields.py:861
+msgid "Date has wrong format. Use one of these formats instead: {format}."
+msgstr ""
+
+#: fields.py:862
+msgid "Expected a date but got a datetime."
+msgstr ""
+
+#: fields.py:919
+msgid "Time has wrong format. Use one of these formats instead: {format}."
+msgstr ""
+
+#: fields.py:975 fields.py:1019
+msgid "\"{input}\" is not a valid choice."
+msgstr ""
+
+#: fields.py:1020 fields.py:1121 serializers.py:476
+msgid "Expected a list of items but got type \"{input_type}\"."
+msgstr ""
+
+#: fields.py:1050
+msgid "No file was submitted."
+msgstr ""
+
+#: fields.py:1051
+msgid "The submitted data was not a file. Check the encoding type on the form."
+msgstr ""
+
+#: fields.py:1052
+msgid "No filename could be determined."
+msgstr ""
+
+#: fields.py:1053
+msgid "The submitted file is empty."
+msgstr ""
+
+#: fields.py:1054
+msgid ""
+"Ensure this filename has at most {max_length} characters (it has {length})."
+msgstr ""
+
+#: fields.py:1096
+msgid ""
+"Upload a valid image. The file you uploaded was either not an image or a "
+"corrupted image."
+msgstr ""
+
+#: generics.py:123
+msgid ""
+"Choose a valid page number. Page numbers must be a whole number, or must be "
+"the string \"last\"."
+msgstr ""
+
+#: generics.py:128
+msgid "Invalid page \"{page_number}\": {message}."
+msgstr ""
+
+#: relations.py:132
+msgid "Invalid pk \"{pk_value}\" - object does not exist."
+msgstr ""
+
+#: relations.py:133
+msgid "Incorrect type. Expected pk value, received {data_type}."
+msgstr ""
+
+#: relations.py:156
+msgid "Invalid hyperlink - No URL match."
+msgstr ""
+
+#: relations.py:157
+msgid "Invalid hyperlink - Incorrect URL match."
+msgstr ""
+
+#: relations.py:158
+msgid "Invalid hyperlink - Object does not exist."
+msgstr ""
+
+#: relations.py:159
+msgid "Incorrect type. Expected URL string, received {data_type}."
+msgstr ""
+
+#: relations.py:294
+msgid "Object with {slug_name}={value} does not exist."
+msgstr ""
+
+#: relations.py:295
+msgid "Invalid value."
+msgstr ""
+
+#: serializers.py:299
+msgid "Invalid data. Expected a dictionary, but got {datatype}."
+msgstr ""
+
+#: validators.py:22
+msgid "This field must be unique."
+msgstr ""
+
+#: validators.py:76
+msgid "The fields {field_names} must make a unique set."
+msgstr ""
+
+#: validators.py:219
+msgid "This field must be unique for the \"{date_field}\" date."
+msgstr ""
+
+#: validators.py:234
+msgid "This field must be unique for the \"{date_field}\" month."
+msgstr ""
+
+#: validators.py:247
+msgid "This field must be unique for the \"{date_field}\" year."
+msgstr ""
+
+#: versioning.py:39
+msgid "Invalid version in \"Accept\" header."
+msgstr ""
+
+#: versioning.py:70 versioning.py:112
+msgid "Invalid version in URL path."
+msgstr ""
+
+#: versioning.py:138
+msgid "Invalid version in hostname."
+msgstr ""
+
+#: versioning.py:160
+msgid "Invalid version in query parameter."
+msgstr ""
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 2074a107..c34cfcee 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -5,7 +5,6 @@ We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways.
"""
from __future__ import unicode_literals
-
from rest_framework import status
from rest_framework.response import Response
from rest_framework.settings import api_settings
@@ -37,12 +36,14 @@ class ListModelMixin(object):
List a queryset.
"""
def list(self, request, *args, **kwargs):
- instance = self.filter_queryset(self.get_queryset())
- page = self.paginate_queryset(instance)
+ queryset = self.filter_queryset(self.get_queryset())
+
+ page = self.paginate_queryset(queryset)
if page is not None:
- serializer = self.get_pagination_serializer(page)
- else:
- serializer = self.get_serializer(instance, many=True)
+ serializer = self.get_serializer(page, many=True)
+ return self.get_paginated_response(serializer.data)
+
+ serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 9c8dda8f..b3658aca 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -1,85 +1,682 @@
+# coding: utf-8
"""
Pagination serializers determine the structure of the output that should
be used for paginated responses.
"""
from __future__ import unicode_literals
-from rest_framework import serializers
-from rest_framework.templatetags.rest_framework import replace_query_param
+from base64 import b64encode, b64decode
+from collections import namedtuple
+from django.core.paginator import InvalidPage, Paginator as DjangoPaginator
+from django.template import Context, loader
+from django.utils import six
+from django.utils.six.moves.urllib import parse as urlparse
+from django.utils.translation import ugettext as _
+from rest_framework.compat import OrderedDict
+from rest_framework.exceptions import NotFound
+from rest_framework.response import Response
+from rest_framework.settings import api_settings
+from rest_framework.utils.urls import (
+ replace_query_param, remove_query_param
+)
-class NextPageField(serializers.Field):
+def _positive_int(integer_string, strict=False, cutoff=None):
"""
- Field that returns a link to the next page in paginated results.
+ Cast a string to a strictly positive integer.
"""
- page_field = 'page'
+ ret = int(integer_string)
+ if ret < 0 or (ret == 0 and strict):
+ raise ValueError()
+ if cutoff:
+ ret = min(ret, cutoff)
+ return ret
- def to_representation(self, value):
- if not value.has_next():
- return None
- page = value.next_page_number()
- request = self.context.get('request')
- url = request and request.build_absolute_uri() or ''
- return replace_query_param(url, self.page_field, page)
+def _divide_with_ceil(a, b):
+ """
+ Returns 'a' divded by 'b', with any remainder rounded up.
+ """
+ if a % b:
+ return (a // b) + 1
+ return a // b
-class PreviousPageField(serializers.Field):
+
+def _get_count(queryset):
"""
- Field that returns a link to the previous page in paginated results.
+ Determine an object count, supporting either querysets or regular lists.
"""
- page_field = 'page'
+ try:
+ return queryset.count()
+ except (AttributeError, TypeError):
+ return len(queryset)
- def to_representation(self, value):
- if not value.has_previous():
- return None
- page = value.previous_page_number()
- request = self.context.get('request')
- url = request and request.build_absolute_uri() or ''
- return replace_query_param(url, self.page_field, page)
+
+def _get_displayed_page_numbers(current, final):
+ """
+ This utility function determines a list of page numbers to display.
+ This gives us a nice contextually relevant set of page numbers.
+
+ For example:
+ current=14, final=16 -> [1, None, 13, 14, 15, 16]
+
+ This implementation gives one page to each side of the cursor,
+ or two pages to the side when the cursor is at the edge, then
+ ensures that any breaks between non-continous page numbers never
+ remove only a single page.
+
+ For an alernativative implementation which gives two pages to each side of
+ the cursor, eg. as in GitHub issue list pagination, see:
+
+ https://gist.github.com/tomchristie/321140cebb1c4a558b15
+ """
+ assert current >= 1
+ assert final >= current
+
+ if final <= 5:
+ return list(range(1, final + 1))
+
+ # We always include the first two pages, last two pages, and
+ # two pages either side of the current page.
+ included = set((
+ 1,
+ current - 1, current, current + 1,
+ final
+ ))
+
+ # If the break would only exclude a single page number then we
+ # may as well include the page number instead of the break.
+ if current <= 4:
+ included.add(2)
+ included.add(3)
+ if current >= final - 3:
+ included.add(final - 1)
+ included.add(final - 2)
+
+ # Now sort the page numbers and drop anything outside the limits.
+ included = [
+ idx for idx in sorted(list(included))
+ if idx > 0 and idx <= final
+ ]
+
+ # Finally insert any `...` breaks
+ if current > 4:
+ included.insert(1, None)
+ if current < final - 3:
+ included.insert(len(included) - 1, None)
+ return included
-class DefaultObjectSerializer(serializers.Serializer):
+def _get_page_links(page_numbers, current, url_func):
"""
- If no object serializer is specified, then this serializer will be applied
- as the default.
+ Given a list of page numbers and `None` page breaks,
+ return a list of `PageLink` objects.
"""
- def to_representation(self, value):
- return value
+ page_links = []
+ for page_number in page_numbers:
+ if page_number is None:
+ page_link = PAGE_BREAK
+ else:
+ page_link = PageLink(
+ url=url_func(page_number),
+ number=page_number,
+ is_active=(page_number == current),
+ is_break=False
+ )
+ page_links.append(page_link)
+ return page_links
-class BasePaginationSerializer(serializers.Serializer):
+def _decode_cursor(encoded):
"""
- A base class for pagination serializers to inherit from,
- to make implementing custom serializers more easy.
+ Given a string representing an encoded cursor, return a `Cursor` instance.
"""
- results_field = 'results'
+ try:
+ querystring = b64decode(encoded.encode('ascii')).decode('ascii')
+ tokens = urlparse.parse_qs(querystring, keep_blank_values=True)
+
+ offset = tokens.get('o', ['0'])[0]
+ offset = _positive_int(offset)
+
+ reverse = tokens.get('r', ['0'])[0]
+ reverse = bool(int(reverse))
+
+ position = tokens.get('p', [None])[0]
+ except (TypeError, ValueError):
+ return None
+
+ return Cursor(offset=offset, reverse=reverse, position=position)
+
+
+def _encode_cursor(cursor):
+ """
+ Given a Cursor instance, return an encoded string representation.
+ """
+ tokens = {}
+ if cursor.offset != 0:
+ tokens['o'] = str(cursor.offset)
+ if cursor.reverse:
+ tokens['r'] = '1'
+ if cursor.position is not None:
+ tokens['p'] = cursor.position
+
+ querystring = urlparse.urlencode(tokens, doseq=True)
+ return b64encode(querystring.encode('ascii')).decode('ascii')
+
+
+def _reverse_ordering(ordering_tuple):
+ """
+ Given an order_by tuple such as `('-created', 'uuid')` reverse the
+ ordering and return a new tuple, eg. `('created', '-uuid')`.
+ """
+ invert = lambda x: x[1:] if (x.startswith('-')) else '-' + x
+ return tuple([invert(item) for item in ordering_tuple])
+
+
+Cursor = namedtuple('Cursor', ['offset', 'reverse', 'position'])
+PageLink = namedtuple('PageLink', ['url', 'number', 'is_active', 'is_break'])
+
+PAGE_BREAK = PageLink(url=None, number=None, is_active=False, is_break=True)
- def __init__(self, *args, **kwargs):
+
+class BasePagination(object):
+ display_page_controls = False
+
+ def paginate_queryset(self, queryset, request, view=None): # pragma: no cover
+ raise NotImplementedError('paginate_queryset() must be implemented.')
+
+ def get_paginated_response(self, data): # pragma: no cover
+ raise NotImplementedError('get_paginated_response() must be implemented.')
+
+ def to_html(self): # pragma: no cover
+ raise NotImplementedError('to_html() must be implemented to display page controls.')
+
+
+class PageNumberPagination(BasePagination):
+ """
+ A simple page number based style that supports page numbers as
+ query parameters. For example:
+
+ http://api.example.org/accounts/?page=4
+ http://api.example.org/accounts/?page=4&page_size=100
+ """
+ # The default page size.
+ # Defaults to `None`, meaning pagination is disabled.
+ paginate_by = api_settings.PAGINATE_BY
+
+ # Client can control the page using this query parameter.
+ page_query_param = 'page'
+
+ # Client can control the page size using this query parameter.
+ # Default is 'None'. Set to eg 'page_size' to enable usage.
+ paginate_by_param = api_settings.PAGINATE_BY_PARAM
+
+ # Set to an integer to limit the maximum page size the client may request.
+ # Only relevant if 'paginate_by_param' has also been set.
+ max_paginate_by = api_settings.MAX_PAGINATE_BY
+
+ last_page_strings = ('last',)
+
+ template = 'rest_framework/pagination/numbers.html'
+
+ invalid_page_message = _('Invalid page "{page_number}": {message}.')
+
+ def _handle_backwards_compat(self, view):
"""
- Override init to add in the object serializer field on-the-fly.
+ Prior to version 3.1, pagination was handled in the view, and the
+ attributes were set there. The attributes should now be set on
+ the pagination class, but the old style is still pending deprecation.
"""
- super(BasePaginationSerializer, self).__init__(*args, **kwargs)
- results_field = self.results_field
+ for attr in (
+ 'paginate_by', 'page_query_param',
+ 'paginate_by_param', 'max_paginate_by'
+ ):
+ if hasattr(view, attr):
+ setattr(self, attr, getattr(view, attr))
- try:
- object_serializer = self.Meta.object_serializer_class
- except AttributeError:
- object_serializer = DefaultObjectSerializer
+ def paginate_queryset(self, queryset, request, view=None):
+ """
+ Paginate a queryset if required, either returning a
+ page object, or `None` if pagination is not configured for this view.
+ """
+ self._handle_backwards_compat(view)
+
+ page_size = self.get_page_size(request)
+ if not page_size:
+ return None
+
+ paginator = DjangoPaginator(queryset, page_size)
+ page_number = request.query_params.get(self.page_query_param, 1)
+ if page_number in self.last_page_strings:
+ page_number = paginator.num_pages
try:
- list_serializer_class = object_serializer.Meta.list_serializer_class
- except AttributeError:
- list_serializer_class = serializers.ListSerializer
+ self.page = paginator.page(page_number)
+ except InvalidPage as exc:
+ msg = self.invalid_page_message.format(
+ page_number=page_number, message=six.text_type(exc)
+ )
+ raise NotFound(msg)
- self.fields[results_field] = list_serializer_class(
- child=object_serializer(*args, **kwargs),
- source='object_list'
- )
+ if paginator.count > 1:
+ # The browsable API should display pagination controls.
+ self.display_page_controls = True
+
+ self.request = request
+ return self.page
+
+ def get_paginated_response(self, data):
+ return Response(OrderedDict([
+ ('count', self.page.paginator.count),
+ ('next', self.get_next_link()),
+ ('previous', self.get_previous_link()),
+ ('results', data)
+ ]))
+
+ def get_page_size(self, request):
+ if self.paginate_by_param:
+ try:
+ return _positive_int(
+ request.query_params[self.paginate_by_param],
+ strict=True,
+ cutoff=self.max_paginate_by
+ )
+ except (KeyError, ValueError):
+ pass
+
+ return self.paginate_by
+
+ def get_next_link(self):
+ if not self.page.has_next():
+ return None
+ url = self.request.build_absolute_uri()
+ page_number = self.page.next_page_number()
+ return replace_query_param(url, self.page_query_param, page_number)
+
+ def get_previous_link(self):
+ if not self.page.has_previous():
+ return None
+ url = self.request.build_absolute_uri()
+ page_number = self.page.previous_page_number()
+ if page_number == 1:
+ return remove_query_param(url, self.page_query_param)
+ return replace_query_param(url, self.page_query_param, page_number)
+
+ def get_html_context(self):
+ base_url = self.request.build_absolute_uri()
+ def page_number_to_url(page_number):
+ if page_number == 1:
+ return remove_query_param(base_url, self.page_query_param)
+ else:
+ return replace_query_param(base_url, self.page_query_param, page_number)
-class PaginationSerializer(BasePaginationSerializer):
+ current = self.page.number
+ final = self.page.paginator.num_pages
+ page_numbers = _get_displayed_page_numbers(current, final)
+ page_links = _get_page_links(page_numbers, current, page_number_to_url)
+
+ return {
+ 'previous_url': self.get_previous_link(),
+ 'next_url': self.get_next_link(),
+ 'page_links': page_links
+ }
+
+ def to_html(self):
+ template = loader.get_template(self.template)
+ context = Context(self.get_html_context())
+ return template.render(context)
+
+
+class LimitOffsetPagination(BasePagination):
"""
- A default implementation of a pagination serializer.
+ A limit/offset based style. For example:
+
+ http://api.example.org/accounts/?limit=100
+ http://api.example.org/accounts/?offset=400&limit=100
"""
- count = serializers.ReadOnlyField(source='paginator.count')
- next = NextPageField(source='*')
- previous = PreviousPageField(source='*')
+ default_limit = api_settings.PAGINATE_BY
+ limit_query_param = 'limit'
+ offset_query_param = 'offset'
+ max_limit = None
+ template = 'rest_framework/pagination/numbers.html'
+
+ def paginate_queryset(self, queryset, request, view=None):
+ self.limit = self.get_limit(request)
+ self.offset = self.get_offset(request)
+ self.count = _get_count(queryset)
+ self.request = request
+ if self.count > self.limit:
+ self.display_page_controls = True
+ return queryset[self.offset:self.offset + self.limit]
+
+ def get_paginated_response(self, data):
+ return Response(OrderedDict([
+ ('count', self.count),
+ ('next', self.get_next_link()),
+ ('previous', self.get_previous_link()),
+ ('results', data)
+ ]))
+
+ def get_limit(self, request):
+ if self.limit_query_param:
+ try:
+ return _positive_int(
+ request.query_params[self.limit_query_param],
+ cutoff=self.max_limit
+ )
+ except (KeyError, ValueError):
+ pass
+
+ return self.default_limit
+
+ def get_offset(self, request):
+ try:
+ return _positive_int(
+ request.query_params[self.offset_query_param],
+ )
+ except (KeyError, ValueError):
+ return 0
+
+ def get_next_link(self):
+ if self.offset + self.limit >= self.count:
+ return None
+
+ url = self.request.build_absolute_uri()
+ offset = self.offset + self.limit
+ return replace_query_param(url, self.offset_query_param, offset)
+
+ def get_previous_link(self):
+ if self.offset <= 0:
+ return None
+
+ url = self.request.build_absolute_uri()
+
+ if self.offset - self.limit <= 0:
+ return remove_query_param(url, self.offset_query_param)
+
+ offset = self.offset - self.limit
+ return replace_query_param(url, self.offset_query_param, offset)
+
+ def get_html_context(self):
+ base_url = self.request.build_absolute_uri()
+ current = _divide_with_ceil(self.offset, self.limit) + 1
+ # The number of pages is a little bit fiddly.
+ # We need to sum both the number of pages from current offset to end
+ # plus the number of pages up to the current offset.
+ # When offset is not strictly divisible by the limit then we may
+ # end up introducing an extra page as an artifact.
+ final = (
+ _divide_with_ceil(self.count - self.offset, self.limit) +
+ _divide_with_ceil(self.offset, self.limit)
+ )
+
+ def page_number_to_url(page_number):
+ if page_number == 1:
+ return remove_query_param(base_url, self.offset_query_param)
+ else:
+ offset = self.offset + ((page_number - current) * self.limit)
+ return replace_query_param(base_url, self.offset_query_param, offset)
+
+ page_numbers = _get_displayed_page_numbers(current, final)
+ page_links = _get_page_links(page_numbers, current, page_number_to_url)
+
+ return {
+ 'previous_url': self.get_previous_link(),
+ 'next_url': self.get_next_link(),
+ 'page_links': page_links
+ }
+
+ def to_html(self):
+ template = loader.get_template(self.template)
+ context = Context(self.get_html_context())
+ return template.render(context)
+
+
+class CursorPagination(BasePagination):
+ # Determine how/if True, False and None positions work - do the string
+ # encodings work with Django queryset filters?
+ # Consider a max offset cap.
+ # Tidy up the `get_ordering` API (eg remove queryset from it)
+ cursor_query_param = 'cursor'
+ page_size = api_settings.PAGINATE_BY
+ invalid_cursor_message = _('Invalid cursor')
+ ordering = None
+ template = 'rest_framework/pagination/previous_and_next.html'
+
+ def paginate_queryset(self, queryset, request, view=None):
+ self.base_url = request.build_absolute_uri()
+ self.ordering = self.get_ordering(request, queryset, view)
+
+ # Determine if we have a cursor, and if so then decode it.
+ encoded = request.query_params.get(self.cursor_query_param)
+ if encoded is None:
+ self.cursor = None
+ (offset, reverse, current_position) = (0, False, None)
+ else:
+ self.cursor = _decode_cursor(encoded)
+ if self.cursor is None:
+ raise NotFound(self.invalid_cursor_message)
+ (offset, reverse, current_position) = self.cursor
+
+ # Cursor pagination always enforces an ordering.
+ if reverse:
+ queryset = queryset.order_by(*_reverse_ordering(self.ordering))
+ else:
+ queryset = queryset.order_by(*self.ordering)
+
+ # If we have a cursor with a fixed position then filter by that.
+ if current_position is not None:
+ order = self.ordering[0]
+ is_reversed = order.startswith('-')
+ order_attr = order.lstrip('-')
+
+ # Test for: (cursor reversed) XOR (queryset reversed)
+ if self.cursor.reverse != is_reversed:
+ kwargs = {order_attr + '__lt': current_position}
+ else:
+ kwargs = {order_attr + '__gt': current_position}
+
+ queryset = queryset.filter(**kwargs)
+
+ # If we have an offset cursor then offset the entire page by that amount.
+ # We also always fetch an extra item in order to determine if there is a
+ # page following on from this one.
+ results = list(queryset[offset:offset + self.page_size + 1])
+ self.page = results[:self.page_size]
+
+ # Determine the position of the final item following the page.
+ if len(results) > len(self.page):
+ has_following_postion = True
+ following_position = self._get_position_from_instance(results[-1], self.ordering)
+ else:
+ has_following_postion = False
+ following_position = None
+
+ # If we have a reverse queryset, then the query ordering was in reverse
+ # so we need to reverse the items again before returning them to the user.
+ if reverse:
+ self.page = list(reversed(self.page))
+
+ if reverse:
+ # Determine next and previous positions for reverse cursors.
+ self.has_next = (current_position is not None) or (offset > 0)
+ self.has_previous = has_following_postion
+ if self.has_next:
+ self.next_position = current_position
+ if self.has_previous:
+ self.previous_position = following_position
+ else:
+ # Determine next and previous positions for forward cursors.
+ self.has_next = has_following_postion
+ self.has_previous = (current_position is not None) or (offset > 0)
+ if self.has_next:
+ self.next_position = following_position
+ if self.has_previous:
+ self.previous_position = current_position
+
+ # Display page controls in the browsable API if there is more
+ # than one page.
+ if self.has_previous or self.has_next:
+ self.display_page_controls = True
+
+ return self.page
+
+ def get_next_link(self):
+ if not self.has_next:
+ return None
+
+ if self.cursor and self.cursor.reverse and self.cursor.offset != 0:
+ # If we're reversing direction and we have an offset cursor
+ # then we cannot use the first position we find as a marker.
+ compare = self._get_position_from_instance(self.page[-1], self.ordering)
+ else:
+ compare = self.next_position
+ offset = 0
+
+ for item in reversed(self.page):
+ position = self._get_position_from_instance(item, self.ordering)
+ if position != compare:
+ # The item in this position and the item following it
+ # have different positions. We can use this position as
+ # our marker.
+ break
+
+ # The item in this postion has the same position as the item
+ # following it, we can't use it as a marker position, so increment
+ # the offset and keep seeking to the previous item.
+ compare = position
+ offset += 1
+
+ else:
+ # There were no unique positions in the page.
+ if not self.has_previous:
+ # We are on the first page.
+ # Our cursor will have an offset equal to the page size,
+ # but no position to filter against yet.
+ offset = self.page_size
+ position = None
+ elif self.cursor.reverse:
+ # The change in direction will introduce a paging artifact,
+ # where we end up skipping forward a few extra items.
+ offset = 0
+ position = self.previous_position
+ else:
+ # Use the position from the existing cursor and increment
+ # it's offset by the page size.
+ offset = self.cursor.offset + self.page_size
+ position = self.previous_position
+
+ cursor = Cursor(offset=offset, reverse=False, position=position)
+ encoded = _encode_cursor(cursor)
+ return replace_query_param(self.base_url, self.cursor_query_param, encoded)
+
+ def get_previous_link(self):
+ if not self.has_previous:
+ return None
+
+ if self.cursor and not self.cursor.reverse and self.cursor.offset != 0:
+ # If we're reversing direction and we have an offset cursor
+ # then we cannot use the first position we find as a marker.
+ compare = self._get_position_from_instance(self.page[0], self.ordering)
+ else:
+ compare = self.previous_position
+ offset = 0
+
+ for item in self.page:
+ position = self._get_position_from_instance(item, self.ordering)
+ if position != compare:
+ # The item in this position and the item following it
+ # have different positions. We can use this position as
+ # our marker.
+ break
+
+ # The item in this postion has the same position as the item
+ # following it, we can't use it as a marker position, so increment
+ # the offset and keep seeking to the previous item.
+ compare = position
+ offset += 1
+
+ else:
+ # There were no unique positions in the page.
+ if not self.has_next:
+ # We are on the final page.
+ # Our cursor will have an offset equal to the page size,
+ # but no position to filter against yet.
+ offset = self.page_size
+ position = None
+ elif self.cursor.reverse:
+ # Use the position from the existing cursor and increment
+ # it's offset by the page size.
+ offset = self.cursor.offset + self.page_size
+ position = self.next_position
+ else:
+ # The change in direction will introduce a paging artifact,
+ # where we end up skipping back a few extra items.
+ offset = 0
+ position = self.next_position
+
+ cursor = Cursor(offset=offset, reverse=True, position=position)
+ encoded = _encode_cursor(cursor)
+ return replace_query_param(self.base_url, self.cursor_query_param, encoded)
+
+ def get_ordering(self, request, queryset, view):
+ """
+ Return a tuple of strings, that may be used in an `order_by` method.
+ """
+ ordering_filters = [
+ filter_cls for filter_cls in getattr(view, 'filter_backends', [])
+ if hasattr(filter_cls, 'get_ordering')
+ ]
+
+ if ordering_filters:
+ # If a filter exists on the view that implements `get_ordering`
+ # then we defer to that filter to determine the ordering.
+ filter_cls = ordering_filters[0]
+ filter_instance = filter_cls()
+ ordering = filter_instance.get_ordering(request, queryset, view)
+ assert ordering is not None, (
+ 'Using cursor pagination, but filter class {filter_cls} '
+ 'returned a `None` ordering.'.format(
+ filter_cls=filter_cls.__name__
+ )
+ )
+ else:
+ # The default case is to check for an `ordering` attribute,
+ # first on the view instance, and then on this pagination instance.
+ ordering = getattr(view, 'ordering', getattr(self, 'ordering', None))
+ assert ordering is not None, (
+ 'Using cursor pagination, but no ordering attribute was declared '
+ 'on the view or on the pagination class.'
+ )
+
+ assert isinstance(ordering, (six.string_types, list, tuple)), (
+ 'Invalid ordering. Expected string or tuple, but got {type}'.format(
+ type=type(ordering).__name__
+ )
+ )
+
+ if isinstance(ordering, six.string_types):
+ return (ordering,)
+ return tuple(ordering)
+
+ def _get_position_from_instance(self, instance, ordering):
+ attr = getattr(instance, ordering[0].lstrip('-'))
+ return six.text_type(attr)
+
+ def get_paginated_response(self, data):
+ return Response(OrderedDict([
+ ('next', self.get_next_link()),
+ ('previous', self.get_previous_link()),
+ ('results', data)
+ ]))
+
+ def get_html_context(self):
+ return {
+ 'previous_url': self.get_previous_link(),
+ 'next_url': self.get_next_link()
+ }
+
+ def to_html(self):
+ template = loader.get_template(self.template)
+ context = Context(self.get_html_context())
+ return template.render(context)
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 1efab85b..437d1339 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -14,12 +14,9 @@ from django.http.multipartparser import MultiPartParserError, parse_header, Chun
from django.utils import six
from django.utils.six.moves.urllib import parse as urlparse
from django.utils.encoding import force_text
-from rest_framework.compat import etree, yaml
from rest_framework.exceptions import ParseError
from rest_framework import renderers
import json
-import datetime
-import decimal
class DataAndFiles(object):
@@ -67,29 +64,6 @@ class JSONParser(BaseParser):
raise ParseError('JSON parse error - %s' % six.text_type(exc))
-class YAMLParser(BaseParser):
- """
- Parses YAML-serialized data.
- """
-
- media_type = 'application/yaml'
-
- def parse(self, stream, media_type=None, parser_context=None):
- """
- Parses the incoming bytestream as YAML and returns the resulting data.
- """
- assert yaml, 'YAMLParser requires pyyaml to be installed'
-
- parser_context = parser_context or {}
- encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
-
- try:
- data = stream.read().decode(encoding)
- return yaml.safe_load(data)
- except (ValueError, yaml.parser.ParserError) as exc:
- raise ParseError('YAML parse error - %s' % six.text_type(exc))
-
-
class FormParser(BaseParser):
"""
Parser for form data.
@@ -138,78 +112,6 @@ class MultiPartParser(BaseParser):
raise ParseError('Multipart form parse error - %s' % six.text_type(exc))
-class XMLParser(BaseParser):
- """
- XML parser.
- """
-
- media_type = 'application/xml'
-
- def parse(self, stream, media_type=None, parser_context=None):
- """
- Parses the incoming bytestream as XML and returns the resulting data.
- """
- assert etree, 'XMLParser requires defusedxml to be installed'
-
- parser_context = parser_context or {}
- encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
- parser = etree.DefusedXMLParser(encoding=encoding)
- try:
- tree = etree.parse(stream, parser=parser, forbid_dtd=True)
- except (etree.ParseError, ValueError) as exc:
- raise ParseError('XML parse error - %s' % six.text_type(exc))
- data = self._xml_convert(tree.getroot())
-
- return data
-
- def _xml_convert(self, element):
- """
- convert the xml `element` into the corresponding python object
- """
-
- children = list(element)
-
- if len(children) == 0:
- return self._type_convert(element.text)
- else:
- # if the fist child tag is list-item means all children are list-item
- if children[0].tag == "list-item":
- data = []
- for child in children:
- data.append(self._xml_convert(child))
- else:
- data = {}
- for child in children:
- data[child.tag] = self._xml_convert(child)
-
- return data
-
- def _type_convert(self, value):
- """
- Converts the value returned by the XMl parse into the equivalent
- Python type
- """
- if value is None:
- return value
-
- try:
- return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
- except ValueError:
- pass
-
- try:
- return int(value)
- except ValueError:
- pass
-
- try:
- return decimal.Decimal(value)
- except decimal.InvalidOperation:
- pass
-
- return value
-
-
class FileUploadParser(BaseParser):
"""
Parser for file upload data.
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index 3f6f5961..9069d315 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -3,8 +3,7 @@ Provides a set of pluggable permission policies.
"""
from __future__ import unicode_literals
from django.http import Http404
-from rest_framework.compat import (get_model_name, oauth2_provider_scope,
- oauth2_constants)
+from rest_framework.compat import get_model_name
SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
@@ -199,28 +198,3 @@ class DjangoObjectPermissions(DjangoModelPermissions):
return False
return True
-
-
-class TokenHasReadWriteScope(BasePermission):
- """
- The request is authenticated as a user and the token used has the right scope
- """
-
- def has_permission(self, request, view):
- token = request.auth
- read_only = request.method in SAFE_METHODS
-
- if not token:
- return False
-
- if hasattr(token, 'resource'): # OAuth 1
- return read_only or not request.auth.resource.is_readonly
- elif hasattr(token, 'scope'): # OAuth 2
- required = oauth2_constants.READ if read_only else oauth2_constants.WRITE
- return oauth2_provider_scope.check(required, request.auth.scope)
-
- assert False, (
- 'TokenHasReadWriteScope requires either the'
- '`OAuthAuthentication` or `OAuth2Authentication` authentication '
- 'class to be used.'
- )
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 13793f37..66857a41 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -130,7 +130,7 @@ class StringRelatedField(RelatedField):
class PrimaryKeyRelatedField(RelatedField):
default_error_messages = {
'required': _('This field is required.'),
- 'does_not_exist': _("Invalid pk '{pk_value}' - object does not exist."),
+ 'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'),
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
}
@@ -154,7 +154,7 @@ class HyperlinkedRelatedField(RelatedField):
default_error_messages = {
'required': _('This field is required.'),
- 'no_match': _('Invalid hyperlink - No URL match'),
+ 'no_match': _('Invalid hyperlink - No URL match.'),
'incorrect_match': _('Invalid hyperlink - Incorrect URL match.'),
'does_not_exist': _('Invalid hyperlink - Object does not exist.'),
'incorrect_type': _('Incorrect type. Expected URL string, received {data_type}.'),
@@ -292,7 +292,7 @@ class SlugRelatedField(RelatedField):
"""
default_error_messages = {
- 'does_not_exist': _("Object with {slug_name}={value} does not exist."),
+ 'does_not_exist': _('Object with {slug_name}={value} does not exist.'),
'invalid': _('Invalid value.'),
}
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 584332e6..6256acdd 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -17,11 +17,8 @@ from django.http.multipartparser import parse_header
from django.template import Context, RequestContext, loader, Template
from django.test.client import encode_multipart
from django.utils import six
-from django.utils.encoding import smart_text
-from django.utils.xmlutils import SimplerXMLGenerator
-from django.utils.six.moves import StringIO
from rest_framework import exceptions, serializers, status, VERSION
-from rest_framework.compat import SHORT_SEPARATORS, LONG_SEPARATORS, yaml
+from rest_framework.compat import SHORT_SEPARATORS, LONG_SEPARATORS, INDENT_SEPARATORS
from rest_framework.exceptions import ParseError
from rest_framework.settings import api_settings
from rest_framework.request import is_form_media_type, override_method
@@ -90,7 +87,11 @@ class JSONRenderer(BaseRenderer):
renderer_context = renderer_context or {}
indent = self.get_indent(accepted_media_type, renderer_context)
- separators = SHORT_SEPARATORS if (indent is None and self.compact) else LONG_SEPARATORS
+
+ if indent is None:
+ separators = SHORT_SEPARATORS if self.compact else LONG_SEPARATORS
+ else:
+ separators = INDENT_SEPARATORS
ret = json.dumps(
data, cls=self.encoder_class,
@@ -112,112 +113,6 @@ class JSONRenderer(BaseRenderer):
return ret
-class JSONPRenderer(JSONRenderer):
- """
- Renderer which serializes to json,
- wrapping the json output in a callback function.
- """
-
- media_type = 'application/javascript'
- format = 'jsonp'
- callback_parameter = 'callback'
- default_callback = 'callback'
- charset = 'utf-8'
-
- def get_callback(self, renderer_context):
- """
- Determine the name of the callback to wrap around the json output.
- """
- request = renderer_context.get('request', None)
- params = request and request.query_params or {}
- return params.get(self.callback_parameter, self.default_callback)
-
- def render(self, data, accepted_media_type=None, renderer_context=None):
- """
- Renders into jsonp, wrapping the json output in a callback function.
-
- Clients may set the callback function name using a query parameter
- on the URL, for example: ?callback=exampleCallbackName
- """
- renderer_context = renderer_context or {}
- callback = self.get_callback(renderer_context)
- json = super(JSONPRenderer, self).render(data, accepted_media_type,
- renderer_context)
- return callback.encode(self.charset) + b'(' + json + b');'
-
-
-class XMLRenderer(BaseRenderer):
- """
- Renderer which serializes to XML.
- """
-
- media_type = 'application/xml'
- format = 'xml'
- charset = 'utf-8'
-
- def render(self, data, accepted_media_type=None, renderer_context=None):
- """
- Renders `data` into serialized XML.
- """
- if data is None:
- return ''
-
- stream = StringIO()
-
- xml = SimplerXMLGenerator(stream, self.charset)
- xml.startDocument()
- xml.startElement("root", {})
-
- self._to_xml(xml, data)
-
- xml.endElement("root")
- xml.endDocument()
- return stream.getvalue()
-
- def _to_xml(self, xml, data):
- if isinstance(data, (list, tuple)):
- for item in data:
- xml.startElement("list-item", {})
- self._to_xml(xml, item)
- xml.endElement("list-item")
-
- elif isinstance(data, dict):
- for key, value in six.iteritems(data):
- xml.startElement(key, {})
- self._to_xml(xml, value)
- xml.endElement(key)
-
- elif data is None:
- # Don't output any value
- pass
-
- else:
- xml.characters(smart_text(data))
-
-
-class YAMLRenderer(BaseRenderer):
- """
- Renderer which serializes to YAML.
- """
-
- media_type = 'application/yaml'
- format = 'yaml'
- encoder = encoders.SafeDumper
- charset = 'utf-8'
- ensure_ascii = False
-
- def render(self, data, accepted_media_type=None, renderer_context=None):
- """
- Renders `data` into serialized YAML.
- """
- assert yaml, 'YAMLRenderer requires pyyaml to be installed'
-
- if data is None:
- return ''
-
- return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii)
-
-
class TemplateHTMLRenderer(BaseRenderer):
"""
An HTML renderer for use with templates.
@@ -696,6 +591,11 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer_content_type += ' ;%s' % renderer.charset
response_headers['Content-Type'] = renderer_content_type
+ if hasattr(view, 'paginator') and view.paginator.display_page_controls:
+ paginator = view.paginator
+ else:
+ paginator = None
+
context = {
'content': self.get_content(renderer, data, accepted_media_type, renderer_context),
'view': view,
@@ -704,6 +604,7 @@ class BrowsableAPIRenderer(BaseRenderer):
'description': self.get_description(view),
'name': self.get_name(view),
'version': VERSION,
+ 'paginator': paginator,
'breadcrumblist': self.get_breadcrumbs(request),
'allowed_methods': view.allowed_methods,
'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
diff --git a/rest_framework/request.py b/rest_framework/request.py
index cfbbdecc..bf6ff670 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -107,6 +107,10 @@ def clone_request(request, method):
ret.accepted_renderer = request.accepted_renderer
if hasattr(request, 'accepted_media_type'):
ret.accepted_media_type = request.accepted_media_type
+ if hasattr(request, 'version'):
+ ret.version = request.version
+ if hasattr(request, 'versioning_scheme'):
+ ret.versioning_scheme = request.versioning_scheme
return ret
diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py
index a74e8aa2..8fcca55b 100644
--- a/rest_framework/reverse.py
+++ b/rest_framework/reverse.py
@@ -9,6 +9,18 @@ from django.utils.functional import lazy
def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
"""
+ If versioning is being used then we pass any `reverse` calls through
+ to the versioning scheme instance, so that the resulting URL
+ can be modified if needed.
+ """
+ scheme = getattr(request, 'versioning_scheme', None)
+ if scheme is not None:
+ return scheme.reverse(viewname, args, kwargs, request, format, **extra)
+ return _reverse(viewname, args, kwargs, request, format, **extra)
+
+
+def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ """
Same as `django.core.urlresolvers.reverse`, but optionally takes a request
and returns a fully qualified URL, using the request to get the base URL.
"""
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 42d1e370..a3b8196b 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -12,7 +12,7 @@ response content is handled by parsers and renderers.
"""
from __future__ import unicode_literals
from django.db import models
-from django.db.models.fields import FieldDoesNotExist, Field as DjangoField
+from django.db.models.fields import FieldDoesNotExist, Field as DjangoModelField
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import postgres_fields, unicode_to_repr
from rest_framework.utils import model_meta
@@ -327,7 +327,9 @@ class Serializer(BaseSerializer):
Returns a list of validator callables.
"""
# Used by the lazily-evaluated `validators` property.
- return getattr(getattr(self, 'Meta', None), 'validators', [])
+ meta = getattr(self, 'Meta', None)
+ validators = getattr(meta, 'validators', None)
+ return validators[:] if validators else []
def get_initial(self):
if hasattr(self, 'initial_data'):
@@ -477,7 +479,7 @@ class ListSerializer(BaseSerializer):
many = True
default_error_messages = {
- 'not_a_list': _('Expected a list of items but got type `{input_type}`.')
+ 'not_a_list': _('Expected a list of items but got type "{input_type}".')
}
def __init__(self, *args, **kwargs):
@@ -702,8 +704,7 @@ class ModelSerializer(Serializer):
you need you should either declare the extra/differing fields explicitly on
the serializer class, or simply use a `Serializer` class.
"""
-
- _field_mapping = ClassLookupDict({
+ serializer_field_mapping = {
models.AutoField: IntegerField,
models.BigIntegerField: IntegerField,
models.BooleanField: BooleanField,
@@ -725,10 +726,11 @@ class ModelSerializer(Serializer):
models.SmallIntegerField: IntegerField,
models.TextField: CharField,
models.TimeField: TimeField,
- models.URLField: URLField
- # Note: Some version-specific mappings also defined below.
- })
- _related_class = PrimaryKeyRelatedField
+ models.URLField: URLField,
+ }
+ serializer_related_class = PrimaryKeyRelatedField
+
+ # Default `create` and `update` behavior...
def create(self, validated_data):
"""
@@ -799,69 +801,81 @@ class ModelSerializer(Serializer):
return instance
- def get_validators(self):
- # If the validators have been declared explicitly then use that.
- validators = getattr(getattr(self, 'Meta', None), 'validators', None)
- if validators is not None:
- return validators
+ # Determine the fields to apply...
- # Determine the default set of validators.
- validators = []
- model_class = self.Meta.model
- field_names = set([
- field.source for field in self.fields.values()
- if (field.source != '*') and ('.' not in field.source)
- ])
+ def get_fields(self):
+ """
+ Return the dict of field names -> field instances that should be
+ used for `self.fields` when instantiating the serializer.
+ """
+ assert hasattr(self, 'Meta'), (
+ 'Class {serializer_class} missing "Meta" attribute'.format(
+ serializer_class=self.__class__.__name__
+ )
+ )
+ assert hasattr(self.Meta, 'model'), (
+ 'Class {serializer_class} missing "Meta.model" attribute'.format(
+ serializer_class=self.__class__.__name__
+ )
+ )
- # Note that we make sure to check `unique_together` both on the
- # base model class, but also on any parent classes.
- for parent_class in [model_class] + list(model_class._meta.parents.keys()):
- for unique_together in parent_class._meta.unique_together:
- if field_names.issuperset(set(unique_together)):
- validator = UniqueTogetherValidator(
- queryset=parent_class._default_manager,
- fields=unique_together
- )
- validators.append(validator)
+ declared_fields = copy.deepcopy(self._declared_fields)
+ model = getattr(self.Meta, 'model')
+ depth = getattr(self.Meta, 'depth', 0)
- # Add any unique_for_date/unique_for_month/unique_for_year constraints.
- info = model_meta.get_field_info(model_class)
- for field_name, field in info.fields_and_pk.items():
- if field.unique_for_date and field_name in field_names:
- validator = UniqueForDateValidator(
- queryset=model_class._default_manager,
- field=field_name,
- date_field=field.unique_for_date
- )
- validators.append(validator)
+ if depth is not None:
+ assert depth >= 0, "'depth' may not be negative."
+ assert depth <= 10, "'depth' may not be greater than 10."
- if field.unique_for_month and field_name in field_names:
- validator = UniqueForMonthValidator(
- queryset=model_class._default_manager,
- field=field_name,
- date_field=field.unique_for_month
- )
- validators.append(validator)
+ # Retrieve metadata about fields & relationships on the model class.
+ info = model_meta.get_field_info(model)
+ field_names = self.get_field_names(declared_fields, info)
- if field.unique_for_year and field_name in field_names:
- validator = UniqueForYearValidator(
- queryset=model_class._default_manager,
- field=field_name,
- date_field=field.unique_for_year
- )
- validators.append(validator)
+ # Determine any extra field arguments and hidden fields that
+ # should be included
+ extra_kwargs = self.get_extra_kwargs()
+ extra_kwargs, hidden_fields = self.get_uniqueness_extra_kwargs(
+ field_names, declared_fields, extra_kwargs
+ )
- return validators
+ # Determine the fields that should be included on the serializer.
+ fields = OrderedDict()
- def get_fields(self):
- declared_fields = copy.deepcopy(self._declared_fields)
+ for field_name in field_names:
+ # If the field is explicitly declared on the class then use that.
+ if field_name in declared_fields:
+ fields[field_name] = declared_fields[field_name]
+ continue
- ret = OrderedDict()
- model = getattr(self.Meta, 'model')
+ # Determine the serializer field class and keyword arguments.
+ field_class, field_kwargs = self.build_field(
+ field_name, info, model, depth
+ )
+
+ # Include any kwargs defined in `Meta.extra_kwargs`
+ field_kwargs = self.build_field_kwargs(
+ field_kwargs, extra_kwargs, field_name
+ )
+
+ # Create the serializer field.
+ fields[field_name] = field_class(**field_kwargs)
+
+ # Add in any hidden fields.
+ fields.update(hidden_fields)
+
+ return fields
+
+ # Methods for determining the set of field names to include...
+
+ def get_field_names(self, declared_fields, info):
+ """
+ Returns the list of all field names that should be created when
+ instantiating this serializer class. This is based on the default
+ set of fields, but also takes into account the `Meta.fields` or
+ `Meta.exclude` options if they have been specified.
+ """
fields = getattr(self.Meta, 'fields', None)
exclude = getattr(self.Meta, 'exclude', None)
- depth = getattr(self.Meta, 'depth', 0)
- extra_kwargs = getattr(self.Meta, 'extra_kwargs', {})
if fields and not isinstance(fields, (list, tuple)):
raise TypeError(
@@ -875,201 +889,199 @@ class ModelSerializer(Serializer):
type(exclude).__name__
)
- assert not (fields and exclude), "Cannot set both 'fields' and 'exclude'."
-
- extra_kwargs = self._include_additional_options(extra_kwargs)
+ assert not (fields and exclude), (
+ "Cannot set both 'fields' and 'exclude' options on "
+ "serializer {serializer_class}.".format(
+ serializer_class=self.__class__.__name__
+ )
+ )
- # Retrieve metadata about fields & relationships on the model class.
- info = model_meta.get_field_info(model)
+ if fields is not None:
+ # Ensure that all declared fields have also been included in the
+ # `Meta.fields` option.
- if fields is None:
- # Use the default set of field names if none is supplied explicitly.
- fields = self._get_default_field_names(declared_fields, info)
- exclude = getattr(self.Meta, 'exclude', None)
- if exclude is not None:
- for field_name in exclude:
- assert field_name in fields, (
- 'The field in the `exclude` option must be a model field. Got %s.' %
- field_name
+ # Do not require any fields that are declared a parent class,
+ # in order to allow serializer subclasses to only include
+ # a subset of fields.
+ required_field_names = set(declared_fields)
+ for cls in self.__class__.__bases__:
+ required_field_names -= set(getattr(cls, '_declared_fields', []))
+
+ for field_name in required_field_names:
+ assert field_name in fields, (
+ "The field '{field_name}' was declared on serializer "
+ "{serializer_class}, but has not been included in the "
+ "'fields' option.".format(
+ field_name=field_name,
+ serializer_class=self.__class__.__name__
)
- fields.remove(field_name)
- else:
- # Check that any fields declared on the class are
- # also explicitly included in `Meta.fields`.
+ )
+ return fields
+
+ # Use the default set of field names if `Meta.fields` is not specified.
+ fields = self.get_default_field_names(declared_fields, info)
+
+ if exclude is not None:
+ # If `Meta.exclude` is included, then remove those fields.
+ for field_name in exclude:
+ assert field_name in fields, (
+ "The field '{field_name}' was include on serializer "
+ "{serializer_class} in the 'exclude' option, but does "
+ "not match any model field.".format(
+ field_name=field_name,
+ serializer_class=self.__class__.__name__
+ )
+ )
+ fields.remove(field_name)
- # Note that we ignore any fields that were declared on a parent
- # class, in order to support only including a subset of fields
- # when subclassing serializers.
- declared_field_names = set(declared_fields.keys())
- for cls in self.__class__.__bases__:
- declared_field_names -= set(getattr(cls, '_declared_fields', []))
+ return fields
- missing_fields = declared_field_names - set(fields)
- assert not missing_fields, (
- 'Field `%s` has been declared on serializer `%s`, but '
- 'is missing from `Meta.fields`.' %
- (list(missing_fields)[0], self.__class__.__name__)
- )
+ def get_default_field_names(self, declared_fields, model_info):
+ """
+ Return the default list of field names that will be used if the
+ `Meta.fields` option is not specified.
+ """
+ return (
+ [model_info.pk.name] +
+ list(declared_fields.keys()) +
+ list(model_info.fields.keys()) +
+ list(model_info.forward_relations.keys())
+ )
- # Determine the set of model fields, and the fields that they map to.
- # We actually only need this to deal with the slightly awkward case
- # of supporting `unique_for_date`/`unique_for_month`/`unique_for_year`.
- model_field_mapping = {}
- for field_name in fields:
- if field_name in declared_fields:
- field = declared_fields[field_name]
- source = field.source or field_name
+ # Methods for constructing serializer fields...
+
+ def build_field(self, field_name, info, model_class, nested_depth):
+ """
+ Return a two tuple of (cls, kwargs) to build a serializer field with.
+ """
+ if field_name in info.fields_and_pk:
+ model_field = info.fields_and_pk[field_name]
+ return self.build_standard_field(field_name, model_field)
+
+ elif field_name in info.relations:
+ relation_info = info.relations[field_name]
+ if not nested_depth:
+ return self.build_relational_field(field_name, relation_info)
else:
- try:
- source = extra_kwargs[field_name]['source']
- except KeyError:
- source = field_name
- # Model fields will always have a simple source mapping,
- # they can't be nested attribute lookups.
- if '.' not in source and source != '*':
- model_field_mapping[source] = field_name
+ return self.build_nested_field(field_name, relation_info, nested_depth)
- # Determine if we need any additional `HiddenField` or extra keyword
- # arguments to deal with `unique_for` dates that are required to
- # be in the input data in order to validate it.
- hidden_fields = {}
- unique_constraint_names = set()
+ elif hasattr(model_class, field_name):
+ return self.build_property_field(field_name, model_class)
- for model_field_name, field_name in model_field_mapping.items():
- try:
- model_field = model._meta.get_field(model_field_name)
- except FieldDoesNotExist:
- continue
+ elif field_name == api_settings.URL_FIELD_NAME:
+ return self.build_url_field(field_name, model_class)
- if not isinstance(model_field, DjangoField):
- continue
+ return self.build_unknown_field(field_name, model_class)
- # Include each of the `unique_for_*` field names.
- unique_constraint_names |= set([
- model_field.unique_for_date,
- model_field.unique_for_month,
- model_field.unique_for_year
- ])
+ def build_standard_field(self, field_name, model_field):
+ """
+ Create regular model fields.
+ """
+ field_mapping = ClassLookupDict(self.serializer_field_mapping)
+
+ field_class = field_mapping[model_field]
+ field_kwargs = get_field_kwargs(field_name, model_field)
+
+ if 'choices' in field_kwargs:
+ # Fields with choices get coerced into `ChoiceField`
+ # instead of using their regular typed field.
+ field_class = ChoiceField
+ if not issubclass(field_class, ModelField):
+ # `model_field` is only valid for the fallback case of
+ # `ModelField`, which is used when no other typed field
+ # matched to the model field.
+ field_kwargs.pop('model_field', None)
+ if not issubclass(field_class, CharField) and not issubclass(field_class, ChoiceField):
+ # `allow_blank` is only valid for textual fields.
+ field_kwargs.pop('allow_blank', None)
+
+ return field_class, field_kwargs
+
+ def build_relational_field(self, field_name, relation_info):
+ """
+ Create fields for forward and reverse relationships.
+ """
+ field_class = self.serializer_related_class
+ field_kwargs = get_relation_kwargs(field_name, relation_info)
- unique_constraint_names -= set([None])
+ # `view_name` is only valid for hyperlinked relationships.
+ if not issubclass(field_class, HyperlinkedRelatedField):
+ field_kwargs.pop('view_name', None)
- # Include each of the `unique_together` field names,
- # so long as all the field names are included on the serializer.
- for parent_class in [model] + list(model._meta.parents.keys()):
- for unique_together_list in parent_class._meta.unique_together:
- if set(fields).issuperset(set(unique_together_list)):
- unique_constraint_names |= set(unique_together_list)
+ return field_class, field_kwargs
- # Now we have all the field names that have uniqueness constraints
- # applied, we can add the extra 'required=...' or 'default=...'
- # arguments that are appropriate to these fields, or add a `HiddenField` for it.
- for unique_constraint_name in unique_constraint_names:
- # Get the model field that is referred too.
- unique_constraint_field = model._meta.get_field(unique_constraint_name)
+ def build_nested_field(self, field_name, relation_info, nested_depth):
+ """
+ Create nested fields for forward and reverse relationships.
+ """
+ class NestedSerializer(ModelSerializer):
+ class Meta:
+ model = relation_info.related_model
+ depth = nested_depth
- if getattr(unique_constraint_field, 'auto_now_add', None):
- default = CreateOnlyDefault(timezone.now)
- elif getattr(unique_constraint_field, 'auto_now', None):
- default = timezone.now
- elif unique_constraint_field.has_default():
- default = unique_constraint_field.default
- else:
- default = empty
+ field_class = NestedSerializer
+ field_kwargs = get_nested_relation_kwargs(relation_info)
- if unique_constraint_name in model_field_mapping:
- # The corresponding field is present in the serializer
- if unique_constraint_name not in extra_kwargs:
- extra_kwargs[unique_constraint_name] = {}
- if default is empty:
- if 'required' not in extra_kwargs[unique_constraint_name]:
- extra_kwargs[unique_constraint_name]['required'] = True
- else:
- if 'default' not in extra_kwargs[unique_constraint_name]:
- extra_kwargs[unique_constraint_name]['default'] = default
- elif default is not empty:
- # The corresponding field is not present in the,
- # serializer. We have a default to use for it, so
- # add in a hidden field that populates it.
- hidden_fields[unique_constraint_name] = HiddenField(default=default)
+ return field_class, field_kwargs
- # Now determine the fields that should be included on the serializer.
- for field_name in fields:
- if field_name in declared_fields:
- # Field is explicitly declared on the class, use that.
- ret[field_name] = declared_fields[field_name]
- continue
+ def build_property_field(self, field_name, model_class):
+ """
+ Create a read only field for model methods and properties.
+ """
+ field_class = ReadOnlyField
+ field_kwargs = {}
- elif field_name in info.fields_and_pk:
- # Create regular model fields.
- model_field = info.fields_and_pk[field_name]
- field_cls = self._field_mapping[model_field]
- kwargs = get_field_kwargs(field_name, model_field)
- if 'choices' in kwargs:
- # Fields with choices get coerced into `ChoiceField`
- # instead of using their regular typed field.
- field_cls = ChoiceField
- if not issubclass(field_cls, ModelField):
- # `model_field` is only valid for the fallback case of
- # `ModelField`, which is used when no other typed field
- # matched to the model field.
- kwargs.pop('model_field', None)
- if not issubclass(field_cls, CharField) and not issubclass(field_cls, ChoiceField):
- # `allow_blank` is only valid for textual fields.
- kwargs.pop('allow_blank', None)
-
- elif field_name in info.relations:
- # Create forward and reverse relationships.
- relation_info = info.relations[field_name]
- if depth:
- field_cls = self._get_nested_class(depth, relation_info)
- kwargs = get_nested_relation_kwargs(relation_info)
- else:
- field_cls = self._related_class
- kwargs = get_relation_kwargs(field_name, relation_info)
- # `view_name` is only valid for hyperlinked relationships.
- if not issubclass(field_cls, HyperlinkedRelatedField):
- kwargs.pop('view_name', None)
-
- elif hasattr(model, field_name):
- # Create a read only field for model methods and properties.
- field_cls = ReadOnlyField
- kwargs = {}
-
- elif field_name == api_settings.URL_FIELD_NAME:
- # Create the URL field.
- field_cls = HyperlinkedIdentityField
- kwargs = get_url_kwargs(model)
+ return field_class, field_kwargs
- else:
- raise ImproperlyConfigured(
- 'Field name `%s` is not valid for model `%s`.' %
- (field_name, model.__class__.__name__)
- )
+ def build_url_field(self, field_name, model_class):
+ """
+ Create a field representing the object's own URL.
+ """
+ field_class = HyperlinkedIdentityField
+ field_kwargs = get_url_kwargs(model_class)
+
+ return field_class, field_kwargs
- # Populate any kwargs defined in `Meta.extra_kwargs`
- extras = extra_kwargs.get(field_name, {})
- if extras.get('read_only', False):
- for attr in [
- 'required', 'default', 'allow_blank', 'allow_null',
- 'min_length', 'max_length', 'min_value', 'max_value',
- 'validators', 'queryset'
- ]:
- kwargs.pop(attr, None)
+ def build_unknown_field(self, field_name, model_class):
+ """
+ Raise an error on any unknown fields.
+ """
+ raise ImproperlyConfigured(
+ 'Field name `%s` is not valid for model `%s`.' %
+ (field_name, model_class.__name__)
+ )
- if extras.get('default') and kwargs.get('required') is False:
- kwargs.pop('required')
+ def build_field_kwargs(self, kwargs, extra_kwargs, field_name):
+ """
+ Include an 'extra_kwargs' that have been included for this field,
+ possibly removing any incompatible existing keyword arguments.
+ """
+ extras = extra_kwargs.get(field_name, {})
- kwargs.update(extras)
+ if extras.get('read_only', False):
+ for attr in [
+ 'required', 'default', 'allow_blank', 'allow_null',
+ 'min_length', 'max_length', 'min_value', 'max_value',
+ 'validators', 'queryset'
+ ]:
+ kwargs.pop(attr, None)
- # Create the serializer field.
- ret[field_name] = field_cls(**kwargs)
+ if extras.get('default') and kwargs.get('required') is False:
+ kwargs.pop('required')
- for field_name, field in hidden_fields.items():
- ret[field_name] = field
+ kwargs.update(extras)
- return ret
+ return kwargs
+
+ # Methods for determining additional keyword arguments to apply...
+
+ def get_extra_kwargs(self):
+ """
+ Return a dictionary mapping field names to a dictionary of
+ additional keyword arguments.
+ """
+ extra_kwargs = getattr(self.Meta, 'extra_kwargs', {})
- def _include_additional_options(self, extra_kwargs):
read_only_fields = getattr(self.Meta, 'read_only_fields', None)
if read_only_fields is not None:
for field_name in read_only_fields:
@@ -1117,21 +1129,204 @@ class ModelSerializer(Serializer):
return extra_kwargs
- def _get_default_field_names(self, declared_fields, model_info):
+ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
+ """
+ Return any additional field options that need to be included as a
+ result of uniqueness constraints on the model. This is returned as
+ a two-tuple of:
+
+ ('dict of updated extra kwargs', 'mapping of hidden fields')
+ """
+ model = getattr(self.Meta, 'model')
+ model_fields = self._get_model_fields(
+ field_names, declared_fields, extra_kwargs
+ )
+
+ # Determine if we need any additional `HiddenField` or extra keyword
+ # arguments to deal with `unique_for` dates that are required to
+ # be in the input data in order to validate it.
+ unique_constraint_names = set()
+
+ for model_field in model_fields.values():
+ # Include each of the `unique_for_*` field names.
+ unique_constraint_names |= set([
+ model_field.unique_for_date,
+ model_field.unique_for_month,
+ model_field.unique_for_year
+ ])
+
+ unique_constraint_names -= set([None])
+
+ # Include each of the `unique_together` field names,
+ # so long as all the field names are included on the serializer.
+ for parent_class in [model] + list(model._meta.parents.keys()):
+ for unique_together_list in parent_class._meta.unique_together:
+ if set(field_names).issuperset(set(unique_together_list)):
+ unique_constraint_names |= set(unique_together_list)
+
+ # Now we have all the field names that have uniqueness constraints
+ # applied, we can add the extra 'required=...' or 'default=...'
+ # arguments that are appropriate to these fields, or add a `HiddenField` for it.
+ hidden_fields = {}
+ uniqueness_extra_kwargs = {}
+
+ for unique_constraint_name in unique_constraint_names:
+ # Get the model field that is referred too.
+ unique_constraint_field = model._meta.get_field(unique_constraint_name)
+
+ if getattr(unique_constraint_field, 'auto_now_add', None):
+ default = CreateOnlyDefault(timezone.now)
+ elif getattr(unique_constraint_field, 'auto_now', None):
+ default = timezone.now
+ elif unique_constraint_field.has_default():
+ default = unique_constraint_field.default
+ else:
+ default = empty
+
+ if unique_constraint_name in model_fields:
+ # The corresponding field is present in the serializer
+ if default is empty:
+ uniqueness_extra_kwargs[unique_constraint_name] = {'required': True}
+ else:
+ uniqueness_extra_kwargs[unique_constraint_name] = {'default': default}
+ elif default is not empty:
+ # The corresponding field is not present in the,
+ # serializer. We have a default to use for it, so
+ # add in a hidden field that populates it.
+ hidden_fields[unique_constraint_name] = HiddenField(default=default)
+
+ # Update `extra_kwargs` with any new options.
+ for key, value in uniqueness_extra_kwargs.items():
+ if key in extra_kwargs:
+ extra_kwargs[key].update(value)
+ else:
+ extra_kwargs[key] = value
+
+ return extra_kwargs, hidden_fields
+
+ def _get_model_fields(self, field_names, declared_fields, extra_kwargs):
+ """
+ Returns all the model fields that are being mapped to by fields
+ on the serializer class.
+ Returned as a dict of 'model field name' -> 'model field'.
+ Used internally by `get_uniqueness_field_options`.
+ """
+ model = getattr(self.Meta, 'model')
+ model_fields = {}
+
+ for field_name in field_names:
+ if field_name in declared_fields:
+ # If the field is declared on the serializer
+ field = declared_fields[field_name]
+ source = field.source or field_name
+ else:
+ try:
+ source = extra_kwargs[field_name]['source']
+ except KeyError:
+ source = field_name
+
+ if '.' in source or source == '*':
+ # Model fields will always have a simple source mapping,
+ # they can't be nested attribute lookups.
+ continue
+
+ try:
+ field = model._meta.get_field(source)
+ if isinstance(field, DjangoModelField):
+ model_fields[source] = field
+ except FieldDoesNotExist:
+ pass
+
+ return model_fields
+
+ # Determine the validators to apply...
+
+ def get_validators(self):
+ """
+ Determine the set of validators to use when instantiating serializer.
+ """
+ # If the validators have been declared explicitly then use that.
+ validators = getattr(getattr(self, 'Meta', None), 'validators', None)
+ if validators is not None:
+ return validators[:]
+
+ # Otherwise use the default set of validators.
return (
- [model_info.pk.name] +
- list(declared_fields.keys()) +
- list(model_info.fields.keys()) +
- list(model_info.forward_relations.keys())
+ self.get_unique_together_validators() +
+ self.get_unique_for_date_validators()
)
- def _get_nested_class(self, nested_depth, relation_info):
- class NestedSerializer(ModelSerializer):
- class Meta:
- model = relation_info.related
- depth = nested_depth - 1
+ def get_unique_together_validators(self):
+ """
+ Determine a default set of validators for any unique_together contraints.
+ """
+ model_class_inheritance_tree = (
+ [self.Meta.model] +
+ list(self.Meta.model._meta.parents.keys())
+ )
+
+ # The field names we're passing though here only include fields
+ # which may map onto a model field. Any dotted field name lookups
+ # cannot map to a field, and must be a traversal, so we're not
+ # including those.
+ field_names = set([
+ field.source for field in self.fields.values()
+ if (field.source != '*') and ('.' not in field.source)
+ ])
+
+ # Note that we make sure to check `unique_together` both on the
+ # base model class, but also on any parent classes.
+ validators = []
+ for parent_class in model_class_inheritance_tree:
+ for unique_together in parent_class._meta.unique_together:
+ if field_names.issuperset(set(unique_together)):
+ validator = UniqueTogetherValidator(
+ queryset=parent_class._default_manager,
+ fields=unique_together
+ )
+ validators.append(validator)
+ return validators
+
+ def get_unique_for_date_validators(self):
+ """
+ Determine a default set of validators for the following contraints:
+
+ * unique_for_date
+ * unique_for_month
+ * unique_for_year
+ """
+ info = model_meta.get_field_info(self.Meta.model)
+ default_manager = self.Meta.model._default_manager
+ field_names = [field.source for field in self.fields.values()]
+
+ validators = []
+
+ for field_name, field in info.fields_and_pk.items():
+ if field.unique_for_date and field_name in field_names:
+ validator = UniqueForDateValidator(
+ queryset=default_manager,
+ field=field_name,
+ date_field=field.unique_for_date
+ )
+ validators.append(validator)
- return NestedSerializer
+ if field.unique_for_month and field_name in field_names:
+ validator = UniqueForMonthValidator(
+ queryset=default_manager,
+ field=field_name,
+ date_field=field.unique_for_month
+ )
+ validators.append(validator)
+
+ if field.unique_for_year and field_name in field_names:
+ validator = UniqueForYearValidator(
+ queryset=default_manager,
+ field=field_name,
+ date_field=field.unique_for_year
+ )
+ validators.append(validator)
+
+ return validators
if hasattr(models, 'UUIDField'):
@@ -1152,9 +1347,13 @@ class HyperlinkedModelSerializer(ModelSerializer):
* A 'url' field is included instead of the 'id' field.
* Relationships to other instances are hyperlinks, instead of primary keys.
"""
- _related_class = HyperlinkedRelatedField
+ serializer_related_class = HyperlinkedRelatedField
- def _get_default_field_names(self, declared_fields, model_info):
+ def get_default_field_names(self, declared_fields, model_info):
+ """
+ Return the default list of field names that will be used if the
+ `Meta.fields` option is not specified.
+ """
return (
[api_settings.URL_FIELD_NAME] +
list(declared_fields.keys()) +
@@ -1162,10 +1361,16 @@ class HyperlinkedModelSerializer(ModelSerializer):
list(model_info.forward_relations.keys())
)
- def _get_nested_class(self, nested_depth, relation_info):
+ def build_nested_field(self, field_name, relation_info, nested_depth):
+ """
+ Create nested fields for forward and reverse relationships.
+ """
class NestedSerializer(HyperlinkedModelSerializer):
class Meta:
- model = relation_info.related
+ model = relation_info.related_model
depth = nested_depth - 1
- return NestedSerializer
+ field_class = NestedSerializer
+ field_kwargs = get_nested_relation_kwargs(relation_info)
+
+ return field_class, field_kwargs
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index e5e5edaf..7331f265 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -5,11 +5,11 @@ For example your project's `settings.py` file might look like this:
REST_FRAMEWORK = {
'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
- 'rest_framework.renderers.YAMLRenderer',
+ 'rest_framework.renderers.TemplateHTMLRenderer',
)
'DEFAULT_PARSER_CLASSES': (
'rest_framework.parsers.JSONParser',
- 'rest_framework.parsers.YAMLParser',
+ 'rest_framework.parsers.TemplateHTMLRenderer',
)
}
@@ -47,9 +47,10 @@ DEFAULTS = {
'DEFAULT_THROTTLE_CLASSES': (),
'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata',
+ 'DEFAULT_VERSIONING_CLASS': None,
# Generic view behavior
- 'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer',
+ 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
'DEFAULT_FILTER_BACKENDS': (),
# Throttling
@@ -68,6 +69,11 @@ DEFAULTS = {
'SEARCH_PARAM': 'search',
'ORDERING_PARAM': 'ordering',
+ # Versioning
+ 'DEFAULT_VERSION': None,
+ 'ALLOWED_VERSIONS': None,
+ 'VERSION_PARAM': 'version',
+
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
@@ -124,7 +130,8 @@ IMPORT_STRINGS = (
'DEFAULT_THROTTLE_CLASSES',
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
'DEFAULT_METADATA_CLASS',
- 'DEFAULT_PAGINATION_SERIALIZER_CLASS',
+ 'DEFAULT_VERSIONING_CLASS',
+ 'DEFAULT_PAGINATION_CLASS',
'DEFAULT_FILTER_BACKENDS',
'EXCEPTION_HANDLER',
'TEST_REQUEST_RENDERER_CLASSES',
@@ -140,7 +147,9 @@ def perform_import(val, setting_name):
If the given setting is a string import notation,
then perform the necessary import or imports.
"""
- if isinstance(val, six.string_types):
+ if val is None:
+ return None
+ elif isinstance(val, six.string_types):
return import_from_string(val, setting_name)
elif isinstance(val, (list, tuple)):
return [import_from_string(item, setting_name) for item in val]
diff --git a/rest_framework/static/rest_framework/css/bootstrap-tweaks.css b/rest_framework/static/rest_framework/css/bootstrap-tweaks.css
index 36c7be48..04f12ed3 100644
--- a/rest_framework/static/rest_framework/css/bootstrap-tweaks.css
+++ b/rest_framework/static/rest_framework/css/bootstrap-tweaks.css
@@ -60,6 +60,23 @@ a single block in the template.
color: #C20000;
}
+.pagination>.disabled>a,
+.pagination>.disabled>a:hover,
+.pagination>.disabled>a:focus {
+ cursor: not-allowed;
+ pointer-events: none;
+}
+
+.pager>.disabled>a,
+.pager>.disabled>a:hover,
+.pager>.disabled>a:focus {
+ pointer-events: none;
+}
+
+.pager .next {
+ margin-left: 10px;
+}
+
/*=== dabapps bootstrap styles ====*/
html {
@@ -185,10 +202,6 @@ body a:hover {
color: #c20000;
}
-#content a span {
- text-decoration: underline;
- }
-
.request-info {
clear:both;
}
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index e9668193..877387f2 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -119,9 +119,18 @@
<div class="page-header">
<h1>{{ name }}</h1>
</div>
+ <div style="float:left">
{% block description %}
{{ description }}
{% endblock %}
+ </div>
+
+ {% if paginator %}
+ <nav style="float: right">
+ {% get_pagination_html paginator %}
+ </nav>
+ {% endif %}
+
<div class="request-info" style="clear: both" >
<pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre>
</div>
diff --git a/rest_framework/templates/rest_framework/pagination/numbers.html b/rest_framework/templates/rest_framework/pagination/numbers.html
new file mode 100644
index 00000000..04045810
--- /dev/null
+++ b/rest_framework/templates/rest_framework/pagination/numbers.html
@@ -0,0 +1,27 @@
+<ul class="pagination" style="margin: 5px 0 10px 0">
+ {% if previous_url %}
+ <li><a href="{{ previous_url }}" aria-label="Previous"><span aria-hidden="true">&laquo;</span></a></li>
+ {% else %}
+ <li class="disabled"><a href="#" aria-label="Previous"><span aria-hidden="true">&laquo;</span></a></li>
+ {% endif %}
+
+ {% for page_link in page_links %}
+ {% if page_link.is_break %}
+ <li class="disabled">
+ <a href="#"><span aria-hidden="true">&hellip;</span></a>
+ </li>
+ {% else %}
+ {% if page_link.is_active %}
+ <li class="active"><a href="{{ page_link.url }}">{{ page_link.number }}</a></li>
+ {% else %}
+ <li><a href="{{ page_link.url }}">{{ page_link.number }}</a></li>
+ {% endif %}
+ {% endif %}
+ {% endfor %}
+
+ {% if next_url %}
+ <li><a href="{{ next_url }}" aria-label="Next"><span aria-hidden="true">&raquo;</span></a></li>
+ {% else %}
+ <li class="disabled"><a href="#" aria-label="Next"><span aria-hidden="true">&raquo;</span></a></li>
+ {% endif %}
+</ul>
diff --git a/rest_framework/templates/rest_framework/pagination/previous_and_next.html b/rest_framework/templates/rest_framework/pagination/previous_and_next.html
new file mode 100644
index 00000000..eacbfff4
--- /dev/null
+++ b/rest_framework/templates/rest_framework/pagination/previous_and_next.html
@@ -0,0 +1,12 @@
+<ul class="pager">
+{% if previous_url %}
+ <li class="previous"><a href="{{ previous_url }}">&laquo; Previous</a></li>
+{% else %}
+ <li class="previous disabled"><a href="#">&laquo; Previous</a></li>
+{% endif %}
+{% if next_url %}
+ <li class="next"><a href="{{ next_url }}">Next &raquo;</a></li>
+{% else %}
+ <li class="next disabled"><a href="#">Next &raquo;</li>
+{% endif %}
+</ul>
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index 69e03af4..a969836f 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -1,36 +1,25 @@
from __future__ import unicode_literals, absolute_import
from django import template
from django.core.urlresolvers import reverse, NoReverseMatch
-from django.http import QueryDict
from django.utils import six
-from django.utils.six.moves.urllib import parse as urlparse
from django.utils.encoding import iri_to_uri, force_text
from django.utils.html import escape
from django.utils.safestring import SafeData, mark_safe
from django.utils.html import smart_urlquote
from rest_framework.renderers import HTMLFormRenderer
+from rest_framework.utils.urls import replace_query_param
import re
register = template.Library()
-
-def replace_query_param(url, key, val):
- """
- Given a URL and a key/val pair, set or replace an item in the query
- parameters of the URL, and return the new URL.
- """
- (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url)
- query_dict = QueryDict(query).copy()
- query_dict[key] = val
- query = query_dict.urlencode()
- return urlparse.urlunsplit((scheme, netloc, path, query, fragment))
-
-
# Regex for adding classes to html snippets
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
-# And the template tags themselves...
+@register.simple_tag
+def get_pagination_html(pager):
+ return pager.to_html()
+
@register.simple_tag
def render_field(field, style=None):
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index bf753271..2160d18b 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -6,11 +6,9 @@ from django.db.models.query import QuerySet
from django.utils import six, timezone
from django.utils.encoding import force_text
from django.utils.functional import Promise
-from rest_framework.compat import OrderedDict, total_seconds
-from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList
+from rest_framework.compat import total_seconds
import datetime
import decimal
-import types
import json
import uuid
@@ -61,65 +59,3 @@ class JSONEncoder(json.JSONEncoder):
elif hasattr(obj, '__iter__'):
return tuple(item for item in obj)
return super(JSONEncoder, self).default(obj)
-
-
-try:
- import yaml
-except ImportError:
- SafeDumper = None
-else:
- # Adapted from http://pyyaml.org/attachment/ticket/161/use_ordered_dict.py
- class SafeDumper(yaml.SafeDumper):
- """
- Handles decimals as strings.
- Handles OrderedDicts as usual dicts, but preserves field order, rather
- than the usual behaviour of sorting the keys.
- """
- def represent_decimal(self, data):
- return self.represent_scalar('tag:yaml.org,2002:str', six.text_type(data))
-
- def represent_mapping(self, tag, mapping, flow_style=None):
- value = []
- node = yaml.MappingNode(tag, value, flow_style=flow_style)
- if self.alias_key is not None:
- self.represented_objects[self.alias_key] = node
- best_style = True
- if hasattr(mapping, 'items'):
- mapping = list(mapping.items())
- if not isinstance(mapping, OrderedDict):
- mapping.sort()
- for item_key, item_value in mapping:
- node_key = self.represent_data(item_key)
- node_value = self.represent_data(item_value)
- if not (isinstance(node_key, yaml.ScalarNode) and not node_key.style):
- best_style = False
- if not (isinstance(node_value, yaml.ScalarNode) and not node_value.style):
- best_style = False
- value.append((node_key, node_value))
- if flow_style is None:
- if self.default_flow_style is not None:
- node.flow_style = self.default_flow_style
- else:
- node.flow_style = best_style
- return node
-
- SafeDumper.add_representer(
- decimal.Decimal,
- SafeDumper.represent_decimal
- )
- SafeDumper.add_representer(
- OrderedDict,
- yaml.representer.SafeRepresenter.represent_dict
- )
- SafeDumper.add_representer(
- ReturnDict,
- yaml.representer.SafeRepresenter.represent_dict
- )
- SafeDumper.add_representer(
- ReturnList,
- yaml.representer.SafeRepresenter.represent_list
- )
- SafeDumper.add_representer(
- types.GeneratorType,
- yaml.representer.SafeRepresenter.represent_list
- )
diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py
index 470af51b..8b6f005e 100644
--- a/rest_framework/utils/formatting.py
+++ b/rest_framework/utils/formatting.py
@@ -2,12 +2,10 @@
Utility functions to return a formatted name and description for a given view.
"""
from __future__ import unicode_literals
-import re
-
from django.utils.html import escape
from django.utils.safestring import mark_safe
-
from rest_framework.compat import apply_markdown, force_text
+import re
def remove_trailing_string(content, trailing):
@@ -59,4 +57,5 @@ def markup_description(description):
description = apply_markdown(description)
else:
description = escape(description).replace('\n', '<br />')
+ description = '<p>' + description + '</p>'
return mark_safe(description)
diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py
index 6a5835f5..dd92f8b6 100644
--- a/rest_framework/utils/model_meta.py
+++ b/rest_framework/utils/model_meta.py
@@ -24,7 +24,7 @@ FieldInfo = namedtuple('FieldResult', [
RelationInfo = namedtuple('RelationInfo', [
'model_field',
- 'related',
+ 'related_model',
'to_many',
'has_through_model'
])
@@ -98,7 +98,7 @@ def _get_forward_relationships(opts):
for field in [field for field in opts.fields if field.serialize and field.rel]:
forward_relations[field.name] = RelationInfo(
model_field=field,
- related=_resolve_model(field.rel.to),
+ related_model=_resolve_model(field.rel.to),
to_many=False,
has_through_model=False
)
@@ -107,7 +107,7 @@ def _get_forward_relationships(opts):
for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo(
model_field=field,
- related=_resolve_model(field.rel.to),
+ related_model=_resolve_model(field.rel.to),
to_many=True,
has_through_model=(
not field.rel.through._meta.auto_created
@@ -131,7 +131,7 @@ def _get_reverse_relationships(opts):
related = getattr(relation, 'related_model', relation.model)
reverse_relations[accessor_name] = RelationInfo(
model_field=None,
- related=related,
+ related_model=related,
to_many=relation.field.rel.multiple,
has_through_model=False
)
@@ -142,7 +142,7 @@ def _get_reverse_relationships(opts):
related = getattr(relation, 'related_model', relation.model)
reverse_relations[accessor_name] = RelationInfo(
model_field=None,
- related=related,
+ related_model=related,
to_many=True,
has_through_model=(
(getattr(relation.field.rel, 'through', None) is not None)
diff --git a/rest_framework/utils/urls.py b/rest_framework/utils/urls.py
new file mode 100644
index 00000000..880ef9ed
--- /dev/null
+++ b/rest_framework/utils/urls.py
@@ -0,0 +1,25 @@
+from django.utils.six.moves.urllib import parse as urlparse
+
+
+def replace_query_param(url, key, val):
+ """
+ Given a URL and a key/val pair, set or replace an item in the query
+ parameters of the URL, and return the new URL.
+ """
+ (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url)
+ query_dict = urlparse.parse_qs(query)
+ query_dict[key] = [val]
+ query = urlparse.urlencode(sorted(list(query_dict.items())), doseq=True)
+ return urlparse.urlunsplit((scheme, netloc, path, query, fragment))
+
+
+def remove_query_param(url, key):
+ """
+ Given a URL and a key/val pair, remove an item in the query
+ parameters of the URL, and return the new URL.
+ """
+ (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url)
+ query_dict = urlparse.parse_qs(query)
+ query_dict.pop(key, None)
+ query = urlparse.urlencode(sorted(list(query_dict.items())), doseq=True)
+ return urlparse.urlunsplit((scheme, netloc, path, query, fragment))
diff --git a/rest_framework/versioning.py b/rest_framework/versioning.py
new file mode 100644
index 00000000..a07b629f
--- /dev/null
+++ b/rest_framework/versioning.py
@@ -0,0 +1,174 @@
+# coding: utf-8
+from __future__ import unicode_literals
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import exceptions
+from rest_framework.compat import unicode_http_header
+from rest_framework.reverse import _reverse
+from rest_framework.settings import api_settings
+from rest_framework.templatetags.rest_framework import replace_query_param
+from rest_framework.utils.mediatypes import _MediaType
+import re
+
+
+class BaseVersioning(object):
+ default_version = api_settings.DEFAULT_VERSION
+ allowed_versions = api_settings.ALLOWED_VERSIONS
+ version_param = api_settings.VERSION_PARAM
+
+ def determine_version(self, request, *args, **kwargs):
+ msg = '{cls}.determine_version() must be implemented.'
+ raise NotImplementedError(msg.format(
+ cls=self.__class__.__name__
+ ))
+
+ def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ return _reverse(viewname, args, kwargs, request, format, **extra)
+
+ def is_allowed_version(self, version):
+ if not self.allowed_versions:
+ return True
+ return (version == self.default_version) or (version in self.allowed_versions)
+
+
+class AcceptHeaderVersioning(BaseVersioning):
+ """
+ GET /something/ HTTP/1.1
+ Host: example.com
+ Accept: application/json; version=1.0
+ """
+ invalid_version_message = _('Invalid version in "Accept" header.')
+
+ def determine_version(self, request, *args, **kwargs):
+ media_type = _MediaType(request.accepted_media_type)
+ version = media_type.params.get(self.version_param, self.default_version)
+ version = unicode_http_header(version)
+ if not self.is_allowed_version(version):
+ raise exceptions.NotAcceptable(self.invalid_version_message)
+ return version
+
+ # We don't need to implement `reverse`, as the versioning is based
+ # on the `Accept` header, not on the request URL.
+
+
+class URLPathVersioning(BaseVersioning):
+ """
+ To the client this is the same style as `NamespaceVersioning`.
+ The difference is in the backend - this implementation uses
+ Django's URL keyword arguments to determine the version.
+
+ An example URL conf for two views that accept two different versions.
+
+ urlpatterns = [
+ url(r'^(?P<version>{v1,v2})/users/$', users_list, name='users-list'),
+ url(r'^(?P<version>{v1,v2})/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail')
+ ]
+
+ GET /1.0/something/ HTTP/1.1
+ Host: example.com
+ Accept: application/json
+ """
+ invalid_version_message = _('Invalid version in URL path.')
+
+ def determine_version(self, request, *args, **kwargs):
+ version = kwargs.get(self.version_param, self.default_version)
+ if not self.is_allowed_version(version):
+ raise exceptions.NotFound(self.invalid_version_message)
+ return version
+
+ def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ if request.version is not None:
+ kwargs = {} if (kwargs is None) else kwargs
+ kwargs[self.version_param] = request.version
+
+ return super(URLPathVersioning, self).reverse(
+ viewname, args, kwargs, request, format, **extra
+ )
+
+
+class NamespaceVersioning(BaseVersioning):
+ """
+ To the client this is the same style as `URLPathVersioning`.
+ The difference is in the backend - this implementation uses
+ Django's URL namespaces to determine the version.
+
+ An example URL conf that is namespaced into two seperate versions
+
+ # users/urls.py
+ urlpatterns = [
+ url(r'^/users/$', users_list, name='users-list'),
+ url(r'^/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail')
+ ]
+
+ # urls.py
+ urlpatterns = [
+ url(r'^v1/', include('users.urls', namespace='v1')),
+ url(r'^v2/', include('users.urls', namespace='v2'))
+ ]
+
+ GET /1.0/something/ HTTP/1.1
+ Host: example.com
+ Accept: application/json
+ """
+ invalid_version_message = _('Invalid version in URL path.')
+
+ def determine_version(self, request, *args, **kwargs):
+ resolver_match = getattr(request, 'resolver_match', None)
+ if (resolver_match is None or not resolver_match.namespace):
+ return self.default_version
+ version = resolver_match.namespace
+ if not self.is_allowed_version(version):
+ raise exceptions.NotFound(self.invalid_version_message)
+ return version
+
+ def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ if request.version is not None:
+ viewname = request.version + ':' + viewname
+ return super(NamespaceVersioning, self).reverse(
+ viewname, args, kwargs, request, format, **extra
+ )
+
+
+class HostNameVersioning(BaseVersioning):
+ """
+ GET /something/ HTTP/1.1
+ Host: v1.example.com
+ Accept: application/json
+ """
+ hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$')
+ invalid_version_message = _('Invalid version in hostname.')
+
+ def determine_version(self, request, *args, **kwargs):
+ hostname, seperator, port = request.get_host().partition(':')
+ match = self.hostname_regex.match(hostname)
+ if not match:
+ return self.default_version
+ version = match.group(1)
+ if not self.is_allowed_version(version):
+ raise exceptions.NotFound(self.invalid_version_message)
+ return version
+
+ # We don't need to implement `reverse`, as the hostname will already be
+ # preserved as part of the REST framework `reverse` implementation.
+
+
+class QueryParameterVersioning(BaseVersioning):
+ """
+ GET /something/?version=0.1 HTTP/1.1
+ Host: example.com
+ Accept: application/json
+ """
+ invalid_version_message = _('Invalid version in query parameter.')
+
+ def determine_version(self, request, *args, **kwargs):
+ version = request.query_params.get(self.version_param)
+ if not self.is_allowed_version(version):
+ raise exceptions.NotFound(self.invalid_version_message)
+ return version
+
+ def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ url = super(QueryParameterVersioning, self).reverse(
+ viewname, args, kwargs, request, format, **extra
+ )
+ if request.version is not None:
+ return replace_query_param(url, self.version_param, request.version)
+ return url
diff --git a/rest_framework/views.py b/rest_framework/views.py
index bc870417..12bb78bd 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -2,6 +2,8 @@
Provides an APIView class that is the base of all views in REST framework.
"""
from __future__ import unicode_literals
+import inspect
+import warnings
from django.core.exceptions import PermissionDenied
from django.http import Http404
@@ -46,7 +48,7 @@ def get_view_description(view_cls, html=False):
return description
-def exception_handler(exc):
+def exception_handler(exc, context):
"""
Returns the response that should be used for any given exception.
@@ -93,6 +95,7 @@ class APIView(View):
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
metadata_class = api_settings.DEFAULT_METADATA_CLASS
+ versioning_class = api_settings.DEFAULT_VERSIONING_CLASS
# Allow dependency injection of other settings to make testing easier.
settings = api_settings
@@ -184,6 +187,18 @@ class APIView(View):
'request': getattr(self, 'request', None)
}
+ def get_exception_handler_context(self):
+ """
+ Returns a dict that is passed through to EXCEPTION_HANDLER,
+ as the `context` argument.
+ """
+ return {
+ 'view': self,
+ 'args': getattr(self, 'args', ()),
+ 'kwargs': getattr(self, 'kwargs', {}),
+ 'request': getattr(self, 'request', None)
+ }
+
def get_view_name(self):
"""
Return the view name, as used in OPTIONS responses and in the
@@ -300,6 +315,16 @@ class APIView(View):
if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait())
+ def determine_version(self, request, *args, **kwargs):
+ """
+ If versioning is being used, then determine any API version for the
+ incoming request. Returns a two-tuple of (version, versioning_scheme)
+ """
+ if self.versioning_class is None:
+ return (None, None)
+ scheme = self.versioning_class()
+ return (scheme.determine_version(request, *args, **kwargs), scheme)
+
# Dispatch methods
def initialize_request(self, request, *args, **kwargs):
@@ -308,11 +333,13 @@ class APIView(View):
"""
parser_context = self.get_parser_context(request)
- return Request(request,
- parsers=self.get_parsers(),
- authenticators=self.get_authenticators(),
- negotiator=self.get_content_negotiator(),
- parser_context=parser_context)
+ return Request(
+ request,
+ parsers=self.get_parsers(),
+ authenticators=self.get_authenticators(),
+ negotiator=self.get_content_negotiator(),
+ parser_context=parser_context
+ )
def initial(self, request, *args, **kwargs):
"""
@@ -329,6 +356,10 @@ class APIView(View):
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
+ # Determine the API version, if versioning is in use.
+ version, scheme = self.determine_version(request, *args, **kwargs)
+ request.version, request.versioning_scheme = version, scheme
+
def finalize_response(self, request, response, *args, **kwargs):
"""
Returns the final response object.
@@ -369,7 +400,18 @@ class APIView(View):
else:
exc.status_code = status.HTTP_403_FORBIDDEN
- response = self.settings.EXCEPTION_HANDLER(exc)
+ exception_handler = self.settings.EXCEPTION_HANDLER
+
+ if len(inspect.getargspec(exception_handler).args) == 1:
+ warnings.warn(
+ 'The `exception_handler(exc)` call signature is deprecated. '
+ 'Use `exception_handler(exc, context) instead.',
+ PendingDeprecationWarning
+ )
+ response = exception_handler(exc)
+ else:
+ context = self.get_exception_handler_context()
+ response = exception_handler(exc, context)
if response is None:
raise