""" Provides an APIView class that is used as the base of all class-based views. """ import re from django.core.exceptions import PermissionDenied from django.http import Http404 from django.utils.html import escape from django.utils.safestring import mark_safe from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions from rest_framework.compat import View, apply_markdown from rest_framework.response import Response from rest_framework.request import Request from rest_framework.settings import api_settings def _remove_trailing_string(content, trailing): """ Strip trailing component `trailing` from `content` if it exists. Used when generating names from view classes. """ if content.endswith(trailing) and content != trailing: return content[:-len(trailing)] return content def _remove_leading_indent(content): """ Remove leading indent from a block of text. Used when generating descriptions from docstrings. """ whitespace_counts = [len(line) - len(line.lstrip(' ')) for line in content.splitlines()[1:] if line.lstrip()] # unindent the content if needed if whitespace_counts: whitespace_pattern = '^' + (' ' * min(whitespace_counts)) content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) content = content.strip('\n') return content def _camelcase_to_spaces(content): """ Translate 'CamelCaseNames' to 'Camel Case Names'. Used when generating names from view classes. """ camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' content = re.sub(camelcase_boundry, ' \\1', content).strip() return ' '.join(content.split('_')).title() class APIView(View): settings = api_settings renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES parser_classes = api_settings.DEFAULT_PARSER_CLASSES authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS @classmethod def as_view(cls, **initkwargs): """ Override the default :meth:`as_view` to store an instance of the view as an attribute on the callable function. This allows us to discover information about the view when we do URL reverse lookups. """ # TODO: deprecate? view = super(APIView, cls).as_view(**initkwargs) view.cls_instance = cls(**initkwargs) return view @property def allowed_methods(self): """ Return the list of allowed HTTP methods, uppercased. """ return [method.upper() for method in self.http_method_names if hasattr(self, method)] @property def default_response_headers(self): # TODO: deprecate? # TODO: Only vary by accept if multiple renderers return { 'Allow': ', '.join(self.allowed_methods), 'Vary': 'Accept' } def get_name(self): """ Return the resource or view class name for use as this view's name. Override to customize. """ # TODO: deprecate? name = self.__class__.__name__ name = _remove_trailing_string(name, 'View') return _camelcase_to_spaces(name) def get_description(self, html=False): """ Return the resource or view docstring for use as this view's description. Override to customize. """ # TODO: deprecate? description = self.__doc__ or '' description = _remove_leading_indent(description) if html: return self.markup_description(description) return description def markup_description(self, description): """ Apply HTML markup to the description of this view. """ # TODO: deprecate? if apply_markdown: description = apply_markdown(description) else: description = escape(description).replace('\n', '
') return mark_safe(description) def metadata(self, request): return { 'name': self.get_name(), 'description': self.get_description(), 'renders': [renderer.media_type for renderer in self.renderer_classes], 'parses': [parser.media_type for parser in self.parser_classes], } # TODO: Add 'fields', from serializer info, if it exists. # serializer = self.get_serializer() # if serializer is not None: # field_name_types = {} # for name, field in form.fields.iteritems(): # field_name_types[name] = field.__class__.__name__ # content['fields'] = field_name_types def http_method_not_allowed(self, request, *args, **kwargs): """ Called if `request.method` does not correspond to a handler method. """ raise exceptions.MethodNotAllowed(request.method) def permission_denied(self, request): """ If request is not permitted, determine what kind of exception to raise. """ raise exceptions.PermissionDenied() def throttled(self, request, wait): """ If request is throttled, determine what kind of exception to raise. """ raise exceptions.Throttled(wait) def get_parser_context(self, http_request): """ Returns a dict that is passed through to Parser.parse(), as the `parser_context` keyword argument. """ # Note: Additionally `request` will also be added to the context # by the Request object. return { 'view': self, 'args': getattr(self, 'args', ()), 'kwargs': getattr(self, 'kwargs', {}) } def get_renderer_context(self): """ Returns a dict that is passed through to Renderer.render(), as the `renderer_context` keyword argument. """ # Note: Additionally 'response' will also be added to the context, # by the Response object. return { 'view': self, 'args': getattr(self, 'args', ()), 'kwargs': getattr(self, 'kwargs', {}), 'request': getattr(self, 'request', None) } # API policy instantiation methods def get_format_suffix(self, **kwargs): """ Determine if the request includes a '.json' style format suffix """ if self.settings.FORMAT_SUFFIX_KWARG: return kwargs.get(self.settings.FORMAT_SUFFIX_KWARG) def get_renderers(self): """ Instantiates and returns the list of renderers that this view can use. """ return [renderer() for renderer in self.renderer_classes] def get_parsers(self): """ Instantiates and returns the list of renderers that this view can use. """ return [parser() for parser in self.parser_classes] def get_authenticators(self): """ Instantiates and returns the list of renderers that this view can use. """ return [auth() for auth in self.authentication_classes] def get_permissions(self): """ Instantiates and returns the list of permissions that this view requires. """ return [permission() for permission in self.permission_classes] def get_throttles(self): """ Instantiates and returns the list of throttles that this view uses. """ return [throttle() for throttle in self.throttle_classes] def get_content_negotiator(self): """ Instantiate and return the content negotiation class to use. """ if not getattr(self, '_negotiator', None): self._negotiator = self.content_negotiation_class() return self._negotiator # API policy implementation methods def perform_content_negotiation(self, request, force=False): """ Determine which renderer and media type to use render the response. """ renderers = self.get_renderers() conneg = self.get_content_negotiator() try: return conneg.select_renderer(request, renderers, self.format_kwarg) except: if force: return (renderers[0], renderers[0].media_type) raise def has_permission(self, request, obj=None): """ Return `True` if the request should be permitted. """ for permission in self.get_permissions(): if not permission.has_permission(request, self, obj): return False return True def check_throttles(self, request): """ Check if request should be throttled. """ for throttle in self.get_throttles(): if not throttle.allow_request(request, self): self.throttled(request, throttle.wait()) # Dispatch methods def initialize_request(self, request, *args, **kargs): """ Returns the initial request object. """ parser_context = self.get_parser_context(request) return Request(request, parsers=self.get_parsers(), authenticators=self.get_authenticators(), negotiator=self.get_content_negotiator(), parser_context=parser_context)
from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.compat import patterns, url, include
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework import status
from rest_framework.renderers import (
    BaseRenderer,
    JSONRenderer,
    BrowsableAPIRenderer
)
from rest_framework.settings import api_settings
from rest_framework.compat import six


class MockPickleRenderer(BaseRenderer):
    media_type = 'application/pickle'


class MockJsonRenderer(BaseRenderer):
    media_type = 'application/json'


DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'

RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')


class RendererA(BaseRenderer):
    media_type = 'mock/renderera'
    format = "formata"

    def render(self, data, media_type=None, renderer_context=None):
        return RENDERER_A_SERIALIZER(data)


class RendererB(BaseRenderer):
    media_type = 'mock/rendererb'
    format = "formatb"

    def render(self, data, media_type=None, renderer_context=None):
        return RENDERER_B_SERIALIZER(data)


class MockView(APIView):
    renderer_classes = (RendererA, RendererB)

    def get(self, request, **kwargs):
        return Response(DUMMYCONTENT, status=DUMMYSTATUS)


class HTMLView(APIView):
    renderer_classes = (BrowsableAPIRenderer, )

    def get(self, request, **kwargs):
        return Response('text')


class HTMLView1(APIView):
    renderer_classes = (BrowsableAPIRenderer, JSONRenderer)

    def get(self, request, **kwargs):
        return Response('text')


urlpatterns = patterns('',
    url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
    url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
    url(r'^html$', HTMLView.as_view()),
    url(r'^html1$', HTMLView1.as_view()),
    url(r'^restframework', include('rest_framework.urls', namespace='rest_framework'))
)


# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
class RendererIntegrationTests(TestCase):
    """
    End-to-end testing of renderers using an ResponseMixin on a generic view.
    """

    urls = 'rest_framework.tests.response'

    def test_default_renderer_serializes_content(self):
        """If the Accept header is not set the default renderer should serialize the response."""
        resp = self.client.get('/')
        self.assertEqual(resp['Content-Type'], RendererA.media_type)
        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
        self.assertEqual(resp.status_code, DUMMYSTATUS)

    def test_head_method_serializes_no_content(self):
        """No response must be included in HEAD requests."""
        resp = self.client.head('/')
        self.assertEqual(resp.status_code, DUMMYSTATUS)
        self.assertEqual(resp['Content-Type'], RendererA.media_type)
        self.assertEqual(resp.content, six.b(''))

    def test_default_renderer_serializes_content_on_accept_any(self):
        """If the Accept header is set to */* the default renderer should serialize the response."""
        resp = self.client.get('/', HTTP_ACCEPT='*/*')
        self.assertEqual(resp['Content-Type'], RendererA.media_type)
        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
        self.assertEqual(resp.status_code, DUMMYSTATUS)

    def test_specified_renderer_serializes_content_default_case(self):
        """If the Accept header is set the specified renderer should serialize the response.
        (In this case we check that works for the default renderer)"""
        resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
        self.assertEqual(resp['Content-Type'], RendererA.media_type)
        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
        self.assertEqual(resp.status_code, DUMMYSTATUS)

    def test_specified_renderer_serializes_content_non_default_case(self):
        """If the Accept header is set the specified renderer should serialize the response.
        (In this case we check that works for a non-default renderer)"""
        resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
        self.assertEqual(resp['Content-Type'], RendererB.media_type)
        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
        self.assertEqual(resp.status_code, DUMMYSTATUS)

    def test_specified_renderer_serializes_content_on_accept_query(self):
        """The '_accept' query string should behave in the same way as the Accept header."""
        param = '?%s=%s' % (
            api_settings.URL_ACCEPT_OVERRIDE,
            RendererB.media_type
        )
        resp = self.client.get('/' + param)
        self.assertEqual(resp['Content-Type'], RendererB.media_type)
        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
        self.assertEqual(resp.status_code, DUMMYSTATUS)

    def test_specified_renderer_serializes_content_on_format_query(self):
        """If a 'format' query is specified, the renderer with the matching
        format attribute should serialize the response."""
        resp = self.client.get('/?format=%s' % RendererB.format)
        self.assertEqual(resp['Content-Type'], RendererB.media_type)
        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
        self.assertEqual(resp.status_code, DUMMYSTATUS)

    def test_specified_renderer_serializes_content_on_format_kwargs(self):
        """If a 'format' keyword arg is specified, the renderer with the matching
        format attribute should serialize the response."""
        resp = self.client.get('/something.formatb')
        self.assertEqual(resp['Content-Type'], RendererB.media_type)
        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
        self.assertEqual(resp.status_code, DUMMYSTATUS)

    def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
        """If both a 'format' query and a matching Accept header specified,
        the renderer with the matching format attribute should serialize the response."""
        resp = self.client.get('/?format=%s' % RendererB.format,
                               HTTP_ACCEPT=RendererB.media_type)
        self.assertEqual(resp['Content-Type'], RendererB.media_type)
        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
        self.assertEqual(resp.status_code, DUMMYSTATUS)


class Issue122Tests(TestCase):
    """
    Tests that covers #122.
    """
    urls = 'rest_framework.tests.response'

    def test_only_html_renderer(self):
        """
        Test if no infinite recursion occurs.
        """
        self.client.get('/html')

    def test_html_renderer_is_first(self):
        """
        Test if no infinite recursion occurs.
        """
        self.client.get('/html1')