diff options
| author | Tom Christie | 2013-08-29 20:52:46 +0100 | 
|---|---|---|
| committer | Tom Christie | 2013-08-29 20:52:46 +0100 | 
| commit | 19f9adacb254841d02f43295baf81406ce3c60eb (patch) | |
| tree | f77644b5515c15e09d49d12aef0855c67262f9ba /rest_framework | |
| parent | e4d2f54529bcf538be93da5770e05b88a32da1c7 (diff) | |
| parent | 02b6836ee88498861521dfff743467b0456ad109 (diff) | |
| download | django-rest-framework-19f9adacb254841d02f43295baf81406ce3c60eb.tar.bz2 | |
Merge branch 'master' into display-raw-data
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/fields.py | 5 | ||||
| -rw-r--r-- | rest_framework/generics.py | 11 | ||||
| -rw-r--r-- | rest_framework/settings.py | 10 | ||||
| -rw-r--r-- | rest_framework/tests/test_fields.py | 8 | ||||
| -rw-r--r-- | rest_framework/tests/test_pagination.py | 47 | ||||
| -rw-r--r-- | rest_framework/throttling.py | 11 | ||||
| -rw-r--r-- | rest_framework/utils/breadcrumbs.py | 7 | ||||
| -rw-r--r-- | rest_framework/views.py | 88 | 
8 files changed, 148 insertions, 39 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 3e0ca1a1..210c2537 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -514,6 +514,11 @@ class ChoiceField(WritableField):                      return True          return False +    def from_native(self, value): +        if value in validators.EMPTY_VALUES: +            return None +        return super(ChoiceField, self).from_native(value) +  class EmailField(CharField):      type_name = 'EmailField' diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 5ecf6310..14feed20 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -14,13 +14,15 @@ from rest_framework.settings import api_settings  import warnings -def strict_positive_int(integer_string): +def strict_positive_int(integer_string, cutoff=None):      """      Cast a string to a strictly positive integer.      """      ret = int(integer_string)      if ret <= 0:          raise ValueError() +    if cutoff: +        ret = min(ret, cutoff)      return ret  def get_object_or_404(queryset, **filter_kwargs): @@ -56,6 +58,7 @@ class GenericAPIView(views.APIView):      # Pagination settings      paginate_by = api_settings.PAGINATE_BY      paginate_by_param = api_settings.PAGINATE_BY_PARAM +    max_paginate_by = api_settings.MAX_PAGINATE_BY      pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS      page_kwarg = 'page' @@ -205,9 +208,11 @@ class GenericAPIView(views.APIView):                            PendingDeprecationWarning, stacklevel=2)          if self.paginate_by_param: -            query_params = self.request.QUERY_PARAMS              try: -                return int(query_params[self.paginate_by_param]) +                return strict_positive_int( +                    self.request.QUERY_PARAMS[self.paginate_by_param], +                    cutoff=self.max_paginate_by +                )              except (KeyError, ValueError):                  pass diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 7d25e513..8c084751 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -48,7 +48,6 @@ DEFAULTS = {      ),      'DEFAULT_THROTTLE_CLASSES': (      ), -      'DEFAULT_CONTENT_NEGOTIATION_CLASS':          'rest_framework.negotiation.DefaultContentNegotiation', @@ -68,15 +67,16 @@ DEFAULTS = {      # Pagination      'PAGINATE_BY': None,      'PAGINATE_BY_PARAM': None, - -    # View configuration -    'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', -    'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', +    'MAX_PAGINATE_BY': None,      # Authentication      'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',      'UNAUTHENTICATED_TOKEN': None, +    # View configuration +    'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', +    'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', +      # Testing      'TEST_REQUEST_RENDERER_CLASSES': (          'rest_framework.renderers.MultiPartRenderer', diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index ebccba7d..34fbab9c 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -688,6 +688,14 @@ class ChoiceFieldTests(TestCase):          f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)          self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES) +    def test_from_native_empty(self): +        """ +        Make sure from_native() returns None on empty param. +        """ +        f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES) +        result = f.from_native('') +        self.assertEqual(result, None) +  class EmailFieldTests(TestCase):      """ diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index 85d4640e..4170d4b6 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -42,6 +42,16 @@ class PaginateByParamView(generics.ListAPIView):      paginate_by_param = 'page_size' +class MaxPaginateByView(generics.ListAPIView): +    """ +    View for testing custom max_paginate_by usage +    """ +    model = BasicModel +    paginate_by = 3 +    max_paginate_by = 5 +    paginate_by_param = 'page_size' + +  class IntegrationTestPagination(TestCase):      """      Integration tests for paginated list views. @@ -313,6 +323,43 @@ class TestCustomPaginateByParam(TestCase):          self.assertEqual(response.data['results'], self.data[:5]) +class TestMaxPaginateByParam(TestCase): +    """ +    Tests for list views with max_paginate_by kwarg +    """ + +    def setUp(self): +        """ +        Create 13 BasicModel instances. +        """ +        for i in range(13): +            BasicModel(text=i).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] +        self.view = MaxPaginateByView.as_view() + +    def test_max_paginate_by(self): +        """ +        If max_paginate_by is set, it should limit page size for the view. +        """ +        request = factory.get('/?page_size=10') +        response = self.view(request).render() +        self.assertEqual(response.data['count'], 13) +        self.assertEqual(response.data['results'], self.data[:5]) + +    def test_max_paginate_by_without_page_size_param(self): +        """ +        If max_paginate_by is set, but client does not specifiy page_size, +        standard `paginate_by` behavior should be used. +        """ +        request = factory.get('/') +        response = self.view(request).render() +        self.assertEqual(response.data['results'], self.data[:3]) + +  ### Tests for context in pagination serializers  class CustomField(serializers.Field): diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 65b45593..a946d837 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -2,7 +2,7 @@  Provides various throttling policies.  """  from __future__ import unicode_literals -from django.core.cache import cache +from django.core.cache import cache as default_cache  from django.core.exceptions import ImproperlyConfigured  from rest_framework.settings import api_settings  import time @@ -39,6 +39,7 @@ class SimpleRateThrottle(BaseThrottle):      Previous request information used for throttling is stored in the cache.      """ +    cache = default_cache      timer = time.time      cache_format = 'throtte_%(scope)s_%(ident)s'      scope = None @@ -99,7 +100,7 @@ class SimpleRateThrottle(BaseThrottle):          if self.key is None:              return True -        self.history = cache.get(self.key, []) +        self.history = self.cache.get(self.key, [])          self.now = self.timer()          # Drop any requests from the history which have now passed the @@ -116,7 +117,7 @@ class SimpleRateThrottle(BaseThrottle):          into the cache.          """          self.history.insert(0, self.now) -        cache.set(self.key, self.history, self.duration) +        self.cache.set(self.key, self.history, self.duration)          return True      def throttle_failure(self): @@ -151,7 +152,9 @@ class AnonRateThrottle(SimpleRateThrottle):          if request.user.is_authenticated():              return None  # Only throttle unauthenticated requests. -        ident = request.META.get('REMOTE_ADDR', None) +        ident = request.META.get('HTTP_X_FORWARDED_FOR') +        if ident is None: +            ident = request.META.get('REMOTE_ADDR')          return self.cache_format % {              'scope': self.scope, diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 0384faba..e6690d17 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -8,8 +8,11 @@ def get_breadcrumbs(url):      tuple of (name, url).      """ +    from rest_framework.settings import api_settings      from rest_framework.views import APIView +    view_name_func = api_settings.VIEW_NAME_FUNCTION +      def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):          """          Add tuples of (name, url) to the breadcrumbs list, @@ -28,8 +31,8 @@ def get_breadcrumbs(url):                  # Don't list the same view twice in a row.                  # Probably an optional trailing slash.                  if not seen or seen[-1] != view: -                    instance = view.cls() -                    name = instance.get_view_name() +                    suffix = getattr(view, 'suffix', None) +                    name = view_name_func(cls, suffix)                      breadcrumbs_list.insert(0, (name, prefix + url))                      seen.append(view) diff --git a/rest_framework/views.py b/rest_framework/views.py index 727a9f95..4cff0422 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -15,8 +15,14 @@ from rest_framework.settings import api_settings  from rest_framework.utils import formatting -def get_view_name(cls, suffix=None): -    name = cls.__name__ +def get_view_name(view_cls, suffix=None): +    """ +    Given a view class, return a textual name to represent the view. +    This name is used in the browsable API, and in OPTIONS responses. + +    This function is the default for the `VIEW_NAME_FUNCTION` setting. +    """ +    name = view_cls.__name__      name = formatting.remove_trailing_string(name, 'View')      name = formatting.remove_trailing_string(name, 'ViewSet')      name = formatting.camelcase_to_spaces(name) @@ -25,17 +31,56 @@ def get_view_name(cls, suffix=None):      return name -def get_view_description(cls, html=False): -    description = cls.__doc__ or '' +def get_view_description(view_cls, html=False): +    """ +    Given a view class, return a textual description to represent the view. +    This name is used in the browsable API, and in OPTIONS responses. + +    This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting. +    """ +    description = view_cls.__doc__ or ''      description = formatting.dedent(smart_text(description))      if html:          return formatting.markup_description(description)      return description +def exception_handler(exc): +    """ +    Returns the response that should be used for any given exception. + +    By default we handle the REST framework `APIException`, and also +    Django's builtin `Http404` and `PermissionDenied` exceptions. + +    Any unhandled exceptions may return `None`, which will cause a 500 error +    to be raised. +    """ +    if isinstance(exc, exceptions.APIException): +        headers = {} +        if getattr(exc, 'auth_header', None): +            headers['WWW-Authenticate'] = exc.auth_header +        if getattr(exc, 'wait', None): +            headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait + +        return Response({'detail': exc.detail}, +                        status=exc.status_code, +                        headers=headers) + +    elif isinstance(exc, Http404): +        return Response({'detail': 'Not found'}, +                        status=status.HTTP_404_NOT_FOUND) + +    elif isinstance(exc, PermissionDenied): +        return Response({'detail': 'Permission denied'}, +                        status=status.HTTP_403_FORBIDDEN) + +    # Note: Unhandled exceptions will raise a 500 error. +    return None + +  class APIView(View): -    settings = api_settings +    # The following policies may be set at either globally, or per-view.      renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES      parser_classes = api_settings.DEFAULT_PARSER_CLASSES      authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES @@ -43,6 +88,9 @@ class APIView(View):      permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES      content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS +    # Allow dependancy injection of other settings to make testing easier. +    settings = api_settings +      @classmethod      def as_view(cls, **initkwargs):          """ @@ -133,7 +181,7 @@ class APIView(View):          Return the view name, as used in OPTIONS responses and in the          browsable API.          """ -        func = api_settings.VIEW_NAME_FUNCTION +        func = self.settings.VIEW_NAME_FUNCTION          return func(self.__class__, getattr(self, 'suffix', None))      def get_view_description(self, html=False): @@ -141,7 +189,7 @@ class APIView(View):          Return some descriptive text for the view, as used in OPTIONS responses          and in the browsable API.          """ -        func = api_settings.VIEW_DESCRIPTION_FUNCTION +        func = self.settings.VIEW_DESCRIPTION_FUNCTION          return func(self.__class__, html)      # API policy instantiation methods @@ -303,33 +351,23 @@ class APIView(View):          Handle any exception that occurs, by returning an appropriate response,          or re-raising the error.          """ -        if isinstance(exc, exceptions.Throttled) and exc.wait is not None: -            # Throttle wait header -            self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait -          if isinstance(exc, (exceptions.NotAuthenticated,                              exceptions.AuthenticationFailed)):              # WWW-Authenticate header for 401 responses, else coerce to 403              auth_header = self.get_authenticate_header(self.request)              if auth_header: -                self.headers['WWW-Authenticate'] = auth_header +                exc.auth_header = auth_header              else:                  exc.status_code = status.HTTP_403_FORBIDDEN -        if isinstance(exc, exceptions.APIException): -            return Response({'detail': exc.detail}, -                            status=exc.status_code, -                            exception=True) -        elif isinstance(exc, Http404): -            return Response({'detail': 'Not found'}, -                            status=status.HTTP_404_NOT_FOUND, -                            exception=True) -        elif isinstance(exc, PermissionDenied): -            return Response({'detail': 'Permission denied'}, -                            status=status.HTTP_403_FORBIDDEN, -                            exception=True) -        raise +        response = exception_handler(exc) + +        if response is None: +            raise + +        response.exception = True +        return response      # Note: session based authentication is explicitly CSRF validated,      # all other authentication is CSRF exempt.  | 
