aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2013-05-24 21:21:56 +0100
committerTom Christie2013-05-24 21:21:56 +0100
commit760e8642bd04b5e03409601a8d378799c36eac1b (patch)
tree69857f97ba5702fe040167a058fae0431abfe588
parent78c53d530ff3d7a4a443b104ad73952d0b5b5b8b (diff)
parenta1deb5eac7d6d00c6269d88fce1cc6818d8ec04a (diff)
downloaddjango-rest-framework-760e8642bd04b5e03409601a8d378799c36eac1b.tar.bz2
Merge branch 'issue-192-expose-fields-for-options' of https://github.com/grimborg/django-rest-framework into improved-options-support
-rw-r--r--rest_framework/fields.py31
-rw-r--r--rest_framework/serializers.py7
-rw-r--r--rest_framework/tests/fields.py81
-rw-r--r--rest_framework/tests/generics.py36
-rw-r--r--rest_framework/tests/permissions.py45
-rw-r--r--rest_framework/tests/views.py5
-rw-r--r--rest_framework/views.py53
7 files changed, 235 insertions, 23 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index d772c400..cb5f9a40 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -23,7 +23,8 @@ from django.utils.translation import ugettext_lazy as _
from django.utils.datastructures import SortedDict
from rest_framework import ISO_8601
-from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time
+from rest_framework.compat import (timezone, parse_date, parse_datetime,
+ parse_time)
from rest_framework.compat import BytesIO
from rest_framework.compat import six
from rest_framework.compat import smart_text, force_text, is_non_str_iterable
@@ -61,7 +62,8 @@ def get_component(obj, attr_name):
def readable_datetime_formats(formats):
- format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]')
+ format = ', '.join(formats).replace(ISO_8601,
+ 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]')
return humanize_strptime(format)
@@ -70,6 +72,18 @@ def readable_date_formats(formats):
return humanize_strptime(format)
+def humanize_form_fields(form):
+ """Return a humanized description of all the fields in a form.
+
+ :param form: A Django form.
+ :return: A dictionary of {field_label: humanized description}
+
+ """
+ fields = SortedDict([(name, humanize_field(field))
+ for name, field in form.fields.iteritems()])
+ return fields
+
+
def readable_time_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
return humanize_strptime(format)
@@ -193,6 +207,19 @@ class Field(object):
return {'type': self.type_name}
return {}
+ @property
+ def humanized(self):
+ humanized = {
+ 'type': self.type_name,
+ 'required': getattr(self, 'required', False),
+ }
+ optional_attrs = ['read_only', 'help_text', 'label',
+ 'min_length', 'max_length']
+ for attr in optional_attrs:
+ if getattr(self, attr, None) is not None:
+ humanized[attr] = getattr(self, attr)
+ return humanized
+
class WritableField(Field):
"""
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 31f261e1..17da8c25 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -521,6 +521,13 @@ class BaseSerializer(WritableField):
return self.object
+ @property
+ def humanized(self):
+ humanized_fields = SortedDict(
+ [(name, field.humanized)
+ for name, field in self.fields.iteritems()])
+ return humanized_fields
+
class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)):
pass
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
index a3104206..22c515a9 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/fields.py
@@ -2,13 +2,20 @@
General serializer field tests.
"""
from __future__ import unicode_literals
-from django.utils.datastructures import SortedDict
-import datetime
+
+from collections import namedtuple
from decimal import Decimal
+from uuid import uuid4
+
+import datetime
+from django import forms
+from django.core import validators
from django.db import models
from django.test import TestCase
-from django.core import validators
+from django.utils.datastructures import SortedDict
+
from rest_framework import serializers
+from rest_framework.fields import Field, CharField
from rest_framework.serializers import Serializer
from rest_framework.tests.models import RESTFrameworkModel
@@ -760,14 +767,16 @@ class SlugFieldTests(TestCase):
def test_given_serializer_value(self):
class SlugFieldSerializer(serializers.ModelSerializer):
- slug_field = serializers.SlugField(source='slug_field', max_length=20, required=False)
+ slug_field = serializers.SlugField(source='slug_field',
+ max_length=20, required=False)
class Meta:
model = self.SlugFieldModel
serializer = SlugFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 20)
+ self.assertEqual(getattr(serializer.fields['slug_field'],
+ 'max_length'), 20)
def test_invalid_slug(self):
"""
@@ -803,7 +812,8 @@ class URLFieldTests(TestCase):
serializer = URLFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['url_field'], 'max_length'), 200)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 200)
def test_given_model_value(self):
class URLFieldSerializer(serializers.ModelSerializer):
@@ -812,15 +822,68 @@ class URLFieldTests(TestCase):
serializer = URLFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['url_field'], 'max_length'), 128)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 128)
def test_given_serializer_value(self):
class URLFieldSerializer(serializers.ModelSerializer):
- url_field = serializers.URLField(source='url_field', max_length=20, required=False)
+ url_field = serializers.URLField(source='url_field',
+ max_length=20, required=False)
class Meta:
model = self.URLFieldWithGivenMaxLengthModel
serializer = URLFieldSerializer(data={})
self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['url_field'], 'max_length'), 20)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 20)
+
+
+class HumanizedField(TestCase):
+ def setUp(self):
+ self.required_field = Field()
+ self.required_field.label = uuid4().hex
+ self.required_field.required = True
+
+ self.optional_field = Field()
+ self.optional_field.label = uuid4().hex
+ self.optional_field.required = False
+
+ def test_type(self):
+ for field in (self.required_field, self.optional_field):
+ self.assertEqual(field.humanized['type'], field.type_name)
+
+ def test_required(self):
+ self.assertEqual(self.required_field.humanized['required'], True)
+
+ def test_optional(self):
+ self.assertEqual(self.optional_field.humanized['required'], False)
+
+ def test_label(self):
+ for field in (self.required_field, self.optional_field):
+ self.assertEqual(field.humanized['label'], field.label)
+
+
+class HumanizableSerializer(Serializer):
+ field1 = CharField(3, required=True)
+ field2 = CharField(10, required=False)
+
+
+class HumanizedSerializer(TestCase):
+ def setUp(self):
+ self.serializer = HumanizableSerializer()
+
+ def test_humanized(self):
+ humanized = self.serializer.humanized
+ expected = {
+ 'field1': {u'required': True,
+ u'max_length': 3,
+ u'type': u'CharField',
+ u'read_only': False},
+ 'field2': {u'required': False,
+ u'max_length': 10,
+ u'type': u'CharField',
+ u'read_only': False}}
+ self.assertEqual(set(expected.keys()), set(humanized.keys()))
+ for k, v in humanized.iteritems():
+ self.assertEqual(v, expected[k])
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index 15d87e86..a2f8fb4b 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -121,7 +121,22 @@ class TestRootView(TestCase):
'text/html'
],
'name': 'Root',
- 'description': 'Example description for OPTIONS.'
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {}
+ }
+ expected['actions']['GET'] = {}
+ expected['actions']['POST'] = {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'String',
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'Integer',
+ },
}
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected)
@@ -238,8 +253,25 @@ class TestInstanceView(TestCase):
'text/html'
],
'name': 'Instance',
- 'description': 'Example description for OPTIONS.'
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {}
}
+ for method in ('GET', 'DELETE'):
+ expected['actions'][method] = {}
+ for method in ('PATCH', 'PUT'):
+ expected['actions'][method] = {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'String',
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'Integer',
+ },
+ }
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected)
diff --git a/rest_framework/tests/permissions.py b/rest_framework/tests/permissions.py
index b3993be5..5a18182b 100644
--- a/rest_framework/tests/permissions.py
+++ b/rest_framework/tests/permissions.py
@@ -108,6 +108,51 @@ class ModelPermissionsIntegrationTests(TestCase):
response = instance_view(request, pk='2')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ def test_options_permitted(self):
+ request = factory.options('/', content_type='application/json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEquals(response.data['actions'].keys(), ['POST', 'GET',])
+
+ request = factory.options('/1', content_type='application/json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEquals(response.data['actions'].keys(), ['PUT', 'PATCH', 'DELETE', 'GET',])
+
+ def test_options_disallowed(self):
+ request = factory.options('/', content_type='application/json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEquals(response.data['actions'].keys(), ['GET',])
+
+ request = factory.options('/1', content_type='application/json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEquals(response.data['actions'].keys(), ['GET',])
+
+ def test_options_updateonly(self):
+ request = factory.options('/', content_type='application/json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEquals(response.data['actions'].keys(), ['GET',])
+
+ request = factory.options('/1', content_type='application/json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEquals(response.data['actions'].keys(), ['PUT', 'PATCH', 'GET',])
+
class OwnerModel(models.Model):
text = models.CharField(max_length=100)
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py
index 994cf6dc..2767d24c 100644
--- a/rest_framework/tests/views.py
+++ b/rest_framework/tests/views.py
@@ -1,12 +1,15 @@
from __future__ import unicode_literals
+
+import copy
+
from django.test import TestCase
from django.test.client import RequestFactory
+
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.views import APIView
-import copy
factory = RequestFactory()
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 555fa2f4..d1afbe89 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -2,13 +2,16 @@
Provides an APIView class that is the base of all views in REST framework.
"""
from __future__ import unicode_literals
+
from django.core.exceptions import PermissionDenied
from django.http import Http404, HttpResponse
from django.views.decorators.csrf import csrf_exempt
+
from rest_framework import status, exceptions
from rest_framework.compat import View
+from rest_framework.fields import humanize_form_fields
+from rest_framework.request import clone_request, Request
from rest_framework.response import Response
-from rest_framework.request import Request
from rest_framework.settings import api_settings
from rest_framework.utils.formatting import get_view_name, get_view_description
@@ -52,19 +55,51 @@ class APIView(View):
}
def metadata(self, request):
- return {
+ content = {
'name': get_view_name(self.__class__),
'description': get_view_description(self.__class__),
'renders': [renderer.media_type for renderer in self.renderer_classes],
'parses': [parser.media_type for parser in self.parser_classes],
}
- # TODO: Add 'fields', from serializer info, if it exists.
- # serializer = self.get_serializer()
- # if serializer is not None:
- # field_name_types = {}
- # for name, field in form.fields.iteritems():
- # field_name_types[name] = field.__class__.__name__
- # content['fields'] = field_name_types
+ content['actions'] = self.action_metadata(request)
+
+ return content
+
+ def action_metadata(self, request):
+ """Return a dictionary with the fields required fo reach allowed method. If no method is allowed,
+ return an empty dictionary.
+
+ :param request: Request for which to return the metadata of the allowed methods.
+ :return: A dictionary of the form {method: {field: {field attribute: value}}}
+ """
+ actions = {}
+ for method in self.allowed_methods:
+ # skip HEAD and OPTIONS
+ if method in ('HEAD', 'OPTIONS'):
+ continue
+
+ cloned_request = clone_request(request, method)
+ try:
+ self.check_permissions(cloned_request)
+
+ # TODO: discuss whether and how to expose parameters like e.g. filter or paginate
+ if method in ('GET', 'DELETE'):
+ actions[method] = {}
+ continue
+
+ if not hasattr(self, 'get_serializer'):
+ continue
+ serializer = self.get_serializer()
+ if serializer is not None:
+ actions[method] = serializer.humanized
+ except exceptions.PermissionDenied:
+ # don't add this method
+ pass
+ except exceptions.NotAuthenticated:
+ # don't add this method
+ pass
+
+ return actions if len(actions) > 0 else None
def http_method_not_allowed(self, request, *args, **kwargs):
"""