diff options
Diffstat (limited to 'rest_framework')
29 files changed, 1949 insertions, 1111 deletions
| diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 4832ad33..11db0585 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -3,14 +3,10 @@ Provides various authentication policies.  """  from __future__ import unicode_literals  import base64 -  from django.contrib.auth import authenticate -from django.core.exceptions import ImproperlyConfigured  from django.middleware.csrf import CsrfViewMiddleware -from django.conf import settings +from django.utils.translation import ugettext_lazy as _  from rest_framework import exceptions, HTTP_HEADER_ENCODING -from rest_framework.compat import oauth, oauth_provider, oauth_provider_store -from rest_framework.compat import oauth2_provider, provider_now, check_nonce  from rest_framework.authtoken.models import Token @@ -70,16 +66,16 @@ class BasicAuthentication(BaseAuthentication):              return None          if len(auth) == 1: -            msg = 'Invalid basic header. No credentials provided.' +            msg = _('Invalid basic header. No credentials provided.')              raise exceptions.AuthenticationFailed(msg)          elif len(auth) > 2: -            msg = 'Invalid basic header. Credentials string should not contain spaces.' +            msg = _('Invalid basic header. Credentials string should not contain spaces.')              raise exceptions.AuthenticationFailed(msg)          try:              auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':')          except (TypeError, UnicodeDecodeError): -            msg = 'Invalid basic header. Credentials not correctly base64 encoded' +            msg = _('Invalid basic header. Credentials not correctly base64 encoded.')              raise exceptions.AuthenticationFailed(msg)          userid, password = auth_parts[0], auth_parts[2] @@ -91,7 +87,7 @@ class BasicAuthentication(BaseAuthentication):          """          user = authenticate(username=userid, password=password)          if user is None or not user.is_active: -            raise exceptions.AuthenticationFailed('Invalid username/password') +            raise exceptions.AuthenticationFailed(_('Invalid username/password.'))          return (user, None)      def authenticate_header(self, request): @@ -157,10 +153,10 @@ class TokenAuthentication(BaseAuthentication):              return None          if len(auth) == 1: -            msg = 'Invalid token header. No credentials provided.' +            msg = _('Invalid token header. No credentials provided.')              raise exceptions.AuthenticationFailed(msg)          elif len(auth) > 2: -            msg = 'Invalid token header. Token string should not contain spaces.' +            msg = _('Invalid token header. Token string should not contain spaces.')              raise exceptions.AuthenticationFailed(msg)          return self.authenticate_credentials(auth[1]) @@ -169,190 +165,12 @@ class TokenAuthentication(BaseAuthentication):          try:              token = self.model.objects.get(key=key)          except self.model.DoesNotExist: -            raise exceptions.AuthenticationFailed('Invalid token') +            raise exceptions.AuthenticationFailed(_('Invalid token.'))          if not token.user.is_active: -            raise exceptions.AuthenticationFailed('User inactive or deleted') +            raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))          return (token.user, token)      def authenticate_header(self, request):          return 'Token' - - -class OAuthAuthentication(BaseAuthentication): -    """ -    OAuth 1.0a authentication backend using `django-oauth-plus` and `oauth2`. - -    Note: The `oauth2` package actually provides oauth1.0a support.  Urg. -          We import it from the `compat` module as `oauth`. -    """ -    www_authenticate_realm = 'api' - -    def __init__(self, *args, **kwargs): -        super(OAuthAuthentication, self).__init__(*args, **kwargs) - -        if oauth is None: -            raise ImproperlyConfigured( -                "The 'oauth2' package could not be imported." -                "It is required for use with the 'OAuthAuthentication' class.") - -        if oauth_provider is None: -            raise ImproperlyConfigured( -                "The 'django-oauth-plus' package could not be imported." -                "It is required for use with the 'OAuthAuthentication' class.") - -    def authenticate(self, request): -        """ -        Returns two-tuple of (user, token) if authentication succeeds, -        or None otherwise. -        """ -        try: -            oauth_request = oauth_provider.utils.get_oauth_request(request) -        except oauth.Error as err: -            raise exceptions.AuthenticationFailed(err.message) - -        if not oauth_request: -            return None - -        oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES - -        found = any(param for param in oauth_params if param in oauth_request) -        missing = list(param for param in oauth_params if param not in oauth_request) - -        if not found: -            # OAuth authentication was not attempted. -            return None - -        if missing: -            # OAuth was attempted but missing parameters. -            msg = 'Missing parameters: %s' % (', '.join(missing)) -            raise exceptions.AuthenticationFailed(msg) - -        if not self.check_nonce(request, oauth_request): -            msg = 'Nonce check failed' -            raise exceptions.AuthenticationFailed(msg) - -        try: -            consumer_key = oauth_request.get_parameter('oauth_consumer_key') -            consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key) -        except oauth_provider.store.InvalidConsumerError: -            msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key') -            raise exceptions.AuthenticationFailed(msg) - -        if consumer.status != oauth_provider.consts.ACCEPTED: -            msg = 'Invalid consumer key status: %s' % consumer.get_status_display() -            raise exceptions.AuthenticationFailed(msg) - -        try: -            token_param = oauth_request.get_parameter('oauth_token') -            token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param) -        except oauth_provider.store.InvalidTokenError: -            msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token') -            raise exceptions.AuthenticationFailed(msg) - -        try: -            self.validate_token(request, consumer, token) -        except oauth.Error as err: -            raise exceptions.AuthenticationFailed(err.message) - -        user = token.user - -        if not user.is_active: -            msg = 'User inactive or deleted: %s' % user.username -            raise exceptions.AuthenticationFailed(msg) - -        return (token.user, token) - -    def authenticate_header(self, request): -        """ -        If permission is denied, return a '401 Unauthorized' response, -        with an appropriate 'WWW-Authenticate' header. -        """ -        return 'OAuth realm="%s"' % self.www_authenticate_realm - -    def validate_token(self, request, consumer, token): -        """ -        Check the token and raise an `oauth.Error` exception if invalid. -        """ -        oauth_server, oauth_request = oauth_provider.utils.initialize_server_request(request) -        oauth_server.verify_request(oauth_request, consumer, token) - -    def check_nonce(self, request, oauth_request): -        """ -        Checks nonce of request, and return True if valid. -        """ -        oauth_nonce = oauth_request['oauth_nonce'] -        oauth_timestamp = oauth_request['oauth_timestamp'] -        return check_nonce(request, oauth_request, oauth_nonce, oauth_timestamp) - - -class OAuth2Authentication(BaseAuthentication): -    """ -    OAuth 2 authentication backend using `django-oauth2-provider` -    """ -    www_authenticate_realm = 'api' -    allow_query_params_token = settings.DEBUG - -    def __init__(self, *args, **kwargs): -        super(OAuth2Authentication, self).__init__(*args, **kwargs) - -        if oauth2_provider is None: -            raise ImproperlyConfigured( -                "The 'django-oauth2-provider' package could not be imported. " -                "It is required for use with the 'OAuth2Authentication' class.") - -    def authenticate(self, request): -        """ -        Returns two-tuple of (user, token) if authentication succeeds, -        or None otherwise. -        """ - -        auth = get_authorization_header(request).split() - -        if len(auth) == 1: -            msg = 'Invalid bearer header. No credentials provided.' -            raise exceptions.AuthenticationFailed(msg) -        elif len(auth) > 2: -            msg = 'Invalid bearer header. Token string should not contain spaces.' -            raise exceptions.AuthenticationFailed(msg) - -        if auth and auth[0].lower() == b'bearer': -            access_token = auth[1] -        elif 'access_token' in request.POST: -            access_token = request.POST['access_token'] -        elif 'access_token' in request.GET and self.allow_query_params_token: -            access_token = request.GET['access_token'] -        else: -            return None - -        return self.authenticate_credentials(request, access_token) - -    def authenticate_credentials(self, request, access_token): -        """ -        Authenticate the request, given the access token. -        """ - -        try: -            token = oauth2_provider.oauth2.models.AccessToken.objects.select_related('user') -            # provider_now switches to timezone aware datetime when -            # the oauth2_provider version supports to it. -            token = token.get(token=access_token, expires__gt=provider_now()) -        except oauth2_provider.oauth2.models.AccessToken.DoesNotExist: -            raise exceptions.AuthenticationFailed('Invalid token') - -        user = token.user - -        if not user.is_active: -            msg = 'User inactive or deleted: %s' % user.get_username() -            raise exceptions.AuthenticationFailed(msg) - -        return (user, token) - -    def authenticate_header(self, request): -        """ -        Bearer is the only finalized type currently - -        Check details on the `OAuth2Authentication.authenticate` method -        """ -        return 'Bearer realm="%s"' % self.www_authenticate_realm diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index f31dded1..37ade255 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -23,7 +23,7 @@ class AuthTokenSerializer(serializers.Serializer):                  msg = _('Unable to log in with provided credentials.')                  raise exceptions.ValidationError(msg)          else: -            msg = _('Must include "username" and "password"') +            msg = _('Must include "username" and "password".')              raise exceptions.ValidationError(msg)          attrs['user'] = user diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 36413394..50f37014 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -5,15 +5,13 @@ versions of django/python, and compatibility wrappers around optional packages.  # flake8: noqa  from __future__ import unicode_literals - -import inspect -  from django.core.exceptions import ImproperlyConfigured +from django.conf import settings  from django.utils.encoding import force_text  from django.utils.six.moves.urllib.parse import urlparse as _urlparse -from django.conf import settings  from django.utils import six  import django +import inspect  def unicode_repr(instance): @@ -33,6 +31,13 @@ def unicode_to_repr(value):      return value +def unicode_http_header(value): +    # Coerce HTTP header value to unicode. +    if isinstance(value, six.binary_type): +        return value.decode('iso-8859-1') +    return value + +  def total_seconds(timedelta):      # TimeDelta.total_seconds() is only available in Python 2.7      if hasattr(timedelta, 'total_seconds'): @@ -232,77 +237,13 @@ except ImportError:      apply_markdown = None -# Yaml is optional -try: -    import yaml -except ImportError: -    yaml = None - - -# XML is optional -try: -    import defusedxml.ElementTree as etree -except ImportError: -    etree = None - - -# OAuth2 is optional -try: -    # Note: The `oauth2` package actually provides oauth1.0a support.  Urg. -    import oauth2 as oauth -except ImportError: -    oauth = None - - -# OAuthProvider is optional -try: -    import oauth_provider -    from oauth_provider.store import store as oauth_provider_store - -    # check_nonce's calling signature in django-oauth-plus changes sometime -    # between versions 2.0 and 2.2.1 -    def check_nonce(request, oauth_request, oauth_nonce, oauth_timestamp): -        check_nonce_args = inspect.getargspec(oauth_provider_store.check_nonce).args -        if 'timestamp' in check_nonce_args: -            return oauth_provider_store.check_nonce( -                request, oauth_request, oauth_nonce, oauth_timestamp -            ) -        return oauth_provider_store.check_nonce( -            request, oauth_request, oauth_nonce -        ) - -except (ImportError, ImproperlyConfigured): -    oauth_provider = None -    oauth_provider_store = None -    check_nonce = None - - -# OAuth 2 support is optional -try: -    import provider as oauth2_provider -    from provider import scope as oauth2_provider_scope -    from provider import constants as oauth2_constants - -    if oauth2_provider.__version__ in ('0.2.3', '0.2.4'): -        # 0.2.3 and 0.2.4 are supported version that do not support -        # timezone aware datetimes -        import datetime - -        provider_now = datetime.datetime.now -    else: -        # Any other supported version does use timezone aware datetimes -        from django.utils.timezone import now as provider_now -except ImportError: -    oauth2_provider = None -    oauth2_provider_scope = None -    oauth2_constants = None -    provider_now = None -  # `separators` argument to `json.dumps()` differs between 2.x and 3.x  # See: http://bugs.python.org/issue22767  if six.PY3:      SHORT_SEPARATORS = (',', ':')      LONG_SEPARATORS = (', ', ': ') +    INDENT_SEPARATORS = (',', ': ')  else:      SHORT_SEPARATORS = (b',', b':')      LONG_SEPARATORS = (b', ', b': ') +    INDENT_SEPARATORS = (b',', b': ') diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 1f381e4e..f954c13e 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -7,8 +7,7 @@ In addition Django's built in 403 and 404 exceptions are handled.  from __future__ import unicode_literals  from django.utils import six  from django.utils.encoding import force_text -from django.utils.translation import ugettext_lazy as _ -from django.utils.translation import ungettext_lazy +from django.utils.translation import ugettext_lazy as _, ungettext  from rest_framework import status  import math @@ -36,7 +35,7 @@ class APIException(Exception):      Subclasses should provide `.status_code` and `.default_detail` properties.      """      status_code = status.HTTP_500_INTERNAL_SERVER_ERROR -    default_detail = _('A server error occured') +    default_detail = _('A server error occurred.')      def __init__(self, detail=None):          if detail is not None: @@ -89,20 +88,25 @@ class PermissionDenied(APIException):      default_detail = _('You do not have permission to perform this action.') +class NotFound(APIException): +    status_code = status.HTTP_404_NOT_FOUND +    default_detail = _('Not found.') + +  class MethodNotAllowed(APIException):      status_code = status.HTTP_405_METHOD_NOT_ALLOWED -    default_detail = _("Method '%s' not allowed.") +    default_detail = _('Method "{method}" not allowed.')      def __init__(self, method, detail=None):          if detail is not None:              self.detail = force_text(detail)          else: -            self.detail = force_text(self.default_detail) % method +            self.detail = force_text(self.default_detail).format(method=method)  class NotAcceptable(APIException):      status_code = status.HTTP_406_NOT_ACCEPTABLE -    default_detail = _('Could not satisfy the request Accept header') +    default_detail = _('Could not satisfy the request Accept header.')      def __init__(self, detail=None, available_renderers=None):          if detail is not None: @@ -114,23 +118,22 @@ class NotAcceptable(APIException):  class UnsupportedMediaType(APIException):      status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE -    default_detail = _("Unsupported media type '%s' in request.") +    default_detail = _('Unsupported media type "{media_type}" in request.')      def __init__(self, media_type, detail=None):          if detail is not None:              self.detail = force_text(detail)          else: -            self.detail = force_text(self.default_detail) % media_type +            self.detail = force_text(self.default_detail).format( +                media_type=media_type +            )  class Throttled(APIException):      status_code = status.HTTP_429_TOO_MANY_REQUESTS      default_detail = _('Request was throttled.') -    extra_detail = ungettext_lazy( -        'Expected available in %(wait)d second.', -        'Expected available in %(wait)d seconds.', -        'wait' -    ) +    extra_detail_singular = 'Expected available in {wait} second.' +    extra_detail_plural = 'Expected available in {wait} seconds.'      def __init__(self, wait=None, detail=None):          if detail is not None: @@ -142,6 +145,8 @@ class Throttled(APIException):              self.wait = None          else:              self.wait = math.ceil(wait) -            self.detail += ' ' + force_text( -                self.extra_detail % {'wait': self.wait} -            ) +            self.detail += ' ' + force_text(ungettext( +                self.extra_detail_singular.format(wait=self.wait), +                self.extra_detail_plural.format(wait=self.wait), +                self.wait +            )) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 71a9f193..02d2adef 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -484,7 +484,7 @@ class Field(object):  class BooleanField(Field):      default_error_messages = { -        'invalid': _('`{input}` is not a valid boolean.') +        'invalid': _('"{input}" is not a valid boolean.')      }      default_empty_html = False      initial = False @@ -512,7 +512,7 @@ class BooleanField(Field):  class NullBooleanField(Field):      default_error_messages = { -        'invalid': _('`{input}` is not a valid boolean.') +        'invalid': _('"{input}" is not a valid boolean.')      }      initial = None      TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) @@ -612,7 +612,7 @@ class RegexField(CharField):  class SlugField(CharField):      default_error_messages = { -        'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.") +        'invalid': _('Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.')      }      def __init__(self, **kwargs): @@ -624,7 +624,7 @@ class SlugField(CharField):  class URLField(CharField):      default_error_messages = { -        'invalid': _("Enter a valid URL.") +        'invalid': _('Enter a valid URL.')      }      def __init__(self, **kwargs): @@ -657,7 +657,7 @@ class IntegerField(Field):          'invalid': _('A valid integer is required.'),          'max_value': _('Ensure this value is less than or equal to {max_value}.'),          'min_value': _('Ensure this value is greater than or equal to {min_value}.'), -        'max_string_length': _('String value too large') +        'max_string_length': _('String value too large.')      }      MAX_STRING_LENGTH = 1000  # Guard against malicious string inputs. @@ -688,10 +688,10 @@ class IntegerField(Field):  class FloatField(Field):      default_error_messages = { -        'invalid': _("A valid number is required."), +        'invalid': _('A valid number is required.'),          'max_value': _('Ensure this value is less than or equal to {max_value}.'),          'min_value': _('Ensure this value is greater than or equal to {min_value}.'), -        'max_string_length': _('String value too large') +        'max_string_length': _('String value too large.')      }      MAX_STRING_LENGTH = 1000  # Guard against malicious string inputs. @@ -727,7 +727,7 @@ class DecimalField(Field):          'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),          'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),          'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'), -        'max_string_length': _('String value too large') +        'max_string_length': _('String value too large.')      }      MAX_STRING_LENGTH = 1000  # Guard against malicious string inputs. @@ -810,7 +810,7 @@ class DecimalField(Field):  class DateTimeField(Field):      default_error_messages = { -        'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'), +        'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}.'),          'date': _('Expected a datetime but got a date.'),      }      format = api_settings.DATETIME_FORMAT @@ -875,7 +875,7 @@ class DateTimeField(Field):  class DateField(Field):      default_error_messages = { -        'invalid': _('Date has wrong format. Use one of these formats instead: {format}'), +        'invalid': _('Date has wrong format. Use one of these formats instead: {format}.'),          'datetime': _('Expected a date but got a datetime.'),      }      format = api_settings.DATE_FORMAT @@ -933,7 +933,7 @@ class DateField(Field):  class TimeField(Field):      default_error_messages = { -        'invalid': _('Time has wrong format. Use one of these formats instead: {format}'), +        'invalid': _('Time has wrong format. Use one of these formats instead: {format}.'),      }      format = api_settings.TIME_FORMAT      input_formats = api_settings.TIME_INPUT_FORMATS @@ -989,7 +989,7 @@ class TimeField(Field):  class ChoiceField(Field):      default_error_messages = { -        'invalid_choice': _('`{input}` is not a valid choice.') +        'invalid_choice': _('"{input}" is not a valid choice.')      }      def __init__(self, choices, **kwargs): @@ -1033,8 +1033,8 @@ class ChoiceField(Field):  class MultipleChoiceField(ChoiceField):      default_error_messages = { -        'invalid_choice': _('`{input}` is not a valid choice.'), -        'not_a_list': _('Expected a list of items but got type `{input_type}`.') +        'invalid_choice': _('"{input}" is not a valid choice.'), +        'not_a_list': _('Expected a list of items but got type "{input_type}".')      }      default_empty_html = [] @@ -1064,10 +1064,10 @@ class MultipleChoiceField(ChoiceField):  class FileField(Field):      default_error_messages = { -        'required': _("No file was submitted."), -        'invalid': _("The submitted data was not a file. Check the encoding type on the form."), -        'no_name': _("No filename could be determined."), -        'empty': _("The submitted file is empty."), +        'required': _('No file was submitted.'), +        'invalid': _('The submitted data was not a file. Check the encoding type on the form.'), +        'no_name': _('No filename could be determined.'), +        'empty': _('The submitted file is empty.'),          'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'),      }      use_url = api_settings.UPLOADED_FILES_USE_URL @@ -1110,8 +1110,7 @@ class FileField(Field):  class ImageField(FileField):      default_error_messages = {          'invalid_image': _( -            'Upload a valid image. The file you uploaded was either not an ' -            'image or a corrupted image.' +            'Upload a valid image. The file you uploaded was either not an image or a corrupted image.'          ),      } @@ -1149,7 +1148,7 @@ class ListField(Field):      child = _UnvalidatedField()      initial = []      default_error_messages = { -        'not_a_list': _('Expected a list of items but got type `{input_type}`') +        'not_a_list': _('Expected a list of items but got type "{input_type}".')      }      def __init__(self, *args, **kwargs): @@ -1186,7 +1185,7 @@ class DictField(Field):      child = _UnvalidatedField()      initial = []      default_error_messages = { -        'not_a_dict': _('Expected a dictionary of items but got type `{input_type}`') +        'not_a_dict': _('Expected a dictionary of items but got type "{input_type}".')      }      def __init__(self, *args, **kwargs): diff --git a/rest_framework/filters.py b/rest_framework/filters.py index d188a2d1..2bcf3699 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -114,7 +114,7 @@ class OrderingFilter(BaseFilterBackend):      ordering_param = api_settings.ORDERING_PARAM      ordering_fields = None -    def get_ordering(self, request): +    def get_ordering(self, request, queryset, view):          """          Ordering is set by a comma delimited ?ordering=... query parameter. @@ -124,7 +124,13 @@ class OrderingFilter(BaseFilterBackend):          """          params = request.query_params.get(self.ordering_param)          if params: -            return [param.strip() for param in params.split(',')] +            fields = [param.strip() for param in params.split(',')] +            ordering = self.remove_invalid_fields(queryset, fields, view) +            if ordering: +                return ordering + +        # No ordering was included, or all the ordering fields were invalid +        return self.get_default_ordering(view)      def get_default_ordering(self, view):          ordering = getattr(view, 'ordering', None) @@ -132,7 +138,7 @@ class OrderingFilter(BaseFilterBackend):              return (ordering,)          return ordering -    def remove_invalid_fields(self, queryset, ordering, view): +    def remove_invalid_fields(self, queryset, fields, view):          valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)          if valid_fields is None: @@ -152,18 +158,10 @@ class OrderingFilter(BaseFilterBackend):              valid_fields = [field.name for field in queryset.model._meta.fields]              valid_fields += queryset.query.aggregates.keys() -        return [term for term in ordering if term.lstrip('-') in valid_fields] +        return [term for term in fields if term.lstrip('-') in valid_fields]      def filter_queryset(self, request, queryset, view): -        ordering = self.get_ordering(request) - -        if ordering: -            # Skip any incorrect parameters -            ordering = self.remove_invalid_fields(queryset, ordering, view) - -        if not ordering: -            # Use 'ordering' attribute by default -            ordering = self.get_default_ordering(view) +        ordering = self.get_ordering(request, queryset, view)          if ordering:              return queryset.order_by(*ordering) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index e6db155e..61dcb84a 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -2,29 +2,13 @@  Generic views that provide commonly needed behaviour.  """  from __future__ import unicode_literals - -from django.core.paginator import Paginator, InvalidPage  from django.db.models.query import QuerySet  from django.http import Http404  from django.shortcuts import get_object_or_404 as _get_object_or_404 -from django.utils import six -from django.utils.translation import ugettext as _  from rest_framework import views, mixins  from rest_framework.settings import api_settings -def strict_positive_int(integer_string, cutoff=None): -    """ -    Cast a string to a strictly positive integer. -    """ -    ret = int(integer_string) -    if ret <= 0: -        raise ValueError() -    if cutoff: -        ret = min(ret, cutoff) -    return ret - -  def get_object_or_404(queryset, *filter_args, **filter_kwargs):      """      Same as Django's standard shortcut, but make sure to also raise 404 @@ -40,7 +24,6 @@ class GenericAPIView(views.APIView):      """      Base class for all other generic views.      """ -      # You'll need to either set these attributes,      # or override `get_queryset()`/`get_serializer_class()`.      # If you are overriding a view method, it is important that you call @@ -50,146 +33,16 @@ class GenericAPIView(views.APIView):      queryset = None      serializer_class = None -    # If you want to use object lookups other than pk, set this attribute. +    # If you want to use object lookups other than pk, set 'lookup_field'.      # For more complex lookup requirements override `get_object()`.      lookup_field = 'pk'      lookup_url_kwarg = None -    # Pagination settings -    paginate_by = api_settings.PAGINATE_BY -    paginate_by_param = api_settings.PAGINATE_BY_PARAM -    max_paginate_by = api_settings.MAX_PAGINATE_BY -    pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS -    page_kwarg = 'page' -      # The filter backend classes to use for queryset filtering      filter_backends = api_settings.DEFAULT_FILTER_BACKENDS -    # The following attribute may be subject to change, -    # and should be considered private API. -    paginator_class = Paginator - -    def get_serializer_context(self): -        """ -        Extra context provided to the serializer class. -        """ -        return { -            'request': self.request, -            'format': self.format_kwarg, -            'view': self -        } - -    def get_serializer(self, *args, **kwargs): -        """ -        Return the serializer instance that should be used for validating and -        deserializing input, and for serializing output. -        """ -        serializer_class = self.get_serializer_class() -        kwargs['context'] = self.get_serializer_context() -        return serializer_class(*args, **kwargs) - -    def get_pagination_serializer(self, page): -        """ -        Return a serializer instance to use with paginated data. -        """ -        class SerializerClass(self.pagination_serializer_class): -            class Meta: -                object_serializer_class = self.get_serializer_class() - -        pagination_serializer_class = SerializerClass -        context = self.get_serializer_context() -        return pagination_serializer_class(instance=page, context=context) - -    def paginate_queryset(self, queryset): -        """ -        Paginate a queryset if required, either returning a page object, -        or `None` if pagination is not configured for this view. -        """ -        page_size = self.get_paginate_by() -        if not page_size: -            return None - -        paginator = self.paginator_class(queryset, page_size) -        page_kwarg = self.kwargs.get(self.page_kwarg) -        page_query_param = self.request.query_params.get(self.page_kwarg) -        page = page_kwarg or page_query_param or 1 -        try: -            page_number = paginator.validate_number(page) -        except InvalidPage: -            if page == 'last': -                page_number = paginator.num_pages -            else: -                raise Http404(_("Page is not 'last', nor can it be converted to an int.")) -        try: -            page = paginator.page(page_number) -        except InvalidPage as exc: -            error_format = _('Invalid page (%(page_number)s): %(message)s') -            raise Http404(error_format % { -                'page_number': page_number, -                'message': six.text_type(exc) -            }) - -        return page - -    def filter_queryset(self, queryset): -        """ -        Given a queryset, filter it with whichever filter backend is in use. - -        You are unlikely to want to override this method, although you may need -        to call it either from a list view, or from a custom `get_object` -        method if you want to apply the configured filtering backend to the -        default queryset. -        """ -        for backend in self.get_filter_backends(): -            queryset = backend().filter_queryset(self.request, queryset, self) -        return queryset - -    def get_filter_backends(self): -        """ -        Returns the list of filter backends that this view requires. -        """ -        return list(self.filter_backends) - -    # The following methods provide default implementations -    # that you may want to override for more complex cases. - -    def get_paginate_by(self): -        """ -        Return the size of pages to use with pagination. - -        If `PAGINATE_BY_PARAM` is set it will attempt to get the page size -        from a named query parameter in the url, eg. ?page_size=100 - -        Otherwise defaults to using `self.paginate_by`. -        """ -        if self.paginate_by_param: -            try: -                return strict_positive_int( -                    self.request.query_params[self.paginate_by_param], -                    cutoff=self.max_paginate_by -                ) -            except (KeyError, ValueError): -                pass - -        return self.paginate_by - -    def get_serializer_class(self): -        """ -        Return the class to use for the serializer. -        Defaults to using `self.serializer_class`. - -        You may want to override this if you need to provide different -        serializations depending on the incoming request. - -        (Eg. admins get full serialization, others get basic serialization) -        """ -        assert self.serializer_class is not None, ( -            "'%s' should either include a `serializer_class` attribute, " -            "or override the `get_serializer_class()` method." -            % self.__class__.__name__ -        ) - -        return self.serializer_class +    # The style to use for queryset pagination. +    pagination_class = api_settings.DEFAULT_PAGINATION_CLASS      def get_queryset(self):          """ @@ -246,6 +99,83 @@ class GenericAPIView(views.APIView):          return obj +    def get_serializer(self, *args, **kwargs): +        """ +        Return the serializer instance that should be used for validating and +        deserializing input, and for serializing output. +        """ +        serializer_class = self.get_serializer_class() +        kwargs['context'] = self.get_serializer_context() +        return serializer_class(*args, **kwargs) + +    def get_serializer_class(self): +        """ +        Return the class to use for the serializer. +        Defaults to using `self.serializer_class`. + +        You may want to override this if you need to provide different +        serializations depending on the incoming request. + +        (Eg. admins get full serialization, others get basic serialization) +        """ +        assert self.serializer_class is not None, ( +            "'%s' should either include a `serializer_class` attribute, " +            "or override the `get_serializer_class()` method." +            % self.__class__.__name__ +        ) + +        return self.serializer_class + +    def get_serializer_context(self): +        """ +        Extra context provided to the serializer class. +        """ +        return { +            'request': self.request, +            'format': self.format_kwarg, +            'view': self +        } + +    def filter_queryset(self, queryset): +        """ +        Given a queryset, filter it with whichever filter backend is in use. + +        You are unlikely to want to override this method, although you may need +        to call it either from a list view, or from a custom `get_object` +        method if you want to apply the configured filtering backend to the +        default queryset. +        """ +        for backend in list(self.filter_backends): +            queryset = backend().filter_queryset(self.request, queryset, self) +        return queryset + +    @property +    def paginator(self): +        """ +        The paginator instance associated with the view, or `None`. +        """ +        if not hasattr(self, '_paginator'): +            if self.pagination_class is None: +                self._paginator = None +            else: +                self._paginator = self.pagination_class() +        return self._paginator + +    def paginate_queryset(self, queryset): +        """ +        Return a single page of results, or `None` if pagination is disabled. +        """ +        if self.paginator is None: +            return None +        return self.paginator.paginate_queryset(queryset, self.request, view=self) + +    def get_paginated_response(self, data): +        """ +        Return a paginated style `Response` object for the given output data. +        """ +        assert self.paginator is not None +        return self.paginator.get_paginated_response(data) +  # Concrete view classes that provide method handlers  # by composing the mixin classes with the base view. diff --git a/rest_framework/locale/en_US/LC_MESSAGES/django.po b/rest_framework/locale/en_US/LC_MESSAGES/django.po new file mode 100644 index 00000000..d98225ce --- /dev/null +++ b/rest_framework/locale/en_US/LC_MESSAGES/django.po @@ -0,0 +1,316 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR <EMAIL@ADDRESS>, YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2015-01-07 18:21+0000\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" +"Language-Team: LANGUAGE <LL@li.org>\n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" + +#: authentication.py:69 +msgid "Invalid basic header. No credentials provided." +msgstr "" + +#: authentication.py:72 +msgid "Invalid basic header. Credentials string should not contain spaces." +msgstr "" + +#: authentication.py:78 +msgid "Invalid basic header. Credentials not correctly base64 encoded." +msgstr "" + +#: authentication.py:90 +msgid "Invalid username/password." +msgstr "" + +#: authentication.py:156 +msgid "Invalid token header. No credentials provided." +msgstr "" + +#: authentication.py:159 +msgid "Invalid token header. Token string should not contain spaces." +msgstr "" + +#: authentication.py:168 +msgid "Invalid token." +msgstr "" + +#: authentication.py:171 +msgid "User inactive or deleted." +msgstr "" + +#: authtoken/serializers.py:20 +msgid "User account is disabled." +msgstr "" + +#: authtoken/serializers.py:23 +msgid "Unable to log in with provided credentials." +msgstr "" + +#: authtoken/serializers.py:26 +msgid "Must include \"username\" and \"password\"." +msgstr "" + +#: exceptions.py:38 +msgid "A server error occurred." +msgstr "" + +#: exceptions.py:73 +msgid "Malformed request." +msgstr "" + +#: exceptions.py:78 +msgid "Incorrect authentication credentials." +msgstr "" + +#: exceptions.py:83 +msgid "Authentication credentials were not provided." +msgstr "" + +#: exceptions.py:88 +msgid "You do not have permission to perform this action." +msgstr "" + +#: exceptions.py:93 +msgid "Not found." +msgstr "" + +#: exceptions.py:98 +msgid "Method \"{method}\" not allowed." +msgstr "" + +#: exceptions.py:109 +msgid "Could not satisfy the request Accept header." +msgstr "" + +#: exceptions.py:121 +msgid "Unsupported media type \"{media_type}\" in request." +msgstr "" + +#: exceptions.py:134 +msgid "Request was throttled." +msgstr "" + +#: fields.py:152 relations.py:131 relations.py:155 validators.py:77 +#: validators.py:155 +msgid "This field is required." +msgstr "" + +#: fields.py:153 +msgid "This field may not be null." +msgstr "" + +#: fields.py:480 fields.py:508 +msgid "\"{input}\" is not a valid boolean." +msgstr "" + +#: fields.py:543 +msgid "This field may not be blank." +msgstr "" + +#: fields.py:544 fields.py:1252 +msgid "Ensure this field has no more than {max_length} characters." +msgstr "" + +#: fields.py:545 +msgid "Ensure this field has at least {min_length} characters." +msgstr "" + +#: fields.py:587 +msgid "Enter a valid email address." +msgstr "" + +#: fields.py:604 +msgid "This value does not match the required pattern." +msgstr "" + +#: fields.py:615 +msgid "" +"Enter a valid \"slug\" consisting of letters, numbers, underscores or " +"hyphens." +msgstr "" + +#: fields.py:627 +msgid "Enter a valid URL." +msgstr "" + +#: fields.py:640 +msgid "A valid integer is required." +msgstr "" + +#: fields.py:641 fields.py:675 fields.py:708 +msgid "Ensure this value is less than or equal to {max_value}." +msgstr "" + +#: fields.py:642 fields.py:676 fields.py:709 +msgid "Ensure this value is greater than or equal to {min_value}." +msgstr "" + +#: fields.py:643 fields.py:677 fields.py:713 +msgid "String value too large." +msgstr "" + +#: fields.py:674 fields.py:707 +msgid "A valid number is required." +msgstr "" + +#: fields.py:710 +msgid "Ensure that there are no more than {max_digits} digits in total." +msgstr "" + +#: fields.py:711 +msgid "Ensure that there are no more than {max_decimal_places} decimal places." +msgstr "" + +#: fields.py:712 +msgid "" +"Ensure that there are no more than {max_whole_digits} digits before the " +"decimal point." +msgstr "" + +#: fields.py:796 +msgid "Datetime has wrong format. Use one of these formats instead: {format}." +msgstr "" + +#: fields.py:797 +msgid "Expected a datetime but got a date." +msgstr "" + +#: fields.py:861 +msgid "Date has wrong format. Use one of these formats instead: {format}." +msgstr "" + +#: fields.py:862 +msgid "Expected a date but got a datetime." +msgstr "" + +#: fields.py:919 +msgid "Time has wrong format. Use one of these formats instead: {format}." +msgstr "" + +#: fields.py:975 fields.py:1019 +msgid "\"{input}\" is not a valid choice." +msgstr "" + +#: fields.py:1020 fields.py:1121 serializers.py:476 +msgid "Expected a list of items but got type \"{input_type}\"." +msgstr "" + +#: fields.py:1050 +msgid "No file was submitted." +msgstr "" + +#: fields.py:1051 +msgid "The submitted data was not a file. Check the encoding type on the form." +msgstr "" + +#: fields.py:1052 +msgid "No filename could be determined." +msgstr "" + +#: fields.py:1053 +msgid "The submitted file is empty." +msgstr "" + +#: fields.py:1054 +msgid "" +"Ensure this filename has at most {max_length} characters (it has {length})." +msgstr "" + +#: fields.py:1096 +msgid "" +"Upload a valid image. The file you uploaded was either not an image or a " +"corrupted image." +msgstr "" + +#: generics.py:123 +msgid "" +"Choose a valid page number. Page numbers must be a whole number, or must be " +"the string \"last\"." +msgstr "" + +#: generics.py:128 +msgid "Invalid page \"{page_number}\": {message}." +msgstr "" + +#: relations.py:132 +msgid "Invalid pk \"{pk_value}\" - object does not exist." +msgstr "" + +#: relations.py:133 +msgid "Incorrect type. Expected pk value, received {data_type}." +msgstr "" + +#: relations.py:156 +msgid "Invalid hyperlink - No URL match." +msgstr "" + +#: relations.py:157 +msgid "Invalid hyperlink - Incorrect URL match." +msgstr "" + +#: relations.py:158 +msgid "Invalid hyperlink - Object does not exist." +msgstr "" + +#: relations.py:159 +msgid "Incorrect type. Expected URL string, received {data_type}." +msgstr "" + +#: relations.py:294 +msgid "Object with {slug_name}={value} does not exist." +msgstr "" + +#: relations.py:295 +msgid "Invalid value." +msgstr "" + +#: serializers.py:299 +msgid "Invalid data. Expected a dictionary, but got {datatype}." +msgstr "" + +#: validators.py:22 +msgid "This field must be unique." +msgstr "" + +#: validators.py:76 +msgid "The fields {field_names} must make a unique set." +msgstr "" + +#: validators.py:219 +msgid "This field must be unique for the \"{date_field}\" date." +msgstr "" + +#: validators.py:234 +msgid "This field must be unique for the \"{date_field}\" month." +msgstr "" + +#: validators.py:247 +msgid "This field must be unique for the \"{date_field}\" year." +msgstr "" + +#: versioning.py:39 +msgid "Invalid version in \"Accept\" header." +msgstr "" + +#: versioning.py:70 versioning.py:112 +msgid "Invalid version in URL path." +msgstr "" + +#: versioning.py:138 +msgid "Invalid version in hostname." +msgstr "" + +#: versioning.py:160 +msgid "Invalid version in query parameter." +msgstr "" diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 2074a107..c34cfcee 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -5,7 +5,6 @@ We don't bind behaviour to http method handlers yet,  which allows mixin classes to be composed in interesting ways.  """  from __future__ import unicode_literals -  from rest_framework import status  from rest_framework.response import Response  from rest_framework.settings import api_settings @@ -37,12 +36,14 @@ class ListModelMixin(object):      List a queryset.      """      def list(self, request, *args, **kwargs): -        instance = self.filter_queryset(self.get_queryset()) -        page = self.paginate_queryset(instance) +        queryset = self.filter_queryset(self.get_queryset()) + +        page = self.paginate_queryset(queryset)          if page is not None: -            serializer = self.get_pagination_serializer(page) -        else: -            serializer = self.get_serializer(instance, many=True) +            serializer = self.get_serializer(page, many=True) +            return self.get_paginated_response(serializer.data) + +        serializer = self.get_serializer(queryset, many=True)          return Response(serializer.data) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 9c8dda8f..b3658aca 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -1,85 +1,682 @@ +# coding: utf-8  """  Pagination serializers determine the structure of the output that should  be used for paginated responses.  """  from __future__ import unicode_literals -from rest_framework import serializers -from rest_framework.templatetags.rest_framework import replace_query_param +from base64 import b64encode, b64decode +from collections import namedtuple +from django.core.paginator import InvalidPage, Paginator as DjangoPaginator +from django.template import Context, loader +from django.utils import six +from django.utils.six.moves.urllib import parse as urlparse +from django.utils.translation import ugettext as _ +from rest_framework.compat import OrderedDict +from rest_framework.exceptions import NotFound +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.utils.urls import ( +    replace_query_param, remove_query_param +) -class NextPageField(serializers.Field): +def _positive_int(integer_string, strict=False, cutoff=None):      """ -    Field that returns a link to the next page in paginated results. +    Cast a string to a strictly positive integer.      """ -    page_field = 'page' +    ret = int(integer_string) +    if ret < 0 or (ret == 0 and strict): +        raise ValueError() +    if cutoff: +        ret = min(ret, cutoff) +    return ret -    def to_representation(self, value): -        if not value.has_next(): -            return None -        page = value.next_page_number() -        request = self.context.get('request') -        url = request and request.build_absolute_uri() or '' -        return replace_query_param(url, self.page_field, page) +def _divide_with_ceil(a, b): +    """ +    Returns 'a' divded by 'b', with any remainder rounded up. +    """ +    if a % b: +        return (a // b) + 1 +    return a // b -class PreviousPageField(serializers.Field): + +def _get_count(queryset):      """ -    Field that returns a link to the previous page in paginated results. +    Determine an object count, supporting either querysets or regular lists.      """ -    page_field = 'page' +    try: +        return queryset.count() +    except (AttributeError, TypeError): +        return len(queryset) -    def to_representation(self, value): -        if not value.has_previous(): -            return None -        page = value.previous_page_number() -        request = self.context.get('request') -        url = request and request.build_absolute_uri() or '' -        return replace_query_param(url, self.page_field, page) + +def _get_displayed_page_numbers(current, final): +    """ +    This utility function determines a list of page numbers to display. +    This gives us a nice contextually relevant set of page numbers. + +    For example: +    current=14, final=16 -> [1, None, 13, 14, 15, 16] + +    This implementation gives one page to each side of the cursor, +    or two pages to the side when the cursor is at the edge, then +    ensures that any breaks between non-continous page numbers never +    remove only a single page. + +    For an alernativative implementation which gives two pages to each side of +    the cursor, eg. as in GitHub issue list pagination, see: + +    https://gist.github.com/tomchristie/321140cebb1c4a558b15 +    """ +    assert current >= 1 +    assert final >= current + +    if final <= 5: +        return list(range(1, final + 1)) + +    # We always include the first two pages, last two pages, and +    # two pages either side of the current page. +    included = set(( +        1, +        current - 1, current, current + 1, +        final +    )) + +    # If the break would only exclude a single page number then we +    # may as well include the page number instead of the break. +    if current <= 4: +        included.add(2) +        included.add(3) +    if current >= final - 3: +        included.add(final - 1) +        included.add(final - 2) + +    # Now sort the page numbers and drop anything outside the limits. +    included = [ +        idx for idx in sorted(list(included)) +        if idx > 0 and idx <= final +    ] + +    # Finally insert any `...` breaks +    if current > 4: +        included.insert(1, None) +    if current < final - 3: +        included.insert(len(included) - 1, None) +    return included -class DefaultObjectSerializer(serializers.Serializer): +def _get_page_links(page_numbers, current, url_func):      """ -    If no object serializer is specified, then this serializer will be applied -    as the default. +    Given a list of page numbers and `None` page breaks, +    return a list of `PageLink` objects.      """ -    def to_representation(self, value): -        return value +    page_links = [] +    for page_number in page_numbers: +        if page_number is None: +            page_link = PAGE_BREAK +        else: +            page_link = PageLink( +                url=url_func(page_number), +                number=page_number, +                is_active=(page_number == current), +                is_break=False +            ) +        page_links.append(page_link) +    return page_links -class BasePaginationSerializer(serializers.Serializer): +def _decode_cursor(encoded):      """ -    A base class for pagination serializers to inherit from, -    to make implementing custom serializers more easy. +    Given a string representing an encoded cursor, return a `Cursor` instance.      """ -    results_field = 'results' +    try: +        querystring = b64decode(encoded.encode('ascii')).decode('ascii') +        tokens = urlparse.parse_qs(querystring, keep_blank_values=True) + +        offset = tokens.get('o', ['0'])[0] +        offset = _positive_int(offset) + +        reverse = tokens.get('r', ['0'])[0] +        reverse = bool(int(reverse)) + +        position = tokens.get('p', [None])[0] +    except (TypeError, ValueError): +        return None + +    return Cursor(offset=offset, reverse=reverse, position=position) + + +def _encode_cursor(cursor): +    """ +    Given a Cursor instance, return an encoded string representation. +    """ +    tokens = {} +    if cursor.offset != 0: +        tokens['o'] = str(cursor.offset) +    if cursor.reverse: +        tokens['r'] = '1' +    if cursor.position is not None: +        tokens['p'] = cursor.position + +    querystring = urlparse.urlencode(tokens, doseq=True) +    return b64encode(querystring.encode('ascii')).decode('ascii') + + +def _reverse_ordering(ordering_tuple): +    """ +    Given an order_by tuple such as `('-created', 'uuid')` reverse the +    ordering and return a new tuple, eg. `('created', '-uuid')`. +    """ +    invert = lambda x: x[1:] if (x.startswith('-')) else '-' + x +    return tuple([invert(item) for item in ordering_tuple]) + + +Cursor = namedtuple('Cursor', ['offset', 'reverse', 'position']) +PageLink = namedtuple('PageLink', ['url', 'number', 'is_active', 'is_break']) + +PAGE_BREAK = PageLink(url=None, number=None, is_active=False, is_break=True) -    def __init__(self, *args, **kwargs): + +class BasePagination(object): +    display_page_controls = False + +    def paginate_queryset(self, queryset, request, view=None):  # pragma: no cover +        raise NotImplementedError('paginate_queryset() must be implemented.') + +    def get_paginated_response(self, data):  # pragma: no cover +        raise NotImplementedError('get_paginated_response() must be implemented.') + +    def to_html(self):  # pragma: no cover +        raise NotImplementedError('to_html() must be implemented to display page controls.') + + +class PageNumberPagination(BasePagination): +    """ +    A simple page number based style that supports page numbers as +    query parameters. For example: + +    http://api.example.org/accounts/?page=4 +    http://api.example.org/accounts/?page=4&page_size=100 +    """ +    # The default page size. +    # Defaults to `None`, meaning pagination is disabled. +    paginate_by = api_settings.PAGINATE_BY + +    # Client can control the page using this query parameter. +    page_query_param = 'page' + +    # Client can control the page size using this query parameter. +    # Default is 'None'. Set to eg 'page_size' to enable usage. +    paginate_by_param = api_settings.PAGINATE_BY_PARAM + +    # Set to an integer to limit the maximum page size the client may request. +    # Only relevant if 'paginate_by_param' has also been set. +    max_paginate_by = api_settings.MAX_PAGINATE_BY + +    last_page_strings = ('last',) + +    template = 'rest_framework/pagination/numbers.html' + +    invalid_page_message = _('Invalid page "{page_number}": {message}.') + +    def _handle_backwards_compat(self, view):          """ -        Override init to add in the object serializer field on-the-fly. +        Prior to version 3.1, pagination was handled in the view, and the +        attributes were set there. The attributes should now be set on +        the pagination class, but the old style is still pending deprecation.          """ -        super(BasePaginationSerializer, self).__init__(*args, **kwargs) -        results_field = self.results_field +        for attr in ( +            'paginate_by', 'page_query_param', +            'paginate_by_param', 'max_paginate_by' +        ): +            if hasattr(view, attr): +                setattr(self, attr, getattr(view, attr)) -        try: -            object_serializer = self.Meta.object_serializer_class -        except AttributeError: -            object_serializer = DefaultObjectSerializer +    def paginate_queryset(self, queryset, request, view=None): +        """ +        Paginate a queryset if required, either returning a +        page object, or `None` if pagination is not configured for this view. +        """ +        self._handle_backwards_compat(view) + +        page_size = self.get_page_size(request) +        if not page_size: +            return None + +        paginator = DjangoPaginator(queryset, page_size) +        page_number = request.query_params.get(self.page_query_param, 1) +        if page_number in self.last_page_strings: +            page_number = paginator.num_pages          try: -            list_serializer_class = object_serializer.Meta.list_serializer_class -        except AttributeError: -            list_serializer_class = serializers.ListSerializer +            self.page = paginator.page(page_number) +        except InvalidPage as exc: +            msg = self.invalid_page_message.format( +                page_number=page_number, message=six.text_type(exc) +            ) +            raise NotFound(msg) -        self.fields[results_field] = list_serializer_class( -            child=object_serializer(*args, **kwargs), -            source='object_list' -        ) +        if paginator.count > 1: +            # The browsable API should display pagination controls. +            self.display_page_controls = True + +        self.request = request +        return self.page + +    def get_paginated_response(self, data): +        return Response(OrderedDict([ +            ('count', self.page.paginator.count), +            ('next', self.get_next_link()), +            ('previous', self.get_previous_link()), +            ('results', data) +        ])) + +    def get_page_size(self, request): +        if self.paginate_by_param: +            try: +                return _positive_int( +                    request.query_params[self.paginate_by_param], +                    strict=True, +                    cutoff=self.max_paginate_by +                ) +            except (KeyError, ValueError): +                pass + +        return self.paginate_by + +    def get_next_link(self): +        if not self.page.has_next(): +            return None +        url = self.request.build_absolute_uri() +        page_number = self.page.next_page_number() +        return replace_query_param(url, self.page_query_param, page_number) + +    def get_previous_link(self): +        if not self.page.has_previous(): +            return None +        url = self.request.build_absolute_uri() +        page_number = self.page.previous_page_number() +        if page_number == 1: +            return remove_query_param(url, self.page_query_param) +        return replace_query_param(url, self.page_query_param, page_number) + +    def get_html_context(self): +        base_url = self.request.build_absolute_uri() +        def page_number_to_url(page_number): +            if page_number == 1: +                return remove_query_param(base_url, self.page_query_param) +            else: +                return replace_query_param(base_url, self.page_query_param, page_number) -class PaginationSerializer(BasePaginationSerializer): +        current = self.page.number +        final = self.page.paginator.num_pages +        page_numbers = _get_displayed_page_numbers(current, final) +        page_links = _get_page_links(page_numbers, current, page_number_to_url) + +        return { +            'previous_url': self.get_previous_link(), +            'next_url': self.get_next_link(), +            'page_links': page_links +        } + +    def to_html(self): +        template = loader.get_template(self.template) +        context = Context(self.get_html_context()) +        return template.render(context) + + +class LimitOffsetPagination(BasePagination):      """ -    A default implementation of a pagination serializer. +    A limit/offset based style. For example: + +    http://api.example.org/accounts/?limit=100 +    http://api.example.org/accounts/?offset=400&limit=100      """ -    count = serializers.ReadOnlyField(source='paginator.count') -    next = NextPageField(source='*') -    previous = PreviousPageField(source='*') +    default_limit = api_settings.PAGINATE_BY +    limit_query_param = 'limit' +    offset_query_param = 'offset' +    max_limit = None +    template = 'rest_framework/pagination/numbers.html' + +    def paginate_queryset(self, queryset, request, view=None): +        self.limit = self.get_limit(request) +        self.offset = self.get_offset(request) +        self.count = _get_count(queryset) +        self.request = request +        if self.count > self.limit: +            self.display_page_controls = True +        return queryset[self.offset:self.offset + self.limit] + +    def get_paginated_response(self, data): +        return Response(OrderedDict([ +            ('count', self.count), +            ('next', self.get_next_link()), +            ('previous', self.get_previous_link()), +            ('results', data) +        ])) + +    def get_limit(self, request): +        if self.limit_query_param: +            try: +                return _positive_int( +                    request.query_params[self.limit_query_param], +                    cutoff=self.max_limit +                ) +            except (KeyError, ValueError): +                pass + +        return self.default_limit + +    def get_offset(self, request): +        try: +            return _positive_int( +                request.query_params[self.offset_query_param], +            ) +        except (KeyError, ValueError): +            return 0 + +    def get_next_link(self): +        if self.offset + self.limit >= self.count: +            return None + +        url = self.request.build_absolute_uri() +        offset = self.offset + self.limit +        return replace_query_param(url, self.offset_query_param, offset) + +    def get_previous_link(self): +        if self.offset <= 0: +            return None + +        url = self.request.build_absolute_uri() + +        if self.offset - self.limit <= 0: +            return remove_query_param(url, self.offset_query_param) + +        offset = self.offset - self.limit +        return replace_query_param(url, self.offset_query_param, offset) + +    def get_html_context(self): +        base_url = self.request.build_absolute_uri() +        current = _divide_with_ceil(self.offset, self.limit) + 1 +        # The number of pages is a little bit fiddly. +        # We need to sum both the number of pages from current offset to end +        # plus the number of pages up to the current offset. +        # When offset is not strictly divisible by the limit then we may +        # end up introducing an extra page as an artifact. +        final = ( +            _divide_with_ceil(self.count - self.offset, self.limit) + +            _divide_with_ceil(self.offset, self.limit) +        ) + +        def page_number_to_url(page_number): +            if page_number == 1: +                return remove_query_param(base_url, self.offset_query_param) +            else: +                offset = self.offset + ((page_number - current) * self.limit) +                return replace_query_param(base_url, self.offset_query_param, offset) + +        page_numbers = _get_displayed_page_numbers(current, final) +        page_links = _get_page_links(page_numbers, current, page_number_to_url) + +        return { +            'previous_url': self.get_previous_link(), +            'next_url': self.get_next_link(), +            'page_links': page_links +        } + +    def to_html(self): +        template = loader.get_template(self.template) +        context = Context(self.get_html_context()) +        return template.render(context) + + +class CursorPagination(BasePagination): +    # Determine how/if True, False and None positions work - do the string +    # encodings work with Django queryset filters? +    # Consider a max offset cap. +    # Tidy up the `get_ordering` API (eg remove queryset from it) +    cursor_query_param = 'cursor' +    page_size = api_settings.PAGINATE_BY +    invalid_cursor_message = _('Invalid cursor') +    ordering = None +    template = 'rest_framework/pagination/previous_and_next.html' + +    def paginate_queryset(self, queryset, request, view=None): +        self.base_url = request.build_absolute_uri() +        self.ordering = self.get_ordering(request, queryset, view) + +        # Determine if we have a cursor, and if so then decode it. +        encoded = request.query_params.get(self.cursor_query_param) +        if encoded is None: +            self.cursor = None +            (offset, reverse, current_position) = (0, False, None) +        else: +            self.cursor = _decode_cursor(encoded) +            if self.cursor is None: +                raise NotFound(self.invalid_cursor_message) +            (offset, reverse, current_position) = self.cursor + +        # Cursor pagination always enforces an ordering. +        if reverse: +            queryset = queryset.order_by(*_reverse_ordering(self.ordering)) +        else: +            queryset = queryset.order_by(*self.ordering) + +        # If we have a cursor with a fixed position then filter by that. +        if current_position is not None: +            order = self.ordering[0] +            is_reversed = order.startswith('-') +            order_attr = order.lstrip('-') + +            # Test for: (cursor reversed) XOR (queryset reversed) +            if self.cursor.reverse != is_reversed: +                kwargs = {order_attr + '__lt': current_position} +            else: +                kwargs = {order_attr + '__gt': current_position} + +            queryset = queryset.filter(**kwargs) + +        # If we have an offset cursor then offset the entire page by that amount. +        # We also always fetch an extra item in order to determine if there is a +        # page following on from this one. +        results = list(queryset[offset:offset + self.page_size + 1]) +        self.page = results[:self.page_size] + +        # Determine the position of the final item following the page. +        if len(results) > len(self.page): +            has_following_postion = True +            following_position = self._get_position_from_instance(results[-1], self.ordering) +        else: +            has_following_postion = False +            following_position = None + +        # If we have a reverse queryset, then the query ordering was in reverse +        # so we need to reverse the items again before returning them to the user. +        if reverse: +            self.page = list(reversed(self.page)) + +        if reverse: +            # Determine next and previous positions for reverse cursors. +            self.has_next = (current_position is not None) or (offset > 0) +            self.has_previous = has_following_postion +            if self.has_next: +                self.next_position = current_position +            if self.has_previous: +                self.previous_position = following_position +        else: +            # Determine next and previous positions for forward cursors. +            self.has_next = has_following_postion +            self.has_previous = (current_position is not None) or (offset > 0) +            if self.has_next: +                self.next_position = following_position +            if self.has_previous: +                self.previous_position = current_position + +        # Display page controls in the browsable API if there is more +        # than one page. +        if self.has_previous or self.has_next: +            self.display_page_controls = True + +        return self.page + +    def get_next_link(self): +        if not self.has_next: +            return None + +        if self.cursor and self.cursor.reverse and self.cursor.offset != 0: +            # If we're reversing direction and we have an offset cursor +            # then we cannot use the first position we find as a marker. +            compare = self._get_position_from_instance(self.page[-1], self.ordering) +        else: +            compare = self.next_position +        offset = 0 + +        for item in reversed(self.page): +            position = self._get_position_from_instance(item, self.ordering) +            if position != compare: +                # The item in this position and the item following it +                # have different positions. We can use this position as +                # our marker. +                break + +            # The item in this postion has the same position as the item +            # following it, we can't use it as a marker position, so increment +            # the offset and keep seeking to the previous item. +            compare = position +            offset += 1 + +        else: +            # There were no unique positions in the page. +            if not self.has_previous: +                # We are on the first page. +                # Our cursor will have an offset equal to the page size, +                # but no position to filter against yet. +                offset = self.page_size +                position = None +            elif self.cursor.reverse: +                # The change in direction will introduce a paging artifact, +                # where we end up skipping forward a few extra items. +                offset = 0 +                position = self.previous_position +            else: +                # Use the position from the existing cursor and increment +                # it's offset by the page size. +                offset = self.cursor.offset + self.page_size +                position = self.previous_position + +        cursor = Cursor(offset=offset, reverse=False, position=position) +        encoded = _encode_cursor(cursor) +        return replace_query_param(self.base_url, self.cursor_query_param, encoded) + +    def get_previous_link(self): +        if not self.has_previous: +            return None + +        if self.cursor and not self.cursor.reverse and self.cursor.offset != 0: +            # If we're reversing direction and we have an offset cursor +            # then we cannot use the first position we find as a marker. +            compare = self._get_position_from_instance(self.page[0], self.ordering) +        else: +            compare = self.previous_position +        offset = 0 + +        for item in self.page: +            position = self._get_position_from_instance(item, self.ordering) +            if position != compare: +                # The item in this position and the item following it +                # have different positions. We can use this position as +                # our marker. +                break + +            # The item in this postion has the same position as the item +            # following it, we can't use it as a marker position, so increment +            # the offset and keep seeking to the previous item. +            compare = position +            offset += 1 + +        else: +            # There were no unique positions in the page. +            if not self.has_next: +                # We are on the final page. +                # Our cursor will have an offset equal to the page size, +                # but no position to filter against yet. +                offset = self.page_size +                position = None +            elif self.cursor.reverse: +                # Use the position from the existing cursor and increment +                # it's offset by the page size. +                offset = self.cursor.offset + self.page_size +                position = self.next_position +            else: +                # The change in direction will introduce a paging artifact, +                # where we end up skipping back a few extra items. +                offset = 0 +                position = self.next_position + +        cursor = Cursor(offset=offset, reverse=True, position=position) +        encoded = _encode_cursor(cursor) +        return replace_query_param(self.base_url, self.cursor_query_param, encoded) + +    def get_ordering(self, request, queryset, view): +        """ +        Return a tuple of strings, that may be used in an `order_by` method. +        """ +        ordering_filters = [ +            filter_cls for filter_cls in getattr(view, 'filter_backends', []) +            if hasattr(filter_cls, 'get_ordering') +        ] + +        if ordering_filters: +            # If a filter exists on the view that implements `get_ordering` +            # then we defer to that filter to determine the ordering. +            filter_cls = ordering_filters[0] +            filter_instance = filter_cls() +            ordering = filter_instance.get_ordering(request, queryset, view) +            assert ordering is not None, ( +                'Using cursor pagination, but filter class {filter_cls} ' +                'returned a `None` ordering.'.format( +                    filter_cls=filter_cls.__name__ +                ) +            ) +        else: +            # The default case is to check for an `ordering` attribute, +            # first on the view instance, and then on this pagination instance. +            ordering = getattr(view, 'ordering', getattr(self, 'ordering', None)) +            assert ordering is not None, ( +                'Using cursor pagination, but no ordering attribute was declared ' +                'on the view or on the pagination class.' +            ) + +        assert isinstance(ordering, (six.string_types, list, tuple)), ( +            'Invalid ordering. Expected string or tuple, but got {type}'.format( +                type=type(ordering).__name__ +            ) +        ) + +        if isinstance(ordering, six.string_types): +            return (ordering,) +        return tuple(ordering) + +    def _get_position_from_instance(self, instance, ordering): +        attr = getattr(instance, ordering[0].lstrip('-')) +        return six.text_type(attr) + +    def get_paginated_response(self, data): +        return Response(OrderedDict([ +            ('next', self.get_next_link()), +            ('previous', self.get_previous_link()), +            ('results', data) +        ])) + +    def get_html_context(self): +        return { +            'previous_url': self.get_previous_link(), +            'next_url': self.get_next_link() +        } + +    def to_html(self): +        template = loader.get_template(self.template) +        context = Context(self.get_html_context()) +        return template.render(context) diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 1efab85b..437d1339 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -14,12 +14,9 @@ from django.http.multipartparser import MultiPartParserError, parse_header, Chun  from django.utils import six  from django.utils.six.moves.urllib import parse as urlparse  from django.utils.encoding import force_text -from rest_framework.compat import etree, yaml  from rest_framework.exceptions import ParseError  from rest_framework import renderers  import json -import datetime -import decimal  class DataAndFiles(object): @@ -67,29 +64,6 @@ class JSONParser(BaseParser):              raise ParseError('JSON parse error - %s' % six.text_type(exc)) -class YAMLParser(BaseParser): -    """ -    Parses YAML-serialized data. -    """ - -    media_type = 'application/yaml' - -    def parse(self, stream, media_type=None, parser_context=None): -        """ -        Parses the incoming bytestream as YAML and returns the resulting data. -        """ -        assert yaml, 'YAMLParser requires pyyaml to be installed' - -        parser_context = parser_context or {} -        encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) - -        try: -            data = stream.read().decode(encoding) -            return yaml.safe_load(data) -        except (ValueError, yaml.parser.ParserError) as exc: -            raise ParseError('YAML parse error - %s' % six.text_type(exc)) - -  class FormParser(BaseParser):      """      Parser for form data. @@ -138,78 +112,6 @@ class MultiPartParser(BaseParser):              raise ParseError('Multipart form parse error - %s' % six.text_type(exc)) -class XMLParser(BaseParser): -    """ -    XML parser. -    """ - -    media_type = 'application/xml' - -    def parse(self, stream, media_type=None, parser_context=None): -        """ -        Parses the incoming bytestream as XML and returns the resulting data. -        """ -        assert etree, 'XMLParser requires defusedxml to be installed' - -        parser_context = parser_context or {} -        encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) -        parser = etree.DefusedXMLParser(encoding=encoding) -        try: -            tree = etree.parse(stream, parser=parser, forbid_dtd=True) -        except (etree.ParseError, ValueError) as exc: -            raise ParseError('XML parse error - %s' % six.text_type(exc)) -        data = self._xml_convert(tree.getroot()) - -        return data - -    def _xml_convert(self, element): -        """ -        convert the xml `element` into the corresponding python object -        """ - -        children = list(element) - -        if len(children) == 0: -            return self._type_convert(element.text) -        else: -            # if the fist child tag is list-item means all children are list-item -            if children[0].tag == "list-item": -                data = [] -                for child in children: -                    data.append(self._xml_convert(child)) -            else: -                data = {} -                for child in children: -                    data[child.tag] = self._xml_convert(child) - -            return data - -    def _type_convert(self, value): -        """ -        Converts the value returned by the XMl parse into the equivalent -        Python type -        """ -        if value is None: -            return value - -        try: -            return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') -        except ValueError: -            pass - -        try: -            return int(value) -        except ValueError: -            pass - -        try: -            return decimal.Decimal(value) -        except decimal.InvalidOperation: -            pass - -        return value - -  class FileUploadParser(BaseParser):      """      Parser for file upload data. diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 3f6f5961..9069d315 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -3,8 +3,7 @@ Provides a set of pluggable permission policies.  """  from __future__ import unicode_literals  from django.http import Http404 -from rest_framework.compat import (get_model_name, oauth2_provider_scope, -                                   oauth2_constants) +from rest_framework.compat import get_model_name  SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS'] @@ -199,28 +198,3 @@ class DjangoObjectPermissions(DjangoModelPermissions):              return False          return True - - -class TokenHasReadWriteScope(BasePermission): -    """ -    The request is authenticated as a user and the token used has the right scope -    """ - -    def has_permission(self, request, view): -        token = request.auth -        read_only = request.method in SAFE_METHODS - -        if not token: -            return False - -        if hasattr(token, 'resource'):  # OAuth 1 -            return read_only or not request.auth.resource.is_readonly -        elif hasattr(token, 'scope'):  # OAuth 2 -            required = oauth2_constants.READ if read_only else oauth2_constants.WRITE -            return oauth2_provider_scope.check(required, request.auth.scope) - -        assert False, ( -            'TokenHasReadWriteScope requires either the' -            '`OAuthAuthentication` or `OAuth2Authentication` authentication ' -            'class to be used.' -        ) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 13793f37..66857a41 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -130,7 +130,7 @@ class StringRelatedField(RelatedField):  class PrimaryKeyRelatedField(RelatedField):      default_error_messages = {          'required': _('This field is required.'), -        'does_not_exist': _("Invalid pk '{pk_value}' - object does not exist."), +        'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'),          'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),      } @@ -154,7 +154,7 @@ class HyperlinkedRelatedField(RelatedField):      default_error_messages = {          'required': _('This field is required.'), -        'no_match': _('Invalid hyperlink - No URL match'), +        'no_match': _('Invalid hyperlink - No URL match.'),          'incorrect_match': _('Invalid hyperlink - Incorrect URL match.'),          'does_not_exist': _('Invalid hyperlink - Object does not exist.'),          'incorrect_type': _('Incorrect type. Expected URL string, received {data_type}.'), @@ -292,7 +292,7 @@ class SlugRelatedField(RelatedField):      """      default_error_messages = { -        'does_not_exist': _("Object with {slug_name}={value} does not exist."), +        'does_not_exist': _('Object with {slug_name}={value} does not exist.'),          'invalid': _('Invalid value.'),      } diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 584332e6..6256acdd 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -17,11 +17,8 @@ from django.http.multipartparser import parse_header  from django.template import Context, RequestContext, loader, Template  from django.test.client import encode_multipart  from django.utils import six -from django.utils.encoding import smart_text -from django.utils.xmlutils import SimplerXMLGenerator -from django.utils.six.moves import StringIO  from rest_framework import exceptions, serializers, status, VERSION -from rest_framework.compat import SHORT_SEPARATORS, LONG_SEPARATORS, yaml +from rest_framework.compat import SHORT_SEPARATORS, LONG_SEPARATORS, INDENT_SEPARATORS  from rest_framework.exceptions import ParseError  from rest_framework.settings import api_settings  from rest_framework.request import is_form_media_type, override_method @@ -90,7 +87,11 @@ class JSONRenderer(BaseRenderer):          renderer_context = renderer_context or {}          indent = self.get_indent(accepted_media_type, renderer_context) -        separators = SHORT_SEPARATORS if (indent is None and self.compact) else LONG_SEPARATORS + +        if indent is None: +            separators = SHORT_SEPARATORS if self.compact else LONG_SEPARATORS +        else: +            separators = INDENT_SEPARATORS          ret = json.dumps(              data, cls=self.encoder_class, @@ -112,112 +113,6 @@ class JSONRenderer(BaseRenderer):          return ret -class JSONPRenderer(JSONRenderer): -    """ -    Renderer which serializes to json, -    wrapping the json output in a callback function. -    """ - -    media_type = 'application/javascript' -    format = 'jsonp' -    callback_parameter = 'callback' -    default_callback = 'callback' -    charset = 'utf-8' - -    def get_callback(self, renderer_context): -        """ -        Determine the name of the callback to wrap around the json output. -        """ -        request = renderer_context.get('request', None) -        params = request and request.query_params or {} -        return params.get(self.callback_parameter, self.default_callback) - -    def render(self, data, accepted_media_type=None, renderer_context=None): -        """ -        Renders into jsonp, wrapping the json output in a callback function. - -        Clients may set the callback function name using a query parameter -        on the URL, for example: ?callback=exampleCallbackName -        """ -        renderer_context = renderer_context or {} -        callback = self.get_callback(renderer_context) -        json = super(JSONPRenderer, self).render(data, accepted_media_type, -                                                 renderer_context) -        return callback.encode(self.charset) + b'(' + json + b');' - - -class XMLRenderer(BaseRenderer): -    """ -    Renderer which serializes to XML. -    """ - -    media_type = 'application/xml' -    format = 'xml' -    charset = 'utf-8' - -    def render(self, data, accepted_media_type=None, renderer_context=None): -        """ -        Renders `data` into serialized XML. -        """ -        if data is None: -            return '' - -        stream = StringIO() - -        xml = SimplerXMLGenerator(stream, self.charset) -        xml.startDocument() -        xml.startElement("root", {}) - -        self._to_xml(xml, data) - -        xml.endElement("root") -        xml.endDocument() -        return stream.getvalue() - -    def _to_xml(self, xml, data): -        if isinstance(data, (list, tuple)): -            for item in data: -                xml.startElement("list-item", {}) -                self._to_xml(xml, item) -                xml.endElement("list-item") - -        elif isinstance(data, dict): -            for key, value in six.iteritems(data): -                xml.startElement(key, {}) -                self._to_xml(xml, value) -                xml.endElement(key) - -        elif data is None: -            # Don't output any value -            pass - -        else: -            xml.characters(smart_text(data)) - - -class YAMLRenderer(BaseRenderer): -    """ -    Renderer which serializes to YAML. -    """ - -    media_type = 'application/yaml' -    format = 'yaml' -    encoder = encoders.SafeDumper -    charset = 'utf-8' -    ensure_ascii = False - -    def render(self, data, accepted_media_type=None, renderer_context=None): -        """ -        Renders `data` into serialized YAML. -        """ -        assert yaml, 'YAMLRenderer requires pyyaml to be installed' - -        if data is None: -            return '' - -        return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii) - -  class TemplateHTMLRenderer(BaseRenderer):      """      An HTML renderer for use with templates. @@ -696,6 +591,11 @@ class BrowsableAPIRenderer(BaseRenderer):                  renderer_content_type += ' ;%s' % renderer.charset          response_headers['Content-Type'] = renderer_content_type +        if hasattr(view, 'paginator') and view.paginator.display_page_controls: +            paginator = view.paginator +        else: +            paginator = None +          context = {              'content': self.get_content(renderer, data, accepted_media_type, renderer_context),              'view': view, @@ -704,6 +604,7 @@ class BrowsableAPIRenderer(BaseRenderer):              'description': self.get_description(view),              'name': self.get_name(view),              'version': VERSION, +            'paginator': paginator,              'breadcrumblist': self.get_breadcrumbs(request),              'allowed_methods': view.allowed_methods,              'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes], diff --git a/rest_framework/request.py b/rest_framework/request.py index cfbbdecc..bf6ff670 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -107,6 +107,10 @@ def clone_request(request, method):          ret.accepted_renderer = request.accepted_renderer      if hasattr(request, 'accepted_media_type'):          ret.accepted_media_type = request.accepted_media_type +    if hasattr(request, 'version'): +        ret.version = request.version +    if hasattr(request, 'versioning_scheme'): +        ret.versioning_scheme = request.versioning_scheme      return ret diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index a74e8aa2..8fcca55b 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -9,6 +9,18 @@ from django.utils.functional import lazy  def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):      """ +    If versioning is being used then we pass any `reverse` calls through +    to the versioning scheme instance, so that the resulting URL +    can be modified if needed. +    """ +    scheme = getattr(request, 'versioning_scheme', None) +    if scheme is not None: +        return scheme.reverse(viewname, args, kwargs, request, format, **extra) +    return _reverse(viewname, args, kwargs, request, format, **extra) + + +def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra): +    """      Same as `django.core.urlresolvers.reverse`, but optionally takes a request      and returns a fully qualified URL, using the request to get the base URL.      """ diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 42d1e370..a3b8196b 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -12,7 +12,7 @@ response content is handled by parsers and renderers.  """  from __future__ import unicode_literals  from django.db import models -from django.db.models.fields import FieldDoesNotExist, Field as DjangoField +from django.db.models.fields import FieldDoesNotExist, Field as DjangoModelField  from django.utils.translation import ugettext_lazy as _  from rest_framework.compat import postgres_fields, unicode_to_repr  from rest_framework.utils import model_meta @@ -327,7 +327,9 @@ class Serializer(BaseSerializer):          Returns a list of validator callables.          """          # Used by the lazily-evaluated `validators` property. -        return getattr(getattr(self, 'Meta', None), 'validators', []) +        meta = getattr(self, 'Meta', None) +        validators = getattr(meta, 'validators', None) +        return validators[:] if validators else []      def get_initial(self):          if hasattr(self, 'initial_data'): @@ -477,7 +479,7 @@ class ListSerializer(BaseSerializer):      many = True      default_error_messages = { -        'not_a_list': _('Expected a list of items but got type `{input_type}`.') +        'not_a_list': _('Expected a list of items but got type "{input_type}".')      }      def __init__(self, *args, **kwargs): @@ -702,8 +704,7 @@ class ModelSerializer(Serializer):      you need you should either declare the extra/differing fields explicitly on      the serializer class, or simply use a `Serializer` class.      """ - -    _field_mapping = ClassLookupDict({ +    serializer_field_mapping = {          models.AutoField: IntegerField,          models.BigIntegerField: IntegerField,          models.BooleanField: BooleanField, @@ -725,10 +726,11 @@ class ModelSerializer(Serializer):          models.SmallIntegerField: IntegerField,          models.TextField: CharField,          models.TimeField: TimeField, -        models.URLField: URLField -        # Note: Some version-specific mappings also defined below. -    }) -    _related_class = PrimaryKeyRelatedField +        models.URLField: URLField, +    } +    serializer_related_class = PrimaryKeyRelatedField + +    # Default `create` and `update` behavior...      def create(self, validated_data):          """ @@ -799,69 +801,81 @@ class ModelSerializer(Serializer):          return instance -    def get_validators(self): -        # If the validators have been declared explicitly then use that. -        validators = getattr(getattr(self, 'Meta', None), 'validators', None) -        if validators is not None: -            return validators +    # Determine the fields to apply... -        # Determine the default set of validators. -        validators = [] -        model_class = self.Meta.model -        field_names = set([ -            field.source for field in self.fields.values() -            if (field.source != '*') and ('.' not in field.source) -        ]) +    def get_fields(self): +        """ +        Return the dict of field names -> field instances that should be +        used for `self.fields` when instantiating the serializer. +        """ +        assert hasattr(self, 'Meta'), ( +            'Class {serializer_class} missing "Meta" attribute'.format( +                serializer_class=self.__class__.__name__ +            ) +        ) +        assert hasattr(self.Meta, 'model'), ( +            'Class {serializer_class} missing "Meta.model" attribute'.format( +                serializer_class=self.__class__.__name__ +            ) +        ) -        # Note that we make sure to check `unique_together` both on the -        # base model class, but also on any parent classes. -        for parent_class in [model_class] + list(model_class._meta.parents.keys()): -            for unique_together in parent_class._meta.unique_together: -                if field_names.issuperset(set(unique_together)): -                    validator = UniqueTogetherValidator( -                        queryset=parent_class._default_manager, -                        fields=unique_together -                    ) -                    validators.append(validator) +        declared_fields = copy.deepcopy(self._declared_fields) +        model = getattr(self.Meta, 'model') +        depth = getattr(self.Meta, 'depth', 0) -        # Add any unique_for_date/unique_for_month/unique_for_year constraints. -        info = model_meta.get_field_info(model_class) -        for field_name, field in info.fields_and_pk.items(): -            if field.unique_for_date and field_name in field_names: -                validator = UniqueForDateValidator( -                    queryset=model_class._default_manager, -                    field=field_name, -                    date_field=field.unique_for_date -                ) -                validators.append(validator) +        if depth is not None: +            assert depth >= 0, "'depth' may not be negative." +            assert depth <= 10, "'depth' may not be greater than 10." -            if field.unique_for_month and field_name in field_names: -                validator = UniqueForMonthValidator( -                    queryset=model_class._default_manager, -                    field=field_name, -                    date_field=field.unique_for_month -                ) -                validators.append(validator) +        # Retrieve metadata about fields & relationships on the model class. +        info = model_meta.get_field_info(model) +        field_names = self.get_field_names(declared_fields, info) -            if field.unique_for_year and field_name in field_names: -                validator = UniqueForYearValidator( -                    queryset=model_class._default_manager, -                    field=field_name, -                    date_field=field.unique_for_year -                ) -                validators.append(validator) +        # Determine any extra field arguments and hidden fields that +        # should be included +        extra_kwargs = self.get_extra_kwargs() +        extra_kwargs, hidden_fields = self.get_uniqueness_extra_kwargs( +            field_names, declared_fields, extra_kwargs +        ) -        return validators +        # Determine the fields that should be included on the serializer. +        fields = OrderedDict() -    def get_fields(self): -        declared_fields = copy.deepcopy(self._declared_fields) +        for field_name in field_names: +            # If the field is explicitly declared on the class then use that. +            if field_name in declared_fields: +                fields[field_name] = declared_fields[field_name] +                continue -        ret = OrderedDict() -        model = getattr(self.Meta, 'model') +            # Determine the serializer field class and keyword arguments. +            field_class, field_kwargs = self.build_field( +                field_name, info, model, depth +            ) + +            # Include any kwargs defined in `Meta.extra_kwargs` +            field_kwargs = self.build_field_kwargs( +                field_kwargs, extra_kwargs, field_name +            ) + +            # Create the serializer field. +            fields[field_name] = field_class(**field_kwargs) + +        # Add in any hidden fields. +        fields.update(hidden_fields) + +        return fields + +    # Methods for determining the set of field names to include... + +    def get_field_names(self, declared_fields, info): +        """ +        Returns the list of all field names that should be created when +        instantiating this serializer class. This is based on the default +        set of fields, but also takes into account the `Meta.fields` or +        `Meta.exclude` options if they have been specified. +        """          fields = getattr(self.Meta, 'fields', None)          exclude = getattr(self.Meta, 'exclude', None) -        depth = getattr(self.Meta, 'depth', 0) -        extra_kwargs = getattr(self.Meta, 'extra_kwargs', {})          if fields and not isinstance(fields, (list, tuple)):              raise TypeError( @@ -875,201 +889,199 @@ class ModelSerializer(Serializer):                  type(exclude).__name__              ) -        assert not (fields and exclude), "Cannot set both 'fields' and 'exclude'." - -        extra_kwargs = self._include_additional_options(extra_kwargs) +        assert not (fields and exclude), ( +            "Cannot set both 'fields' and 'exclude' options on " +            "serializer {serializer_class}.".format( +                serializer_class=self.__class__.__name__ +            ) +        ) -        # Retrieve metadata about fields & relationships on the model class. -        info = model_meta.get_field_info(model) +        if fields is not None: +            # Ensure that all declared fields have also been included in the +            # `Meta.fields` option. -        if fields is None: -            # Use the default set of field names if none is supplied explicitly. -            fields = self._get_default_field_names(declared_fields, info) -            exclude = getattr(self.Meta, 'exclude', None) -            if exclude is not None: -                for field_name in exclude: -                    assert field_name in fields, ( -                        'The field in the `exclude` option must be a model field. Got %s.' % -                        field_name +            # Do not require any fields that are declared a parent class, +            # in order to allow serializer subclasses to only include +            # a subset of fields. +            required_field_names = set(declared_fields) +            for cls in self.__class__.__bases__: +                required_field_names -= set(getattr(cls, '_declared_fields', [])) + +            for field_name in required_field_names: +                assert field_name in fields, ( +                    "The field '{field_name}' was declared on serializer " +                    "{serializer_class}, but has not been included in the " +                    "'fields' option.".format( +                        field_name=field_name, +                        serializer_class=self.__class__.__name__                      ) -                    fields.remove(field_name) -        else: -            # Check that any fields declared on the class are -            # also explicitly included in `Meta.fields`. +                ) +            return fields + +        # Use the default set of field names if `Meta.fields` is not specified. +        fields = self.get_default_field_names(declared_fields, info) + +        if exclude is not None: +            # If `Meta.exclude` is included, then remove those fields. +            for field_name in exclude: +                assert field_name in fields, ( +                    "The field '{field_name}' was include on serializer " +                    "{serializer_class} in the 'exclude' option, but does " +                    "not match any model field.".format( +                        field_name=field_name, +                        serializer_class=self.__class__.__name__ +                    ) +                ) +                fields.remove(field_name) -            # Note that we ignore any fields that were declared on a parent -            # class, in order to support only including a subset of fields -            # when subclassing serializers. -            declared_field_names = set(declared_fields.keys()) -            for cls in self.__class__.__bases__: -                declared_field_names -= set(getattr(cls, '_declared_fields', [])) +        return fields -            missing_fields = declared_field_names - set(fields) -            assert not missing_fields, ( -                'Field `%s` has been declared on serializer `%s`, but ' -                'is missing from `Meta.fields`.' % -                (list(missing_fields)[0], self.__class__.__name__) -            ) +    def get_default_field_names(self, declared_fields, model_info): +        """ +        Return the default list of field names that will be used if the +        `Meta.fields` option is not specified. +        """ +        return ( +            [model_info.pk.name] + +            list(declared_fields.keys()) + +            list(model_info.fields.keys()) + +            list(model_info.forward_relations.keys()) +        ) -        # Determine the set of model fields, and the fields that they map to. -        # We actually only need this to deal with the slightly awkward case -        # of supporting `unique_for_date`/`unique_for_month`/`unique_for_year`. -        model_field_mapping = {} -        for field_name in fields: -            if field_name in declared_fields: -                field = declared_fields[field_name] -                source = field.source or field_name +    # Methods for constructing serializer fields... + +    def build_field(self, field_name, info, model_class, nested_depth): +        """ +        Return a two tuple of (cls, kwargs) to build a serializer field with. +        """ +        if field_name in info.fields_and_pk: +            model_field = info.fields_and_pk[field_name] +            return self.build_standard_field(field_name, model_field) + +        elif field_name in info.relations: +            relation_info = info.relations[field_name] +            if not nested_depth: +                return self.build_relational_field(field_name, relation_info)              else: -                try: -                    source = extra_kwargs[field_name]['source'] -                except KeyError: -                    source = field_name -            # Model fields will always have a simple source mapping, -            # they can't be nested attribute lookups. -            if '.' not in source and source != '*': -                model_field_mapping[source] = field_name +                return self.build_nested_field(field_name, relation_info, nested_depth) -        # Determine if we need any additional `HiddenField` or extra keyword -        # arguments to deal with `unique_for` dates that are required to -        # be in the input data in order to validate it. -        hidden_fields = {} -        unique_constraint_names = set() +        elif hasattr(model_class, field_name): +            return self.build_property_field(field_name, model_class) -        for model_field_name, field_name in model_field_mapping.items(): -            try: -                model_field = model._meta.get_field(model_field_name) -            except FieldDoesNotExist: -                continue +        elif field_name == api_settings.URL_FIELD_NAME: +            return self.build_url_field(field_name, model_class) -            if not isinstance(model_field, DjangoField): -                continue +        return self.build_unknown_field(field_name, model_class) -            # Include each of the `unique_for_*` field names. -            unique_constraint_names |= set([ -                model_field.unique_for_date, -                model_field.unique_for_month, -                model_field.unique_for_year -            ]) +    def build_standard_field(self, field_name, model_field): +        """ +        Create regular model fields. +        """ +        field_mapping = ClassLookupDict(self.serializer_field_mapping) + +        field_class = field_mapping[model_field] +        field_kwargs = get_field_kwargs(field_name, model_field) + +        if 'choices' in field_kwargs: +            # Fields with choices get coerced into `ChoiceField` +            # instead of using their regular typed field. +            field_class = ChoiceField +        if not issubclass(field_class, ModelField): +            # `model_field` is only valid for the fallback case of +            # `ModelField`, which is used when no other typed field +            # matched to the model field. +            field_kwargs.pop('model_field', None) +        if not issubclass(field_class, CharField) and not issubclass(field_class, ChoiceField): +            # `allow_blank` is only valid for textual fields. +            field_kwargs.pop('allow_blank', None) + +        return field_class, field_kwargs + +    def build_relational_field(self, field_name, relation_info): +        """ +        Create fields for forward and reverse relationships. +        """ +        field_class = self.serializer_related_class +        field_kwargs = get_relation_kwargs(field_name, relation_info) -        unique_constraint_names -= set([None]) +        # `view_name` is only valid for hyperlinked relationships. +        if not issubclass(field_class, HyperlinkedRelatedField): +            field_kwargs.pop('view_name', None) -        # Include each of the `unique_together` field names, -        # so long as all the field names are included on the serializer. -        for parent_class in [model] + list(model._meta.parents.keys()): -            for unique_together_list in parent_class._meta.unique_together: -                if set(fields).issuperset(set(unique_together_list)): -                    unique_constraint_names |= set(unique_together_list) +        return field_class, field_kwargs -        # Now we have all the field names that have uniqueness constraints -        # applied, we can add the extra 'required=...' or 'default=...' -        # arguments that are appropriate to these fields, or add a `HiddenField` for it. -        for unique_constraint_name in unique_constraint_names: -            # Get the model field that is referred too. -            unique_constraint_field = model._meta.get_field(unique_constraint_name) +    def build_nested_field(self, field_name, relation_info, nested_depth): +        """ +        Create nested fields for forward and reverse relationships. +        """ +        class NestedSerializer(ModelSerializer): +            class Meta: +                model = relation_info.related_model +                depth = nested_depth -            if getattr(unique_constraint_field, 'auto_now_add', None): -                default = CreateOnlyDefault(timezone.now) -            elif getattr(unique_constraint_field, 'auto_now', None): -                default = timezone.now -            elif unique_constraint_field.has_default(): -                default = unique_constraint_field.default -            else: -                default = empty +        field_class = NestedSerializer +        field_kwargs = get_nested_relation_kwargs(relation_info) -            if unique_constraint_name in model_field_mapping: -                # The corresponding field is present in the serializer -                if unique_constraint_name not in extra_kwargs: -                    extra_kwargs[unique_constraint_name] = {} -                if default is empty: -                    if 'required' not in extra_kwargs[unique_constraint_name]: -                        extra_kwargs[unique_constraint_name]['required'] = True -                else: -                    if 'default' not in extra_kwargs[unique_constraint_name]: -                        extra_kwargs[unique_constraint_name]['default'] = default -            elif default is not empty: -                # The corresponding field is not present in the, -                # serializer. We have a default to use for it, so -                # add in a hidden field that populates it. -                hidden_fields[unique_constraint_name] = HiddenField(default=default) +        return field_class, field_kwargs -        # Now determine the fields that should be included on the serializer. -        for field_name in fields: -            if field_name in declared_fields: -                # Field is explicitly declared on the class, use that. -                ret[field_name] = declared_fields[field_name] -                continue +    def build_property_field(self, field_name, model_class): +        """ +        Create a read only field for model methods and properties. +        """ +        field_class = ReadOnlyField +        field_kwargs = {} -            elif field_name in info.fields_and_pk: -                # Create regular model fields. -                model_field = info.fields_and_pk[field_name] -                field_cls = self._field_mapping[model_field] -                kwargs = get_field_kwargs(field_name, model_field) -                if 'choices' in kwargs: -                    # Fields with choices get coerced into `ChoiceField` -                    # instead of using their regular typed field. -                    field_cls = ChoiceField -                if not issubclass(field_cls, ModelField): -                    # `model_field` is only valid for the fallback case of -                    # `ModelField`, which is used when no other typed field -                    # matched to the model field. -                    kwargs.pop('model_field', None) -                if not issubclass(field_cls, CharField) and not issubclass(field_cls, ChoiceField): -                    # `allow_blank` is only valid for textual fields. -                    kwargs.pop('allow_blank', None) - -            elif field_name in info.relations: -                # Create forward and reverse relationships. -                relation_info = info.relations[field_name] -                if depth: -                    field_cls = self._get_nested_class(depth, relation_info) -                    kwargs = get_nested_relation_kwargs(relation_info) -                else: -                    field_cls = self._related_class -                    kwargs = get_relation_kwargs(field_name, relation_info) -                    # `view_name` is only valid for hyperlinked relationships. -                    if not issubclass(field_cls, HyperlinkedRelatedField): -                        kwargs.pop('view_name', None) - -            elif hasattr(model, field_name): -                # Create a read only field for model methods and properties. -                field_cls = ReadOnlyField -                kwargs = {} - -            elif field_name == api_settings.URL_FIELD_NAME: -                # Create the URL field. -                field_cls = HyperlinkedIdentityField -                kwargs = get_url_kwargs(model) +        return field_class, field_kwargs -            else: -                raise ImproperlyConfigured( -                    'Field name `%s` is not valid for model `%s`.' % -                    (field_name, model.__class__.__name__) -                ) +    def build_url_field(self, field_name, model_class): +        """ +        Create a field representing the object's own URL. +        """ +        field_class = HyperlinkedIdentityField +        field_kwargs = get_url_kwargs(model_class) + +        return field_class, field_kwargs -            # Populate any kwargs defined in `Meta.extra_kwargs` -            extras = extra_kwargs.get(field_name, {}) -            if extras.get('read_only', False): -                for attr in [ -                    'required', 'default', 'allow_blank', 'allow_null', -                    'min_length', 'max_length', 'min_value', 'max_value', -                    'validators', 'queryset' -                ]: -                    kwargs.pop(attr, None) +    def build_unknown_field(self, field_name, model_class): +        """ +        Raise an error on any unknown fields. +        """ +        raise ImproperlyConfigured( +            'Field name `%s` is not valid for model `%s`.' % +            (field_name, model_class.__name__) +        ) -            if extras.get('default') and kwargs.get('required') is False: -                kwargs.pop('required') +    def build_field_kwargs(self, kwargs, extra_kwargs, field_name): +        """ +        Include an 'extra_kwargs' that have been included for this field, +        possibly removing any incompatible existing keyword arguments. +        """ +        extras = extra_kwargs.get(field_name, {}) -            kwargs.update(extras) +        if extras.get('read_only', False): +            for attr in [ +                'required', 'default', 'allow_blank', 'allow_null', +                'min_length', 'max_length', 'min_value', 'max_value', +                'validators', 'queryset' +            ]: +                kwargs.pop(attr, None) -            # Create the serializer field. -            ret[field_name] = field_cls(**kwargs) +        if extras.get('default') and kwargs.get('required') is False: +            kwargs.pop('required') -        for field_name, field in hidden_fields.items(): -            ret[field_name] = field +        kwargs.update(extras) -        return ret +        return kwargs + +    # Methods for determining additional keyword arguments to apply... + +    def get_extra_kwargs(self): +        """ +        Return a dictionary mapping field names to a dictionary of +        additional keyword arguments. +        """ +        extra_kwargs = getattr(self.Meta, 'extra_kwargs', {}) -    def _include_additional_options(self, extra_kwargs):          read_only_fields = getattr(self.Meta, 'read_only_fields', None)          if read_only_fields is not None:              for field_name in read_only_fields: @@ -1117,21 +1129,204 @@ class ModelSerializer(Serializer):          return extra_kwargs -    def _get_default_field_names(self, declared_fields, model_info): +    def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs): +        """ +        Return any additional field options that need to be included as a +        result of uniqueness constraints on the model. This is returned as +        a two-tuple of: + +        ('dict of updated extra kwargs', 'mapping of hidden fields') +        """ +        model = getattr(self.Meta, 'model') +        model_fields = self._get_model_fields( +            field_names, declared_fields, extra_kwargs +        ) + +        # Determine if we need any additional `HiddenField` or extra keyword +        # arguments to deal with `unique_for` dates that are required to +        # be in the input data in order to validate it. +        unique_constraint_names = set() + +        for model_field in model_fields.values(): +            # Include each of the `unique_for_*` field names. +            unique_constraint_names |= set([ +                model_field.unique_for_date, +                model_field.unique_for_month, +                model_field.unique_for_year +            ]) + +        unique_constraint_names -= set([None]) + +        # Include each of the `unique_together` field names, +        # so long as all the field names are included on the serializer. +        for parent_class in [model] + list(model._meta.parents.keys()): +            for unique_together_list in parent_class._meta.unique_together: +                if set(field_names).issuperset(set(unique_together_list)): +                    unique_constraint_names |= set(unique_together_list) + +        # Now we have all the field names that have uniqueness constraints +        # applied, we can add the extra 'required=...' or 'default=...' +        # arguments that are appropriate to these fields, or add a `HiddenField` for it. +        hidden_fields = {} +        uniqueness_extra_kwargs = {} + +        for unique_constraint_name in unique_constraint_names: +            # Get the model field that is referred too. +            unique_constraint_field = model._meta.get_field(unique_constraint_name) + +            if getattr(unique_constraint_field, 'auto_now_add', None): +                default = CreateOnlyDefault(timezone.now) +            elif getattr(unique_constraint_field, 'auto_now', None): +                default = timezone.now +            elif unique_constraint_field.has_default(): +                default = unique_constraint_field.default +            else: +                default = empty + +            if unique_constraint_name in model_fields: +                # The corresponding field is present in the serializer +                if default is empty: +                    uniqueness_extra_kwargs[unique_constraint_name] = {'required': True} +                else: +                    uniqueness_extra_kwargs[unique_constraint_name] = {'default': default} +            elif default is not empty: +                # The corresponding field is not present in the, +                # serializer. We have a default to use for it, so +                # add in a hidden field that populates it. +                hidden_fields[unique_constraint_name] = HiddenField(default=default) + +        # Update `extra_kwargs` with any new options. +        for key, value in uniqueness_extra_kwargs.items(): +            if key in extra_kwargs: +                extra_kwargs[key].update(value) +            else: +                extra_kwargs[key] = value + +        return extra_kwargs, hidden_fields + +    def _get_model_fields(self, field_names, declared_fields, extra_kwargs): +        """ +        Returns all the model fields that are being mapped to by fields +        on the serializer class. +        Returned as a dict of 'model field name' -> 'model field'. +        Used internally by `get_uniqueness_field_options`. +        """ +        model = getattr(self.Meta, 'model') +        model_fields = {} + +        for field_name in field_names: +            if field_name in declared_fields: +                # If the field is declared on the serializer +                field = declared_fields[field_name] +                source = field.source or field_name +            else: +                try: +                    source = extra_kwargs[field_name]['source'] +                except KeyError: +                    source = field_name + +            if '.' in source or source == '*': +                # Model fields will always have a simple source mapping, +                # they can't be nested attribute lookups. +                continue + +            try: +                field = model._meta.get_field(source) +                if isinstance(field, DjangoModelField): +                    model_fields[source] = field +            except FieldDoesNotExist: +                pass + +        return model_fields + +    # Determine the validators to apply... + +    def get_validators(self): +        """ +        Determine the set of validators to use when instantiating serializer. +        """ +        # If the validators have been declared explicitly then use that. +        validators = getattr(getattr(self, 'Meta', None), 'validators', None) +        if validators is not None: +            return validators[:] + +        # Otherwise use the default set of validators.          return ( -            [model_info.pk.name] + -            list(declared_fields.keys()) + -            list(model_info.fields.keys()) + -            list(model_info.forward_relations.keys()) +            self.get_unique_together_validators() + +            self.get_unique_for_date_validators()          ) -    def _get_nested_class(self, nested_depth, relation_info): -        class NestedSerializer(ModelSerializer): -            class Meta: -                model = relation_info.related -                depth = nested_depth - 1 +    def get_unique_together_validators(self): +        """ +        Determine a default set of validators for any unique_together contraints. +        """ +        model_class_inheritance_tree = ( +            [self.Meta.model] + +            list(self.Meta.model._meta.parents.keys()) +        ) + +        # The field names we're passing though here only include fields +        # which may map onto a model field. Any dotted field name lookups +        # cannot map to a field, and must be a traversal, so we're not +        # including those. +        field_names = set([ +            field.source for field in self.fields.values() +            if (field.source != '*') and ('.' not in field.source) +        ]) + +        # Note that we make sure to check `unique_together` both on the +        # base model class, but also on any parent classes. +        validators = [] +        for parent_class in model_class_inheritance_tree: +            for unique_together in parent_class._meta.unique_together: +                if field_names.issuperset(set(unique_together)): +                    validator = UniqueTogetherValidator( +                        queryset=parent_class._default_manager, +                        fields=unique_together +                    ) +                    validators.append(validator) +        return validators + +    def get_unique_for_date_validators(self): +        """ +        Determine a default set of validators for the following contraints: + +        * unique_for_date +        * unique_for_month +        * unique_for_year +        """ +        info = model_meta.get_field_info(self.Meta.model) +        default_manager = self.Meta.model._default_manager +        field_names = [field.source for field in self.fields.values()] + +        validators = [] + +        for field_name, field in info.fields_and_pk.items(): +            if field.unique_for_date and field_name in field_names: +                validator = UniqueForDateValidator( +                    queryset=default_manager, +                    field=field_name, +                    date_field=field.unique_for_date +                ) +                validators.append(validator) -        return NestedSerializer +            if field.unique_for_month and field_name in field_names: +                validator = UniqueForMonthValidator( +                    queryset=default_manager, +                    field=field_name, +                    date_field=field.unique_for_month +                ) +                validators.append(validator) + +            if field.unique_for_year and field_name in field_names: +                validator = UniqueForYearValidator( +                    queryset=default_manager, +                    field=field_name, +                    date_field=field.unique_for_year +                ) +                validators.append(validator) + +        return validators  if hasattr(models, 'UUIDField'): @@ -1152,9 +1347,13 @@ class HyperlinkedModelSerializer(ModelSerializer):      * A 'url' field is included instead of the 'id' field.      * Relationships to other instances are hyperlinks, instead of primary keys.      """ -    _related_class = HyperlinkedRelatedField +    serializer_related_class = HyperlinkedRelatedField -    def _get_default_field_names(self, declared_fields, model_info): +    def get_default_field_names(self, declared_fields, model_info): +        """ +        Return the default list of field names that will be used if the +        `Meta.fields` option is not specified. +        """          return (              [api_settings.URL_FIELD_NAME] +              list(declared_fields.keys()) + @@ -1162,10 +1361,16 @@ class HyperlinkedModelSerializer(ModelSerializer):              list(model_info.forward_relations.keys())          ) -    def _get_nested_class(self, nested_depth, relation_info): +    def build_nested_field(self, field_name, relation_info, nested_depth): +        """ +        Create nested fields for forward and reverse relationships. +        """          class NestedSerializer(HyperlinkedModelSerializer):              class Meta: -                model = relation_info.related +                model = relation_info.related_model                  depth = nested_depth - 1 -        return NestedSerializer +        field_class = NestedSerializer +        field_kwargs = get_nested_relation_kwargs(relation_info) + +        return field_class, field_kwargs diff --git a/rest_framework/settings.py b/rest_framework/settings.py index e5e5edaf..7331f265 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -5,11 +5,11 @@ For example your project's `settings.py` file might look like this:  REST_FRAMEWORK = {      'DEFAULT_RENDERER_CLASSES': (          'rest_framework.renderers.JSONRenderer', -        'rest_framework.renderers.YAMLRenderer', +        'rest_framework.renderers.TemplateHTMLRenderer',      )      'DEFAULT_PARSER_CLASSES': (          'rest_framework.parsers.JSONParser', -        'rest_framework.parsers.YAMLParser', +        'rest_framework.parsers.TemplateHTMLRenderer',      )  } @@ -47,9 +47,10 @@ DEFAULTS = {      'DEFAULT_THROTTLE_CLASSES': (),      'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',      'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata', +    'DEFAULT_VERSIONING_CLASS': None,      # Generic view behavior -    'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer', +    'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',      'DEFAULT_FILTER_BACKENDS': (),      # Throttling @@ -68,6 +69,11 @@ DEFAULTS = {      'SEARCH_PARAM': 'search',      'ORDERING_PARAM': 'ordering', +    # Versioning +    'DEFAULT_VERSION': None, +    'ALLOWED_VERSIONS': None, +    'VERSION_PARAM': 'version', +      # Authentication      'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',      'UNAUTHENTICATED_TOKEN': None, @@ -124,7 +130,8 @@ IMPORT_STRINGS = (      'DEFAULT_THROTTLE_CLASSES',      'DEFAULT_CONTENT_NEGOTIATION_CLASS',      'DEFAULT_METADATA_CLASS', -    'DEFAULT_PAGINATION_SERIALIZER_CLASS', +    'DEFAULT_VERSIONING_CLASS', +    'DEFAULT_PAGINATION_CLASS',      'DEFAULT_FILTER_BACKENDS',      'EXCEPTION_HANDLER',      'TEST_REQUEST_RENDERER_CLASSES', @@ -140,7 +147,9 @@ def perform_import(val, setting_name):      If the given setting is a string import notation,      then perform the necessary import or imports.      """ -    if isinstance(val, six.string_types): +    if val is None: +        return None +    elif isinstance(val, six.string_types):          return import_from_string(val, setting_name)      elif isinstance(val, (list, tuple)):          return [import_from_string(item, setting_name) for item in val] diff --git a/rest_framework/static/rest_framework/css/bootstrap-tweaks.css b/rest_framework/static/rest_framework/css/bootstrap-tweaks.css index 36c7be48..04f12ed3 100644 --- a/rest_framework/static/rest_framework/css/bootstrap-tweaks.css +++ b/rest_framework/static/rest_framework/css/bootstrap-tweaks.css @@ -60,6 +60,23 @@ a single block in the template.    color: #C20000;  } +.pagination>.disabled>a, +.pagination>.disabled>a:hover, +.pagination>.disabled>a:focus { +  cursor: not-allowed; +  pointer-events: none; +} + +.pager>.disabled>a, +.pager>.disabled>a:hover, +.pager>.disabled>a:focus { +  pointer-events: none; +} + +.pager .next { +  margin-left: 10px; +} +  /*=== dabapps bootstrap styles ====*/  html { @@ -185,10 +202,6 @@ body a:hover {    color: #c20000;  } -#content a span { -  text-decoration: underline; - } -  .request-info {    clear:both;  } diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index e9668193..877387f2 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -119,9 +119,18 @@                      <div class="page-header">                          <h1>{{ name }}</h1>                      </div> +                    <div style="float:left">                      {% block description %}                          {{ description }}                      {% endblock %} +                    </div> + +                    {% if paginator %} +                        <nav style="float: right"> +                        {% get_pagination_html paginator %} +                        </nav> +                    {% endif %} +                      <div class="request-info" style="clear: both" >                          <pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre>                      </div> diff --git a/rest_framework/templates/rest_framework/pagination/numbers.html b/rest_framework/templates/rest_framework/pagination/numbers.html new file mode 100644 index 00000000..04045810 --- /dev/null +++ b/rest_framework/templates/rest_framework/pagination/numbers.html @@ -0,0 +1,27 @@ +<ul class="pagination" style="margin: 5px 0 10px 0"> +    {% if previous_url %} +        <li><a href="{{ previous_url }}" aria-label="Previous"><span aria-hidden="true">«</span></a></li> +    {% else %} +        <li class="disabled"><a href="#" aria-label="Previous"><span aria-hidden="true">«</span></a></li> +    {% endif %} + +    {% for page_link in page_links %} +        {% if page_link.is_break %} +            <li class="disabled"> +                <a href="#"><span aria-hidden="true">…</span></a> +            </li> +        {% else %} +            {% if page_link.is_active %} +                <li class="active"><a href="{{ page_link.url }}">{{ page_link.number }}</a></li> +            {% else %} +                <li><a href="{{ page_link.url }}">{{ page_link.number }}</a></li> +            {% endif %} +        {% endif %} +    {% endfor %} + +    {% if next_url %} +        <li><a href="{{ next_url }}" aria-label="Next"><span aria-hidden="true">»</span></a></li> +    {% else %} +        <li class="disabled"><a href="#" aria-label="Next"><span aria-hidden="true">»</span></a></li> +    {% endif %} +</ul> diff --git a/rest_framework/templates/rest_framework/pagination/previous_and_next.html b/rest_framework/templates/rest_framework/pagination/previous_and_next.html new file mode 100644 index 00000000..eacbfff4 --- /dev/null +++ b/rest_framework/templates/rest_framework/pagination/previous_and_next.html @@ -0,0 +1,12 @@ +<ul class="pager"> +{% if previous_url %} +    <li class="previous"><a href="{{ previous_url }}">« Previous</a></li> +{% else %} +    <li class="previous disabled"><a href="#">« Previous</a></li> +{% endif %} +{% if next_url %} +    <li class="next"><a href="{{ next_url }}">Next »</a></li> +{% else %} +    <li class="next disabled"><a href="#">Next »</li> +{% endif %} +</ul> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 69e03af4..a969836f 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -1,36 +1,25 @@  from __future__ import unicode_literals, absolute_import  from django import template  from django.core.urlresolvers import reverse, NoReverseMatch -from django.http import QueryDict  from django.utils import six -from django.utils.six.moves.urllib import parse as urlparse  from django.utils.encoding import iri_to_uri, force_text  from django.utils.html import escape  from django.utils.safestring import SafeData, mark_safe  from django.utils.html import smart_urlquote  from rest_framework.renderers import HTMLFormRenderer +from rest_framework.utils.urls import replace_query_param  import re  register = template.Library() - -def replace_query_param(url, key, val): -    """ -    Given a URL and a key/val pair, set or replace an item in the query -    parameters of the URL, and return the new URL. -    """ -    (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url) -    query_dict = QueryDict(query).copy() -    query_dict[key] = val -    query = query_dict.urlencode() -    return urlparse.urlunsplit((scheme, netloc, path, query, fragment)) - -  # Regex for adding classes to html snippets  class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') -# And the template tags themselves... +@register.simple_tag +def get_pagination_html(pager): +    return pager.to_html() +  @register.simple_tag  def render_field(field, style=None): diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index bf753271..2160d18b 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -6,11 +6,9 @@ from django.db.models.query import QuerySet  from django.utils import six, timezone  from django.utils.encoding import force_text  from django.utils.functional import Promise -from rest_framework.compat import OrderedDict, total_seconds -from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList +from rest_framework.compat import total_seconds  import datetime  import decimal -import types  import json  import uuid @@ -61,65 +59,3 @@ class JSONEncoder(json.JSONEncoder):          elif hasattr(obj, '__iter__'):              return tuple(item for item in obj)          return super(JSONEncoder, self).default(obj) - - -try: -    import yaml -except ImportError: -    SafeDumper = None -else: -    # Adapted from http://pyyaml.org/attachment/ticket/161/use_ordered_dict.py -    class SafeDumper(yaml.SafeDumper): -        """ -        Handles decimals as strings. -        Handles OrderedDicts as usual dicts, but preserves field order, rather -        than the usual behaviour of sorting the keys. -        """ -        def represent_decimal(self, data): -            return self.represent_scalar('tag:yaml.org,2002:str', six.text_type(data)) - -        def represent_mapping(self, tag, mapping, flow_style=None): -            value = [] -            node = yaml.MappingNode(tag, value, flow_style=flow_style) -            if self.alias_key is not None: -                self.represented_objects[self.alias_key] = node -            best_style = True -            if hasattr(mapping, 'items'): -                mapping = list(mapping.items()) -                if not isinstance(mapping, OrderedDict): -                    mapping.sort() -            for item_key, item_value in mapping: -                node_key = self.represent_data(item_key) -                node_value = self.represent_data(item_value) -                if not (isinstance(node_key, yaml.ScalarNode) and not node_key.style): -                    best_style = False -                if not (isinstance(node_value, yaml.ScalarNode) and not node_value.style): -                    best_style = False -                value.append((node_key, node_value)) -            if flow_style is None: -                if self.default_flow_style is not None: -                    node.flow_style = self.default_flow_style -                else: -                    node.flow_style = best_style -            return node - -    SafeDumper.add_representer( -        decimal.Decimal, -        SafeDumper.represent_decimal -    ) -    SafeDumper.add_representer( -        OrderedDict, -        yaml.representer.SafeRepresenter.represent_dict -    ) -    SafeDumper.add_representer( -        ReturnDict, -        yaml.representer.SafeRepresenter.represent_dict -    ) -    SafeDumper.add_representer( -        ReturnList, -        yaml.representer.SafeRepresenter.represent_list -    ) -    SafeDumper.add_representer( -        types.GeneratorType, -        yaml.representer.SafeRepresenter.represent_list -    ) diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 470af51b..8b6f005e 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -2,12 +2,10 @@  Utility functions to return a formatted name and description for a given view.  """  from __future__ import unicode_literals -import re -  from django.utils.html import escape  from django.utils.safestring import mark_safe -  from rest_framework.compat import apply_markdown, force_text +import re  def remove_trailing_string(content, trailing): @@ -59,4 +57,5 @@ def markup_description(description):          description = apply_markdown(description)      else:          description = escape(description).replace('\n', '<br />') +        description = '<p>' + description + '</p>'      return mark_safe(description) diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index 6a5835f5..dd92f8b6 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -24,7 +24,7 @@ FieldInfo = namedtuple('FieldResult', [  RelationInfo = namedtuple('RelationInfo', [      'model_field', -    'related', +    'related_model',      'to_many',      'has_through_model'  ]) @@ -98,7 +98,7 @@ def _get_forward_relationships(opts):      for field in [field for field in opts.fields if field.serialize and field.rel]:          forward_relations[field.name] = RelationInfo(              model_field=field, -            related=_resolve_model(field.rel.to), +            related_model=_resolve_model(field.rel.to),              to_many=False,              has_through_model=False          ) @@ -107,7 +107,7 @@ def _get_forward_relationships(opts):      for field in [field for field in opts.many_to_many if field.serialize]:          forward_relations[field.name] = RelationInfo(              model_field=field, -            related=_resolve_model(field.rel.to), +            related_model=_resolve_model(field.rel.to),              to_many=True,              has_through_model=(                  not field.rel.through._meta.auto_created @@ -131,7 +131,7 @@ def _get_reverse_relationships(opts):          related = getattr(relation, 'related_model', relation.model)          reverse_relations[accessor_name] = RelationInfo(              model_field=None, -            related=related, +            related_model=related,              to_many=relation.field.rel.multiple,              has_through_model=False          ) @@ -142,7 +142,7 @@ def _get_reverse_relationships(opts):          related = getattr(relation, 'related_model', relation.model)          reverse_relations[accessor_name] = RelationInfo(              model_field=None, -            related=related, +            related_model=related,              to_many=True,              has_through_model=(                  (getattr(relation.field.rel, 'through', None) is not None) diff --git a/rest_framework/utils/urls.py b/rest_framework/utils/urls.py new file mode 100644 index 00000000..880ef9ed --- /dev/null +++ b/rest_framework/utils/urls.py @@ -0,0 +1,25 @@ +from django.utils.six.moves.urllib import parse as urlparse + + +def replace_query_param(url, key, val): +    """ +    Given a URL and a key/val pair, set or replace an item in the query +    parameters of the URL, and return the new URL. +    """ +    (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url) +    query_dict = urlparse.parse_qs(query) +    query_dict[key] = [val] +    query = urlparse.urlencode(sorted(list(query_dict.items())), doseq=True) +    return urlparse.urlunsplit((scheme, netloc, path, query, fragment)) + + +def remove_query_param(url, key): +    """ +    Given a URL and a key/val pair, remove an item in the query +    parameters of the URL, and return the new URL. +    """ +    (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url) +    query_dict = urlparse.parse_qs(query) +    query_dict.pop(key, None) +    query = urlparse.urlencode(sorted(list(query_dict.items())), doseq=True) +    return urlparse.urlunsplit((scheme, netloc, path, query, fragment)) diff --git a/rest_framework/versioning.py b/rest_framework/versioning.py new file mode 100644 index 00000000..a07b629f --- /dev/null +++ b/rest_framework/versioning.py @@ -0,0 +1,174 @@ +# coding: utf-8 +from __future__ import unicode_literals +from django.utils.translation import ugettext_lazy as _ +from rest_framework import exceptions +from rest_framework.compat import unicode_http_header +from rest_framework.reverse import _reverse +from rest_framework.settings import api_settings +from rest_framework.templatetags.rest_framework import replace_query_param +from rest_framework.utils.mediatypes import _MediaType +import re + + +class BaseVersioning(object): +    default_version = api_settings.DEFAULT_VERSION +    allowed_versions = api_settings.ALLOWED_VERSIONS +    version_param = api_settings.VERSION_PARAM + +    def determine_version(self, request, *args, **kwargs): +        msg = '{cls}.determine_version() must be implemented.' +        raise NotImplementedError(msg.format( +            cls=self.__class__.__name__ +        )) + +    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): +        return _reverse(viewname, args, kwargs, request, format, **extra) + +    def is_allowed_version(self, version): +        if not self.allowed_versions: +            return True +        return (version == self.default_version) or (version in self.allowed_versions) + + +class AcceptHeaderVersioning(BaseVersioning): +    """ +    GET /something/ HTTP/1.1 +    Host: example.com +    Accept: application/json; version=1.0 +    """ +    invalid_version_message = _('Invalid version in "Accept" header.') + +    def determine_version(self, request, *args, **kwargs): +        media_type = _MediaType(request.accepted_media_type) +        version = media_type.params.get(self.version_param, self.default_version) +        version = unicode_http_header(version) +        if not self.is_allowed_version(version): +            raise exceptions.NotAcceptable(self.invalid_version_message) +        return version + +    # We don't need to implement `reverse`, as the versioning is based +    # on the `Accept` header, not on the request URL. + + +class URLPathVersioning(BaseVersioning): +    """ +    To the client this is the same style as `NamespaceVersioning`. +    The difference is in the backend - this implementation uses +    Django's URL keyword arguments to determine the version. + +    An example URL conf for two views that accept two different versions. + +    urlpatterns = [ +        url(r'^(?P<version>{v1,v2})/users/$', users_list, name='users-list'), +        url(r'^(?P<version>{v1,v2})/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail') +    ] + +    GET /1.0/something/ HTTP/1.1 +    Host: example.com +    Accept: application/json +    """ +    invalid_version_message = _('Invalid version in URL path.') + +    def determine_version(self, request, *args, **kwargs): +        version = kwargs.get(self.version_param, self.default_version) +        if not self.is_allowed_version(version): +            raise exceptions.NotFound(self.invalid_version_message) +        return version + +    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): +        if request.version is not None: +            kwargs = {} if (kwargs is None) else kwargs +            kwargs[self.version_param] = request.version + +        return super(URLPathVersioning, self).reverse( +            viewname, args, kwargs, request, format, **extra +        ) + + +class NamespaceVersioning(BaseVersioning): +    """ +    To the client this is the same style as `URLPathVersioning`. +    The difference is in the backend - this implementation uses +    Django's URL namespaces to determine the version. + +    An example URL conf that is namespaced into two seperate versions + +    # users/urls.py +    urlpatterns = [ +        url(r'^/users/$', users_list, name='users-list'), +        url(r'^/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail') +    ] + +    # urls.py +    urlpatterns = [ +        url(r'^v1/', include('users.urls', namespace='v1')), +        url(r'^v2/', include('users.urls', namespace='v2')) +    ] + +    GET /1.0/something/ HTTP/1.1 +    Host: example.com +    Accept: application/json +    """ +    invalid_version_message = _('Invalid version in URL path.') + +    def determine_version(self, request, *args, **kwargs): +        resolver_match = getattr(request, 'resolver_match', None) +        if (resolver_match is None or not resolver_match.namespace): +            return self.default_version +        version = resolver_match.namespace +        if not self.is_allowed_version(version): +            raise exceptions.NotFound(self.invalid_version_message) +        return version + +    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): +        if request.version is not None: +            viewname = request.version + ':' + viewname +        return super(NamespaceVersioning, self).reverse( +            viewname, args, kwargs, request, format, **extra +        ) + + +class HostNameVersioning(BaseVersioning): +    """ +    GET /something/ HTTP/1.1 +    Host: v1.example.com +    Accept: application/json +    """ +    hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$') +    invalid_version_message = _('Invalid version in hostname.') + +    def determine_version(self, request, *args, **kwargs): +        hostname, seperator, port = request.get_host().partition(':') +        match = self.hostname_regex.match(hostname) +        if not match: +            return self.default_version +        version = match.group(1) +        if not self.is_allowed_version(version): +            raise exceptions.NotFound(self.invalid_version_message) +        return version + +    # We don't need to implement `reverse`, as the hostname will already be +    # preserved as part of the REST framework `reverse` implementation. + + +class QueryParameterVersioning(BaseVersioning): +    """ +    GET /something/?version=0.1 HTTP/1.1 +    Host: example.com +    Accept: application/json +    """ +    invalid_version_message = _('Invalid version in query parameter.') + +    def determine_version(self, request, *args, **kwargs): +        version = request.query_params.get(self.version_param) +        if not self.is_allowed_version(version): +            raise exceptions.NotFound(self.invalid_version_message) +        return version + +    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): +        url = super(QueryParameterVersioning, self).reverse( +            viewname, args, kwargs, request, format, **extra +        ) +        if request.version is not None: +            return replace_query_param(url, self.version_param, request.version) +        return url diff --git a/rest_framework/views.py b/rest_framework/views.py index bc870417..12bb78bd 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -2,6 +2,8 @@  Provides an APIView class that is the base of all views in REST framework.  """  from __future__ import unicode_literals +import inspect +import warnings  from django.core.exceptions import PermissionDenied  from django.http import Http404 @@ -46,7 +48,7 @@ def get_view_description(view_cls, html=False):      return description -def exception_handler(exc): +def exception_handler(exc, context):      """      Returns the response that should be used for any given exception. @@ -93,6 +95,7 @@ class APIView(View):      permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES      content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS      metadata_class = api_settings.DEFAULT_METADATA_CLASS +    versioning_class = api_settings.DEFAULT_VERSIONING_CLASS      # Allow dependency injection of other settings to make testing easier.      settings = api_settings @@ -184,6 +187,18 @@ class APIView(View):              'request': getattr(self, 'request', None)          } +    def get_exception_handler_context(self): +        """ +        Returns a dict that is passed through to EXCEPTION_HANDLER, +        as the `context` argument. +        """ +        return { +            'view': self, +            'args': getattr(self, 'args', ()), +            'kwargs': getattr(self, 'kwargs', {}), +            'request': getattr(self, 'request', None) +        } +      def get_view_name(self):          """          Return the view name, as used in OPTIONS responses and in the @@ -300,6 +315,16 @@ class APIView(View):              if not throttle.allow_request(request, self):                  self.throttled(request, throttle.wait()) +    def determine_version(self, request, *args, **kwargs): +        """ +        If versioning is being used, then determine any API version for the +        incoming request. Returns a two-tuple of (version, versioning_scheme) +        """ +        if self.versioning_class is None: +            return (None, None) +        scheme = self.versioning_class() +        return (scheme.determine_version(request, *args, **kwargs), scheme) +      # Dispatch methods      def initialize_request(self, request, *args, **kwargs): @@ -308,11 +333,13 @@ class APIView(View):          """          parser_context = self.get_parser_context(request) -        return Request(request, -                       parsers=self.get_parsers(), -                       authenticators=self.get_authenticators(), -                       negotiator=self.get_content_negotiator(), -                       parser_context=parser_context) +        return Request( +            request, +            parsers=self.get_parsers(), +            authenticators=self.get_authenticators(), +            negotiator=self.get_content_negotiator(), +            parser_context=parser_context +        )      def initial(self, request, *args, **kwargs):          """ @@ -329,6 +356,10 @@ class APIView(View):          neg = self.perform_content_negotiation(request)          request.accepted_renderer, request.accepted_media_type = neg +        # Determine the API version, if versioning is in use. +        version, scheme = self.determine_version(request, *args, **kwargs) +        request.version, request.versioning_scheme = version, scheme +      def finalize_response(self, request, response, *args, **kwargs):          """          Returns the final response object. @@ -369,7 +400,18 @@ class APIView(View):              else:                  exc.status_code = status.HTTP_403_FORBIDDEN -        response = self.settings.EXCEPTION_HANDLER(exc) +        exception_handler = self.settings.EXCEPTION_HANDLER + +        if len(inspect.getargspec(exception_handler).args) == 1: +            warnings.warn( +                'The `exception_handler(exc)` call signature is deprecated. ' +                'Use `exception_handler(exc, context) instead.', +                PendingDeprecationWarning +            ) +            response = exception_handler(exc) +        else: +            context = self.get_exception_handler_context() +            response = exception_handler(exc, context)          if response is None:              raise | 
