diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/fields.py | 31 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 7 | ||||
| -rw-r--r-- | rest_framework/tests/fields.py | 81 | ||||
| -rw-r--r-- | rest_framework/tests/generics.py | 36 | ||||
| -rw-r--r-- | rest_framework/tests/permissions.py | 45 | ||||
| -rw-r--r-- | rest_framework/tests/views.py | 5 | ||||
| -rw-r--r-- | rest_framework/views.py | 53 | 
7 files changed, 235 insertions, 23 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index d772c400..cb5f9a40 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -23,7 +23,8 @@ from django.utils.translation import ugettext_lazy as _  from django.utils.datastructures import SortedDict  from rest_framework import ISO_8601 -from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time +from rest_framework.compat import (timezone, parse_date, parse_datetime, +                                   parse_time)  from rest_framework.compat import BytesIO  from rest_framework.compat import six  from rest_framework.compat import smart_text, force_text, is_non_str_iterable @@ -61,7 +62,8 @@ def get_component(obj, attr_name):  def readable_datetime_formats(formats): -    format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]') +    format = ', '.join(formats).replace(ISO_8601, +             'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]')      return humanize_strptime(format) @@ -70,6 +72,18 @@ def readable_date_formats(formats):      return humanize_strptime(format) +def humanize_form_fields(form): +    """Return a humanized description of all the fields in a form. + +    :param form: A Django form. +    :return: A dictionary of {field_label: humanized description} + +    """ +    fields = SortedDict([(name, humanize_field(field)) +                         for name, field in form.fields.iteritems()]) +    return fields + +  def readable_time_formats(formats):      format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')      return humanize_strptime(format) @@ -193,6 +207,19 @@ class Field(object):              return {'type': self.type_name}          return {} +    @property +    def humanized(self): +        humanized = { +            'type': self.type_name, +            'required': getattr(self, 'required', False), +        } +        optional_attrs = ['read_only', 'help_text', 'label', +                          'min_length', 'max_length'] +        for attr in optional_attrs: +            if getattr(self, attr, None) is not None: +                humanized[attr] = getattr(self, attr) +        return humanized +  class WritableField(Field):      """ diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 31f261e1..17da8c25 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -521,6 +521,13 @@ class BaseSerializer(WritableField):          return self.object +    @property +    def humanized(self): +        humanized_fields = SortedDict( +            [(name, field.humanized) +             for name, field in self.fields.iteritems()]) +        return humanized_fields +  class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)):      pass diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index a3104206..22c515a9 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -2,13 +2,20 @@  General serializer field tests.  """  from __future__ import unicode_literals -from django.utils.datastructures import SortedDict -import datetime + +from collections import namedtuple  from decimal import Decimal +from uuid import uuid4 + +import datetime +from django import forms +from django.core import validators  from django.db import models  from django.test import TestCase -from django.core import validators +from django.utils.datastructures import SortedDict +  from rest_framework import serializers +from rest_framework.fields import Field, CharField  from rest_framework.serializers import Serializer  from rest_framework.tests.models import RESTFrameworkModel @@ -760,14 +767,16 @@ class SlugFieldTests(TestCase):      def test_given_serializer_value(self):          class SlugFieldSerializer(serializers.ModelSerializer): -            slug_field = serializers.SlugField(source='slug_field', max_length=20, required=False) +            slug_field = serializers.SlugField(source='slug_field', +                                               max_length=20, required=False)              class Meta:                  model = self.SlugFieldModel          serializer = SlugFieldSerializer(data={})          self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 20) +        self.assertEqual(getattr(serializer.fields['slug_field'], +                                 'max_length'), 20)      def test_invalid_slug(self):          """ @@ -803,7 +812,8 @@ class URLFieldTests(TestCase):          serializer = URLFieldSerializer(data={})          self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['url_field'], 'max_length'), 200) +        self.assertEqual(getattr(serializer.fields['url_field'], +                                 'max_length'), 200)      def test_given_model_value(self):          class URLFieldSerializer(serializers.ModelSerializer): @@ -812,15 +822,68 @@ class URLFieldTests(TestCase):          serializer = URLFieldSerializer(data={})          self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['url_field'], 'max_length'), 128) +        self.assertEqual(getattr(serializer.fields['url_field'], +                                 'max_length'), 128)      def test_given_serializer_value(self):          class URLFieldSerializer(serializers.ModelSerializer): -            url_field = serializers.URLField(source='url_field', max_length=20, required=False) +            url_field = serializers.URLField(source='url_field', +                                             max_length=20, required=False)              class Meta:                  model = self.URLFieldWithGivenMaxLengthModel          serializer = URLFieldSerializer(data={})          self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['url_field'], 'max_length'), 20) +        self.assertEqual(getattr(serializer.fields['url_field'], +                         'max_length'), 20) + + +class HumanizedField(TestCase): +    def setUp(self): +        self.required_field = Field() +        self.required_field.label = uuid4().hex +        self.required_field.required = True + +        self.optional_field = Field() +        self.optional_field.label = uuid4().hex +        self.optional_field.required = False + +    def test_type(self): +        for field in (self.required_field, self.optional_field): +            self.assertEqual(field.humanized['type'], field.type_name) + +    def test_required(self): +        self.assertEqual(self.required_field.humanized['required'], True) + +    def test_optional(self): +        self.assertEqual(self.optional_field.humanized['required'], False) + +    def test_label(self): +        for field in (self.required_field, self.optional_field): +            self.assertEqual(field.humanized['label'], field.label) + + +class HumanizableSerializer(Serializer): +    field1 = CharField(3, required=True) +    field2 = CharField(10, required=False) + + +class HumanizedSerializer(TestCase): +    def setUp(self): +        self.serializer = HumanizableSerializer() + +    def test_humanized(self): +        humanized = self.serializer.humanized +        expected = { +            'field1': {u'required': True, +                       u'max_length': 3, +                       u'type': u'CharField', +                       u'read_only': False}, +            'field2': {u'required': False, +                       u'max_length': 10, +                       u'type': u'CharField', +                       u'read_only': False}} +        self.assertEqual(set(expected.keys()), set(humanized.keys())) +        for k, v in humanized.iteritems(): +            self.assertEqual(v, expected[k]) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 15d87e86..a2f8fb4b 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -121,7 +121,22 @@ class TestRootView(TestCase):                  'text/html'              ],              'name': 'Root', -            'description': 'Example description for OPTIONS.' +            'description': 'Example description for OPTIONS.', +            'actions': {} +        } +        expected['actions']['GET'] = {} +        expected['actions']['POST'] = { +            'text': { +                'max_length': 100, +                'read_only': False, +                'required': True, +                'type': 'String', +            }, +            'id': { +                'read_only': True, +                'required': False, +                'type': 'Integer', +            },          }          self.assertEqual(response.status_code, status.HTTP_200_OK)          self.assertEqual(response.data, expected) @@ -238,8 +253,25 @@ class TestInstanceView(TestCase):                  'text/html'              ],              'name': 'Instance', -            'description': 'Example description for OPTIONS.' +            'description': 'Example description for OPTIONS.', +            'actions': {}          } +        for method in ('GET', 'DELETE'): +            expected['actions'][method] = {} +        for method in ('PATCH', 'PUT'): +            expected['actions'][method] = { +                'text': { +                    'max_length': 100, +                    'read_only': False, +                    'required': True, +                    'type': 'String', +                }, +                'id': { +                    'read_only': True, +                    'required': False, +                    'type': 'Integer', +                }, +            }          self.assertEqual(response.status_code, status.HTTP_200_OK)          self.assertEqual(response.data, expected) diff --git a/rest_framework/tests/permissions.py b/rest_framework/tests/permissions.py index b3993be5..5a18182b 100644 --- a/rest_framework/tests/permissions.py +++ b/rest_framework/tests/permissions.py @@ -108,6 +108,51 @@ class ModelPermissionsIntegrationTests(TestCase):          response = instance_view(request, pk='2')          self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) +    def test_options_permitted(self): +        request = factory.options('/', content_type='application/json', +                               HTTP_AUTHORIZATION=self.permitted_credentials) +        response = root_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertIn('actions', response.data) +        self.assertEquals(response.data['actions'].keys(), ['POST', 'GET',]) + +        request = factory.options('/1', content_type='application/json', +                               HTTP_AUTHORIZATION=self.permitted_credentials) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertIn('actions', response.data) +        self.assertEquals(response.data['actions'].keys(), ['PUT', 'PATCH', 'DELETE', 'GET',]) + +    def test_options_disallowed(self): +        request = factory.options('/', content_type='application/json', +                               HTTP_AUTHORIZATION=self.disallowed_credentials) +        response = root_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertIn('actions', response.data) +        self.assertEquals(response.data['actions'].keys(), ['GET',]) + +        request = factory.options('/1', content_type='application/json', +                               HTTP_AUTHORIZATION=self.disallowed_credentials) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertIn('actions', response.data) +        self.assertEquals(response.data['actions'].keys(), ['GET',]) + +    def test_options_updateonly(self): +        request = factory.options('/', content_type='application/json', +                               HTTP_AUTHORIZATION=self.updateonly_credentials) +        response = root_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertIn('actions', response.data) +        self.assertEquals(response.data['actions'].keys(), ['GET',]) + +        request = factory.options('/1', content_type='application/json', +                               HTTP_AUTHORIZATION=self.updateonly_credentials) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertIn('actions', response.data) +        self.assertEquals(response.data['actions'].keys(), ['PUT', 'PATCH', 'GET',]) +  class OwnerModel(models.Model):      text = models.CharField(max_length=100) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py index 994cf6dc..2767d24c 100644 --- a/rest_framework/tests/views.py +++ b/rest_framework/tests/views.py @@ -1,12 +1,15 @@  from __future__ import unicode_literals + +import copy +  from django.test import TestCase  from django.test.client import RequestFactory +  from rest_framework import status  from rest_framework.decorators import api_view  from rest_framework.response import Response  from rest_framework.settings import api_settings  from rest_framework.views import APIView -import copy  factory = RequestFactory() diff --git a/rest_framework/views.py b/rest_framework/views.py index 555fa2f4..d1afbe89 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -2,13 +2,16 @@  Provides an APIView class that is the base of all views in REST framework.  """  from __future__ import unicode_literals +  from django.core.exceptions import PermissionDenied  from django.http import Http404, HttpResponse  from django.views.decorators.csrf import csrf_exempt +  from rest_framework import status, exceptions  from rest_framework.compat import View +from rest_framework.fields import humanize_form_fields +from rest_framework.request import clone_request, Request  from rest_framework.response import Response -from rest_framework.request import Request  from rest_framework.settings import api_settings  from rest_framework.utils.formatting import get_view_name, get_view_description @@ -52,19 +55,51 @@ class APIView(View):          }      def metadata(self, request): -        return { +        content = {              'name': get_view_name(self.__class__),              'description': get_view_description(self.__class__),              'renders': [renderer.media_type for renderer in self.renderer_classes],              'parses': [parser.media_type for parser in self.parser_classes],          } -        #  TODO: Add 'fields', from serializer info, if it exists. -        # serializer = self.get_serializer() -        # if serializer is not None: -        #     field_name_types = {} -        #     for name, field in form.fields.iteritems(): -        #         field_name_types[name] = field.__class__.__name__ -        #     content['fields'] = field_name_types +        content['actions'] = self.action_metadata(request) + +        return content + +    def action_metadata(self, request): +        """Return a dictionary with the fields required fo reach allowed method. If no method is allowed, +        return an empty dictionary. + +        :param request: Request for which to return the metadata of the allowed methods. +        :return: A dictionary of the form {method: {field: {field attribute: value}}} +        """ +        actions = {} +        for method in self.allowed_methods: +            # skip HEAD and OPTIONS +            if method in ('HEAD', 'OPTIONS'): +                continue + +            cloned_request = clone_request(request, method) +            try: +                self.check_permissions(cloned_request) + +                # TODO: discuss whether and how to expose parameters like e.g. filter or paginate +                if method in ('GET', 'DELETE'): +                    actions[method] = {} +                    continue + +                if not hasattr(self, 'get_serializer'): +                    continue +                serializer = self.get_serializer() +                if serializer is not None: +                    actions[method] = serializer.humanized +            except exceptions.PermissionDenied: +                # don't add this method +                pass +            except exceptions.NotAuthenticated: +                # don't add this method +                pass + +        return actions if len(actions) > 0 else None      def http_method_not_allowed(self, request, *args, **kwargs):          """  | 
