diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/__init__.py | 2 | ||||
| -rw-r--r-- | rest_framework/decorators.py | 10 | ||||
| -rw-r--r-- | rest_framework/fields.py | 51 | ||||
| -rw-r--r-- | rest_framework/generics.py | 56 | ||||
| -rw-r--r-- | rest_framework/relations.py | 67 | ||||
| -rw-r--r-- | rest_framework/renderers.py | 30 | ||||
| -rw-r--r-- | rest_framework/response.py | 2 | ||||
| -rw-r--r-- | rest_framework/routers.py | 10 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 28 | ||||
| -rw-r--r-- | rest_framework/templatetags/rest_framework.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/models.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/routers.py | 55 | ||||
| -rw-r--r-- | rest_framework/tests/test_authentication.py (renamed from rest_framework/tests/authentication.py) | 14 | ||||
| -rw-r--r-- | rest_framework/tests/test_breadcrumbs.py (renamed from rest_framework/tests/breadcrumbs.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_decorators.py (renamed from rest_framework/tests/decorators.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_description.py (renamed from rest_framework/tests/description.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_fields.py (renamed from rest_framework/tests/fields.py) | 78 | ||||
| -rw-r--r-- | rest_framework/tests/test_files.py (renamed from rest_framework/tests/files.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_filters.py (renamed from rest_framework/tests/filters.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_genericrelations.py (renamed from rest_framework/tests/genericrelations.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_generics.py (renamed from rest_framework/tests/generics.py) | 56 | ||||
| -rw-r--r-- | rest_framework/tests/test_htmlrenderer.py (renamed from rest_framework/tests/htmlrenderer.py) | 4 | ||||
| -rw-r--r-- | rest_framework/tests/test_hyperlinkedserializers.py (renamed from rest_framework/tests/hyperlinkedserializers.py) | 12 | ||||
| -rw-r--r-- | rest_framework/tests/test_multitable_inheritance.py (renamed from rest_framework/tests/multitable_inheritance.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_negotiation.py (renamed from rest_framework/tests/negotiation.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_pagination.py (renamed from rest_framework/tests/pagination.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_parsers.py (renamed from rest_framework/tests/parsers.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_permissions.py (renamed from rest_framework/tests/permissions.py) | 42 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations.py (renamed from rest_framework/tests/relations.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_hyperlink.py (renamed from rest_framework/tests/relations_hyperlink.py) | 10 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_nested.py (renamed from rest_framework/tests/relations_nested.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_pk.py (renamed from rest_framework/tests/relations_pk.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_slug.py (renamed from rest_framework/tests/relations_slug.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_renderers.py (renamed from rest_framework/tests/renderers.py) | 28 | ||||
| -rw-r--r-- | rest_framework/tests/test_request.py (renamed from rest_framework/tests/request.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_response.py (renamed from rest_framework/tests/response.py) | 8 | ||||
| -rw-r--r-- | rest_framework/tests/test_reverse.py (renamed from rest_framework/tests/reverse.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_routers.py | 121 | ||||
| -rw-r--r-- | rest_framework/tests/test_serializer.py (renamed from rest_framework/tests/serializer.py) | 51 | ||||
| -rw-r--r-- | rest_framework/tests/test_serializer_bulk_update.py (renamed from rest_framework/tests/serializer_bulk_update.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_serializer_nested.py (renamed from rest_framework/tests/serializer_nested.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_settings.py (renamed from rest_framework/tests/settings.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_throttling.py (renamed from rest_framework/tests/throttling.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_urlpatterns.py (renamed from rest_framework/tests/urlpatterns.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_validation.py (renamed from rest_framework/tests/validation.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_views.py (renamed from rest_framework/tests/views.py) | 5 | ||||
| -rw-r--r-- | rest_framework/tests/testcases.py | 66 | ||||
| -rw-r--r-- | rest_framework/tests/tests.py | 6 | ||||
| -rw-r--r-- | rest_framework/utils/encoders.py | 7 | ||||
| -rw-r--r-- | rest_framework/views.py | 31 |
50 files changed, 583 insertions, 281 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 0b1e67fb..59046733 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.3.3' +__version__ = '2.3.4' VERSION = __version__ # synonym diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 81e585e1..c69756a4 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,5 +1,5 @@ """ -The most imporant decorator in this module is `@api_view`, which is used +The most important decorator in this module is `@api_view`, which is used for writing function-based views with REST framework. There are also various decorators for setting the API policies on function @@ -40,7 +40,7 @@ def api_view(http_method_names): # api_view applied with eg. string instead of list of strings assert isinstance(http_method_names, (list, tuple)), \ - '@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__ + '@api_view expected a list of strings, received %s' % type(http_method_names).__name__ allowed_methods = set(http_method_names) | set(('options',)) WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] @@ -112,18 +112,18 @@ def link(**kwargs): Used to mark a method on a ViewSet that should be routed for GET requests. """ def decorator(func): - func.bind_to_method = 'get' + func.bind_to_methods = ['get'] func.kwargs = kwargs return func return decorator -def action(**kwargs): +def action(methods=['post'], **kwargs): """ Used to mark a method on a ViewSet that should be routed for POST requests. """ def decorator(func): - func.bind_to_method = 'post' + func.bind_to_methods = methods func.kwargs = kwargs return func return decorator diff --git a/rest_framework/fields.py b/rest_framework/fields.py index d772c400..535aa2ac 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -11,7 +11,6 @@ from decimal import Decimal, DecimalException import inspect import re import warnings - from django.core import validators from django.core.exceptions import ValidationError from django.conf import settings @@ -21,9 +20,9 @@ from django.forms import widgets from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ from django.utils.datastructures import SortedDict - from rest_framework import ISO_8601 -from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time +from rest_framework.compat import (timezone, parse_date, parse_datetime, + parse_time) from rest_framework.compat import BytesIO from rest_framework.compat import six from rest_framework.compat import smart_text, force_text, is_non_str_iterable @@ -45,6 +44,7 @@ def is_simple_callable(obj): len_defaults = len(defaults) if defaults else 0 return len_args <= len_defaults + def get_component(obj, attr_name): """ Given an object, and an attribute name, @@ -61,7 +61,8 @@ def get_component(obj, attr_name): def readable_datetime_formats(formats): - format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]') + format = ', '.join(formats).replace(ISO_8601, + 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]') return humanize_strptime(format) @@ -108,6 +109,7 @@ class Field(object): partial = False use_files = False form_field_class = forms.CharField + type_label = 'field' def __init__(self, source=None, label=None, help_text=None): self.parent = None @@ -193,6 +195,18 @@ class Field(object): return {'type': self.type_name} return {} + def metadata(self): + metadata = SortedDict() + metadata['type'] = self.type_label + metadata['required'] = getattr(self, 'required', False) + optional_attrs = ['read_only', 'label', 'help_text', + 'min_length', 'max_length'] + for attr in optional_attrs: + value = getattr(self, attr, None) + if value is not None and value != '': + metadata[attr] = force_text(value, strings_only=True) + return metadata + class WritableField(Field): """ @@ -281,7 +295,10 @@ class WritableField(Field): except KeyError: if self.default is not None and not self.partial: # Note: partial updates shouldn't set defaults - native = self.default + if is_simple_callable(self.default): + native = self.default() + else: + native = self.default else: if self.required: raise ValidationError(self.error_messages['required']) @@ -348,6 +365,7 @@ class ModelField(WritableField): class BooleanField(WritableField): type_name = 'BooleanField' + type_label = 'boolean' form_field_class = forms.BooleanField widget = widgets.CheckboxInput default_error_messages = { @@ -370,6 +388,7 @@ class BooleanField(WritableField): class CharField(WritableField): type_name = 'CharField' + type_label = 'string' form_field_class = forms.CharField def __init__(self, max_length=None, min_length=None, *args, **kwargs): @@ -388,6 +407,7 @@ class CharField(WritableField): class URLField(CharField): type_name = 'URLField' + type_label = 'url' def __init__(self, **kwargs): kwargs['validators'] = [validators.URLValidator()] @@ -396,14 +416,15 @@ class URLField(CharField): class SlugField(CharField): type_name = 'SlugField' + type_label = 'slug' form_field_class = forms.SlugField - + default_error_messages = { 'invalid': _("Enter a valid 'slug' consisting of letters, numbers," " underscores or hyphens."), } default_validators = [validators.validate_slug] - + def __init__(self, *args, **kwargs): super(SlugField, self).__init__(*args, **kwargs) @@ -413,10 +434,11 @@ class SlugField(CharField): #result.widget = copy.deepcopy(self.widget, memo) result.validators = self.validators[:] return result - - + + class ChoiceField(WritableField): type_name = 'ChoiceField' + type_label = 'multiple choice' form_field_class = forms.ChoiceField widget = widgets.Select default_error_messages = { @@ -467,6 +489,7 @@ class ChoiceField(WritableField): class EmailField(CharField): type_name = 'EmailField' + type_label = 'email' form_field_class = forms.EmailField default_error_messages = { @@ -490,6 +513,7 @@ class EmailField(CharField): class RegexField(CharField): type_name = 'RegexField' + type_label = 'regex' form_field_class = forms.RegexField def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs): @@ -519,6 +543,7 @@ class RegexField(CharField): class DateField(WritableField): type_name = 'DateField' + type_label = 'date' widget = widgets.DateInput form_field_class = forms.DateField @@ -582,6 +607,7 @@ class DateField(WritableField): class DateTimeField(WritableField): type_name = 'DateTimeField' + type_label = 'datetime' widget = widgets.DateTimeInput form_field_class = forms.DateTimeField @@ -651,6 +677,7 @@ class DateTimeField(WritableField): class TimeField(WritableField): type_name = 'TimeField' + type_label = 'time' widget = widgets.TimeInput form_field_class = forms.TimeField @@ -707,6 +734,7 @@ class TimeField(WritableField): class IntegerField(WritableField): type_name = 'IntegerField' + type_label = 'integer' form_field_class = forms.IntegerField default_error_messages = { @@ -737,6 +765,7 @@ class IntegerField(WritableField): class FloatField(WritableField): type_name = 'FloatField' + type_label = 'float' form_field_class = forms.FloatField default_error_messages = { @@ -756,6 +785,7 @@ class FloatField(WritableField): class DecimalField(WritableField): type_name = 'DecimalField' + type_label = 'decimal' form_field_class = forms.DecimalField default_error_messages = { @@ -826,6 +856,7 @@ class DecimalField(WritableField): class FileField(WritableField): use_files = True type_name = 'FileField' + type_label = 'file upload' form_field_class = forms.FileField widget = widgets.FileInput @@ -869,6 +900,8 @@ class FileField(WritableField): class ImageField(FileField): use_files = True + type_name = 'ImageField' + type_label = 'image upload' form_field_class = forms.ImageField default_error_messages = { diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 05ec93d3..9ccc7898 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -3,17 +3,28 @@ Generic views that provide commonly needed behaviour. """ from __future__ import unicode_literals -from django.core.exceptions import ImproperlyConfigured +from django.core.exceptions import ImproperlyConfigured, PermissionDenied from django.core.paginator import Paginator, InvalidPage from django.http import Http404 -from django.shortcuts import get_object_or_404 +from django.shortcuts import get_object_or_404 as _get_object_or_404 from django.utils.translation import ugettext as _ -from rest_framework import views, mixins -from rest_framework.exceptions import ConfigurationError +from rest_framework import views, mixins, exceptions +from rest_framework.request import clone_request from rest_framework.settings import api_settings import warnings +def get_object_or_404(queryset, **filter_kwargs): + """ + Same as Django's standard shortcut, but make sure to raise 404 + if the filter_kwargs don't match the required types. + """ + try: + return _get_object_or_404(queryset, **filter_kwargs) + except (TypeError, ValueError): + raise Http404 + + class GenericAPIView(views.APIView): """ Base class for all other generic views. @@ -274,7 +285,7 @@ class GenericAPIView(views.APIView): ) filter_kwargs = {self.slug_field: slug} else: - raise ConfigurationError( + raise exceptions.ConfigurationError( 'Expected view %s to be called with a URL keyword argument ' 'named "%s". Fix your URL conf, or set the `.lookup_field` ' 'attribute on the view correctly.' % @@ -310,6 +321,41 @@ class GenericAPIView(views.APIView): """ pass + def metadata(self, request): + """ + Return a dictionary of metadata about the view. + Used to return responses for OPTIONS requests. + + We override the default behavior, and add some extra information + about the required request body for POST and PUT operations. + """ + ret = super(GenericAPIView, self).metadata(request) + + actions = {} + for method in ('PUT', 'POST'): + if method not in self.allowed_methods: + continue + + cloned_request = clone_request(request, method) + try: + # Test global permissions + self.check_permissions(cloned_request) + # Test object permissions + if method == 'PUT': + self.get_object() + except (exceptions.APIException, PermissionDenied, Http404): + pass + else: + # If user has appropriate permissions for the view, include + # appropriate metadata about the fields that should be supplied. + serializer = self.get_serializer() + actions[method] = serializer.metadata() + + if actions: + ret['actions'] = actions + + return ret + ########################################################## ### Concrete view classes that provide method handlers ### diff --git a/rest_framework/relations.py b/rest_framework/relations.py index c4271e33..e3675b51 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -72,7 +72,6 @@ class RelatedField(WritableField): else: # Reverse self.queryset = manager.field.rel.to._default_manager.all() except Exception: - raise msg = ('Serializer related fields must include a `queryset`' + ' argument or set `read_only=True') raise Exception(msg) @@ -488,13 +487,15 @@ class HyperlinkedIdentityField(Field): slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden def __init__(self, *args, **kwargs): - # TODO: Make view_name mandatory, and have the - # HyperlinkedModelSerializer set it on-the-fly - self.view_name = kwargs.pop('view_name', None) - # Optionally the format of the target hyperlink may be specified - self.format = kwargs.pop('format', None) + try: + self.view_name = kwargs.pop('view_name') + except KeyError: + msg = "HyperlinkedIdentityField requires 'view_name' argument" + raise ValueError(msg) - self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + self.format = kwargs.pop('format', None) + lookup_field = kwargs.pop('lookup_field', None) + self.lookup_field = lookup_field or self.lookup_field # These are pending deprecation if 'pk_url_kwarg' in kwargs: @@ -517,9 +518,7 @@ class HyperlinkedIdentityField(Field): def field_to_native(self, obj, field_name): request = self.context.get('request', None) format = self.context.get('format', None) - view_name = self.view_name or self.parent.opts.view_name - lookup_field = getattr(obj, self.lookup_field) - kwargs = {self.lookup_field: lookup_field} + view_name = self.view_name if request is None: warnings.warn("Using `HyperlinkedIdentityField` without including the " @@ -539,29 +538,51 @@ class HyperlinkedIdentityField(Field): if format and self.format and self.format != format: format = self.format + # Return the hyperlink, or error if incorrectly configured. try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) + return self.get_url(obj, view_name, request, format) except NoReverseMatch: - pass - - slug = getattr(obj, self.slug_field, None) + msg = ( + 'Could not resolve URL for hyperlinked relationship using ' + 'view name "%s". You may have failed to include the related ' + 'model in your API, or incorrectly configured the ' + '`lookup_field` attribute on this field.' + ) + raise Exception(msg % view_name) - if not slug: - raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + def get_url(self, obj, view_name, request, format): + """ + Given an object, return the URL that hyperlinks to the object. - kwargs = {self.slug_url_kwarg: slug} + May raise a `NoReverseMatch` if the `view_name` and `lookup_field` + attributes are not configured to correctly match the URL conf. + """ + lookup_field = getattr(obj, self.lookup_field) + kwargs = {self.lookup_field: lookup_field} try: return reverse(view_name, kwargs=kwargs, request=request, format=format) except NoReverseMatch: pass - kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass + if self.pk_url_kwarg != 'pk': + # Only try pk lookup if it has been explicitly set. + # Otherwise, the default `lookup_field = 'pk'` has us covered. + kwargs = {self.pk_url_kwarg: obj.pk} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass + + slug = getattr(obj, self.slug_field, None) + if slug: + # Only use slug lookup if a slug field exists on the model + kwargs = {self.slug_url_kwarg: slug} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass - raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + raise NoReverseMatch() ### Old-style many classes for backwards compat diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index b4fa55bd..b2fe43ea 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -43,18 +43,21 @@ class BaseRenderer(object): class JSONRenderer(BaseRenderer): """ - Renderer which serializes to json. + Renderer which serializes to JSON. + Applies JSON's backslash-u character escaping for non-ascii characters. """ media_type = 'application/json' format = 'json' encoder_class = encoders.JSONEncoder ensure_ascii = True - charset = 'iso-8859-1' + charset = 'utf-8' + # Note that JSON encodings must be utf-8, utf-16 or utf-32. + # See: http://www.ietf.org/rfc/rfc4627.txt def render(self, data, accepted_media_type=None, renderer_context=None): """ - Render `obj` into json. + Render `data` into JSON. """ if data is None: return '' @@ -77,7 +80,11 @@ class JSONRenderer(BaseRenderer): ret = json.dumps(data, cls=self.encoder_class, indent=indent, ensure_ascii=self.ensure_ascii) - if not self.ensure_ascii: + # On python 2.x json.dumps() returns bytestrings if ensure_ascii=True, + # but if ensure_ascii=False, the return type is underspecified, + # and may (or may not) be unicode. + # On python 3.x json.dumps() returns unicode strings. + if isinstance(ret, six.text_type): return bytes(ret.encode(self.charset)) return ret @@ -85,6 +92,10 @@ class JSONRenderer(BaseRenderer): class UnicodeJSONRenderer(JSONRenderer): ensure_ascii = False charset = 'utf-8' + """ + Renderer which serializes to JSON. + Does *not* apply JSON's character escaping for non-ascii characters. + """ class JSONPRenderer(JSONRenderer): @@ -117,7 +128,7 @@ class JSONPRenderer(JSONRenderer): callback = self.get_callback(renderer_context) json = super(JSONPRenderer, self).render(data, accepted_media_type, renderer_context) - return "%s(%s);" % (callback, json) + return callback.encode(self.charset) + b'(' + json + b');' class XMLRenderer(BaseRenderer): @@ -138,7 +149,7 @@ class XMLRenderer(BaseRenderer): stream = StringIO() - xml = SimplerXMLGenerator(stream, "utf-8") + xml = SimplerXMLGenerator(stream, self.charset) xml.startDocument() xml.startElement("root", {}) @@ -188,7 +199,7 @@ class YAMLRenderer(BaseRenderer): if data is None: return '' - return yaml.dump(data, stream=None, Dumper=self.encoder) + return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder) class TemplateHTMLRenderer(BaseRenderer): @@ -496,10 +507,7 @@ class BrowsableAPIRenderer(BaseRenderer): def render(self, data, accepted_media_type=None, renderer_context=None): """ - Renders *obj* using the :attr:`template` set on the class. - - The context used in the template contains all the information - needed to self-document the response to this request. + Render the HTML for the browsable API representation. """ accepted_media_type = accepted_media_type or '' renderer_context = renderer_context or {} diff --git a/rest_framework/response.py b/rest_framework/response.py index 110ccb13..3ee52ae0 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -1,5 +1,5 @@ """ -The Response class in REST framework is similiar to HTTPResponse, except that +The Response class in REST framework is similar to HTTPResponse, except that it is initialized with unrendered data, instead of a pre-rendered string. The appropriate renderer is called during Django's template response rendering. diff --git a/rest_framework/routers.py b/rest_framework/routers.py index dba104c3..6c5fd004 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -131,20 +131,20 @@ class SimpleRouter(BaseRouter): dynamic_routes = [] for methodname in dir(viewset): attr = getattr(viewset, methodname) - httpmethod = getattr(attr, 'bind_to_method', None) - if httpmethod: - dynamic_routes.append((httpmethod, methodname)) + httpmethods = getattr(attr, 'bind_to_methods', None) + if httpmethods: + dynamic_routes.append((httpmethods, methodname)) ret = [] for route in self.routes: if route.mapping == {'{httpmethod}': '{methodname}'}: # Dynamic routes (@link or @action decorator) - for httpmethod, methodname in dynamic_routes: + for httpmethods, methodname in dynamic_routes: initkwargs = route.initkwargs.copy() initkwargs.update(getattr(viewset, methodname).kwargs) ret.append(Route( url=replace_methodname(route.url, methodname), - mapping={httpmethod: methodname}, + mapping=dict((httpmethod, methodname) for httpmethod in httpmethods), name=replace_methodname(route.name, methodname), initkwargs=initkwargs, )) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 31f261e1..a4969f60 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -25,7 +25,7 @@ from rest_framework.compat import get_concrete_model, six # # example_field = serializers.CharField(...) # -# This helps keep the seperation between model fields, form fields, and +# This helps keep the separation between model fields, form fields, and # serializer fields more explicit. from rest_framework.relations import * @@ -58,7 +58,7 @@ class DictWithMetadata(dict): def __getstate__(self): """ Used by pickle (e.g., caching). - Overriden to remove the metadata from the dict, since it shouldn't be + Overridden to remove the metadata from the dict, since it shouldn't be pickled and may in some instances be unpickleable. """ return dict(self) @@ -521,6 +521,17 @@ class BaseSerializer(WritableField): return self.object + def metadata(self): + """ + Return a dictionary of metadata about the fields on the serializer. + Useful for things like responding to OPTIONS requests, or generating + API schemas for auto-documentation. + """ + return SortedDict( + [(field_name, field.metadata()) + for field_name, field in six.iteritems(self.fields)] + ) + class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)): pass @@ -892,13 +903,24 @@ class HyperlinkedModelSerializer(ModelSerializer): _default_view_name = '%(model_name)s-detail' _hyperlink_field_class = HyperlinkedRelatedField - url = HyperlinkedIdentityField() + # Just a placeholder to ensure 'url' is the first field + # The field itself is actually created on initialization, + # when the view_name and lookup_field arguments are available. + url = Field() def __init__(self, *args, **kwargs): super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs) + if self.opts.view_name is None: self.opts.view_name = self._get_default_view_name(self.opts.model) + url_field = HyperlinkedIdentityField( + view_name=self.opts.view_name, + lookup_field=self.opts.lookup_field + ) + url_field.initialize(self, 'url') + self.fields['url'] = url_field + def _get_default_view_name(self, model): """ Return the view name to use if 'view_name' is not specified in 'Meta' diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index c86b6456..e9c1cdd5 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -15,7 +15,7 @@ register = template.Library() # When 1.3 becomes unsupported by REST framework, we can instead start to # use the {% load staticfiles %} tag, remove the following code, -# and add a dependancy that `django.contrib.staticfiles` must be installed. +# and add a dependency that `django.contrib.staticfiles` must be installed. # Note: We can't put this into the `compat` module because the compat import # from rest_framework.compat import ... diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index abf50a2d..e2d4eacd 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -162,8 +162,8 @@ class NullableOneToOneSource(RESTFrameworkModel): target = models.OneToOneField(OneToOneTarget, null=True, blank=True, related_name='nullable_source') + # Serializer used to test BasicModel class BasicModelSerializer(serializers.ModelSerializer): class Meta: model = BasicModel - diff --git a/rest_framework/tests/routers.py b/rest_framework/tests/routers.py deleted file mode 100644 index 4e4765cb..00000000 --- a/rest_framework/tests/routers.py +++ /dev/null @@ -1,55 +0,0 @@ -from __future__ import unicode_literals -from django.test import TestCase -from django.test.client import RequestFactory -from rest_framework import status -from rest_framework.response import Response -from rest_framework import viewsets -from rest_framework.decorators import link, action -from rest_framework.routers import SimpleRouter -import copy - -factory = RequestFactory() - - -class BasicViewSet(viewsets.ViewSet): - def list(self, request, *args, **kwargs): - return Response({'method': 'list'}) - - @action() - def action1(self, request, *args, **kwargs): - return Response({'method': 'action1'}) - - @action() - def action2(self, request, *args, **kwargs): - return Response({'method': 'action2'}) - - @link() - def link1(self, request, *args, **kwargs): - return Response({'method': 'link1'}) - - @link() - def link2(self, request, *args, **kwargs): - return Response({'method': 'link2'}) - - -class TestSimpleRouter(TestCase): - def setUp(self): - self.router = SimpleRouter() - - def test_link_and_action_decorator(self): - routes = self.router.get_routes(BasicViewSet) - decorator_routes = routes[2:] - # Make sure all these endpoints exist and none have been clobbered - for i, endpoint in enumerate(['action1', 'action2', 'link1', 'link2']): - route = decorator_routes[i] - # check url listing - self.assertEqual(route.url, - '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint)) - # check method to function mapping - if endpoint.startswith('action'): - method_map = 'post' - else: - method_map = 'get' - self.assertEqual(route.mapping[method_map], endpoint) - - diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/test_authentication.py index 8e6d3e51..05e9fbc3 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -48,7 +48,7 @@ urlpatterns = patterns('', (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), - (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], + (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], permission_classes=[permissions.TokenHasReadWriteScope])) ) @@ -56,14 +56,14 @@ if oauth2_provider is not None: urlpatterns += patterns('', url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), - url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], + url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], permission_classes=[permissions.TokenHasReadWriteScope])), ) class BasicAuthTests(TestCase): """Basic authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) @@ -102,7 +102,7 @@ class BasicAuthTests(TestCase): class SessionAuthTests(TestCase): """User session authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) @@ -149,7 +149,7 @@ class SessionAuthTests(TestCase): class TokenAuthTests(TestCase): """Token authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) @@ -243,7 +243,7 @@ class IncorrectCredentialsTests(TestCase): class OAuthTests(TestCase): """OAuth 1.0a authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): # these imports are here because oauth is optional and hiding them in try..except block or compat @@ -429,7 +429,7 @@ class OAuthTests(TestCase): class OAuth2Tests(TestCase): """OAuth 2.0 authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/test_breadcrumbs.py index d9ed647e..41ddf2ce 100644 --- a/rest_framework/tests/breadcrumbs.py +++ b/rest_framework/tests/test_breadcrumbs.py @@ -36,7 +36,7 @@ urlpatterns = patterns('', class BreadcrumbTests(TestCase): """Tests the breadcrumb functionality used by the HTML renderer.""" - urls = 'rest_framework.tests.breadcrumbs' + urls = 'rest_framework.tests.test_breadcrumbs' def test_root_breadcrumbs(self): url = '/' diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/test_decorators.py index 1016fed3..1016fed3 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/test_decorators.py diff --git a/rest_framework/tests/description.py b/rest_framework/tests/test_description.py index 52c1a34c..52c1a34c 100644 --- a/rest_framework/tests/description.py +++ b/rest_framework/tests/test_description.py diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/test_fields.py index a3104206..de371001 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/test_fields.py @@ -2,14 +2,15 @@ General serializer field tests. """ from __future__ import unicode_literals -from django.utils.datastructures import SortedDict + import datetime from decimal import Decimal +from uuid import uuid4 +from django.core import validators from django.db import models from django.test import TestCase -from django.core import validators +from django.utils.datastructures import SortedDict from rest_framework import serializers -from rest_framework.serializers import Serializer from rest_framework.tests.models import RESTFrameworkModel @@ -587,7 +588,7 @@ class DecimalFieldTest(TestCase): """ Make sure the serializer works correctly """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_value=9010, min_value=9000, max_digits=6, @@ -605,7 +606,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_value violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_value=100) s = DecimalSerializer(data={'decimal_field': '123'}) @@ -617,7 +618,7 @@ class DecimalFieldTest(TestCase): """ Make sure min_value violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(min_value=100) s = DecimalSerializer(data={'decimal_field': '99'}) @@ -629,7 +630,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_digits violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_digits=5) s = DecimalSerializer(data={'decimal_field': '123.456'}) @@ -641,7 +642,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_decimal_places violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(decimal_places=3) s = DecimalSerializer(data={'decimal_field': '123.4567'}) @@ -653,7 +654,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_whole_digits violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) s = DecimalSerializer(data={'decimal_field': '12345.6'}) @@ -760,14 +761,16 @@ class SlugFieldTests(TestCase): def test_given_serializer_value(self): class SlugFieldSerializer(serializers.ModelSerializer): - slug_field = serializers.SlugField(source='slug_field', max_length=20, required=False) + slug_field = serializers.SlugField(source='slug_field', + max_length=20, required=False) class Meta: model = self.SlugFieldModel serializer = SlugFieldSerializer(data={}) self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 20) + self.assertEqual(getattr(serializer.fields['slug_field'], + 'max_length'), 20) def test_invalid_slug(self): """ @@ -775,12 +778,12 @@ class SlugFieldTests(TestCase): """ class SlugFieldSerializer(serializers.ModelSerializer): slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True) - + class Meta: model = self.SlugFieldModel - + s = SlugFieldSerializer(data={'slug_field': 'a b'}) - + self.assertEqual(s.is_valid(), False) self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]}) @@ -803,7 +806,8 @@ class URLFieldTests(TestCase): serializer = URLFieldSerializer(data={}) self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['url_field'], 'max_length'), 200) + self.assertEqual(getattr(serializer.fields['url_field'], + 'max_length'), 200) def test_given_model_value(self): class URLFieldSerializer(serializers.ModelSerializer): @@ -812,15 +816,53 @@ class URLFieldTests(TestCase): serializer = URLFieldSerializer(data={}) self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['url_field'], 'max_length'), 128) + self.assertEqual(getattr(serializer.fields['url_field'], + 'max_length'), 128) def test_given_serializer_value(self): class URLFieldSerializer(serializers.ModelSerializer): - url_field = serializers.URLField(source='url_field', max_length=20, required=False) + url_field = serializers.URLField(source='url_field', + max_length=20, required=False) class Meta: model = self.URLFieldWithGivenMaxLengthModel serializer = URLFieldSerializer(data={}) self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['url_field'], 'max_length'), 20) + self.assertEqual(getattr(serializer.fields['url_field'], + 'max_length'), 20) + + +class FieldMetadata(TestCase): + def setUp(self): + self.required_field = serializers.Field() + self.required_field.label = uuid4().hex + self.required_field.required = True + + self.optional_field = serializers.Field() + self.optional_field.label = uuid4().hex + self.optional_field.required = False + + def test_required(self): + self.assertEqual(self.required_field.metadata()['required'], True) + + def test_optional(self): + self.assertEqual(self.optional_field.metadata()['required'], False) + + def test_label(self): + for field in (self.required_field, self.optional_field): + self.assertEqual(field.metadata()['label'], field.label) + + +class FieldCallableDefault(TestCase): + def setUp(self): + self.simple_callable = lambda: 'foo bar' + + def test_default_can_be_simple_callable(self): + """ + Ensure that the 'default' argument can also be a simple callable. + """ + field = serializers.WritableField(default=self.simple_callable) + into = {} + field.field_from_native({}, {}, 'field', into) + self.assertEquals(into, {'field': 'foo bar'}) diff --git a/rest_framework/tests/files.py b/rest_framework/tests/test_files.py index 487046ac..487046ac 100644 --- a/rest_framework/tests/files.py +++ b/rest_framework/tests/test_files.py diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/test_filters.py index 8ae6d530..aaed6247 100644 --- a/rest_framework/tests/filters.py +++ b/rest_framework/tests/test_filters.py @@ -243,7 +243,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): """ Integration tests for filtered detail views. """ - urls = 'rest_framework.tests.filters' + urls = 'rest_framework.tests.test_filters' def _get_url(self, item): return reverse('detail-view', kwargs=dict(pk=item.pk)) diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/test_genericrelations.py index c38bfb9f..c38bfb9f 100644 --- a/rest_framework/tests/genericrelations.py +++ b/rest_framework/tests/test_genericrelations.py diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/test_generics.py index 15d87e86..37734195 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/test_generics.py @@ -121,7 +121,25 @@ class TestRootView(TestCase): 'text/html' ], 'name': 'Root', - 'description': 'Example description for OPTIONS.' + 'description': 'Example description for OPTIONS.', + 'actions': { + 'POST': { + 'text': { + 'max_length': 100, + 'read_only': False, + 'required': True, + 'type': 'string', + "label": "Text comes here", + "help_text": "Text description." + }, + 'id': { + 'read_only': True, + 'required': False, + 'type': 'integer', + 'label': 'ID', + }, + } + } } self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, expected) @@ -224,9 +242,9 @@ class TestInstanceView(TestCase): """ OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata """ - request = factory.options('/') - with self.assertNumQueries(0): - response = self.view(request).render() + request = factory.options('/1') + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() expected = { 'parses': [ 'application/json', @@ -238,11 +256,39 @@ class TestInstanceView(TestCase): 'text/html' ], 'name': 'Instance', - 'description': 'Example description for OPTIONS.' + 'description': 'Example description for OPTIONS.', + 'actions': { + 'PUT': { + 'text': { + 'max_length': 100, + 'read_only': False, + 'required': True, + 'type': 'string', + 'label': 'Text comes here', + 'help_text': 'Text description.' + }, + 'id': { + 'read_only': True, + 'required': False, + 'type': 'integer', + 'label': 'ID', + }, + } + } } self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, expected) + def test_get_instance_view_incorrect_arg(self): + """ + GET requests with an incorrect pk type, should raise 404, not 500. + Regression test for #890. + """ + request = factory.get('/a') + with self.assertNumQueries(0): + response = self.view(request, pk='a').render() + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + def test_put_cannot_set_id(self): """ PUT requests to create a new object should not be able to set the id. diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py index 5d18a6e8..8957a43c 100644 --- a/rest_framework/tests/htmlrenderer.py +++ b/rest_framework/tests/test_htmlrenderer.py @@ -42,7 +42,7 @@ urlpatterns = patterns('', class TemplateHTMLRendererTests(TestCase): - urls = 'rest_framework.tests.htmlrenderer' + urls = 'rest_framework.tests.test_htmlrenderer' def setUp(self): """ @@ -82,7 +82,7 @@ class TemplateHTMLRendererTests(TestCase): class TemplateHTMLRendererExceptionTests(TestCase): - urls = 'rest_framework.tests.htmlrenderer' + urls = 'rest_framework.tests.test_htmlrenderer' def setUp(self): """ diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py index 8fc6ba77..1894ddb2 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/test_hyperlinkedserializers.py @@ -106,7 +106,7 @@ urlpatterns = patterns('', class TestBasicHyperlinkedView(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -143,7 +143,7 @@ class TestBasicHyperlinkedView(TestCase): class TestManyToManyHyperlinkedView(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -191,7 +191,7 @@ class TestManyToManyHyperlinkedView(TestCase): class TestHyperlinkedIdentityFieldLookup(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -221,7 +221,7 @@ class TestHyperlinkedIdentityFieldLookup(TestCase): class TestCreateWithForeignKeys(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -246,7 +246,7 @@ class TestCreateWithForeignKeys(TestCase): class TestCreateWithForeignKeysAndCustomSlug(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -271,7 +271,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase): class TestOptionalRelationHyperlinkedView(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ diff --git a/rest_framework/tests/multitable_inheritance.py b/rest_framework/tests/test_multitable_inheritance.py index 00c15327..00c15327 100644 --- a/rest_framework/tests/multitable_inheritance.py +++ b/rest_framework/tests/test_multitable_inheritance.py diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/test_negotiation.py index 7f84827f..7f84827f 100644 --- a/rest_framework/tests/negotiation.py +++ b/rest_framework/tests/test_negotiation.py diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/test_pagination.py index e538a78e..e538a78e 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/test_pagination.py diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/test_parsers.py index 7699e10c..7699e10c 100644 --- a/rest_framework/tests/parsers.py +++ b/rest_framework/tests/test_parsers.py diff --git a/rest_framework/tests/permissions.py b/rest_framework/tests/test_permissions.py index b3993be5..6caaf65b 100644 --- a/rest_framework/tests/permissions.py +++ b/rest_framework/tests/test_permissions.py @@ -108,6 +108,48 @@ class ModelPermissionsIntegrationTests(TestCase): response = instance_view(request, pk='2') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + def test_options_permitted(self): + request = factory.options('/', content_type='application/json', + HTTP_AUTHORIZATION=self.permitted_credentials) + response = root_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('actions', response.data) + self.assertEqual(list(response.data['actions'].keys()), ['POST']) + + request = factory.options('/1', content_type='application/json', + HTTP_AUTHORIZATION=self.permitted_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('actions', response.data) + self.assertEqual(list(response.data['actions'].keys()), ['PUT']) + + def test_options_disallowed(self): + request = factory.options('/', content_type='application/json', + HTTP_AUTHORIZATION=self.disallowed_credentials) + response = root_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertNotIn('actions', response.data) + + request = factory.options('/1', content_type='application/json', + HTTP_AUTHORIZATION=self.disallowed_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertNotIn('actions', response.data) + + def test_options_updateonly(self): + request = factory.options('/', content_type='application/json', + HTTP_AUTHORIZATION=self.updateonly_credentials) + response = root_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertNotIn('actions', response.data) + + request = factory.options('/1', content_type='application/json', + HTTP_AUTHORIZATION=self.updateonly_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('actions', response.data) + self.assertEqual(list(response.data['actions'].keys()), ['PUT']) + class OwnerModel(models.Model): text = models.CharField(max_length=100) diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/test_relations.py index d19219c9..d19219c9 100644 --- a/rest_framework/tests/relations.py +++ b/rest_framework/tests/test_relations.py diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py index b3efbf52..2ca7f4f2 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/test_relations_hyperlink.py @@ -71,7 +71,7 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): # TODO: Add test that .data cannot be accessed prior to .is_valid class HyperlinkedManyToManyTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def setUp(self): for idx in range(1, 4): @@ -179,7 +179,7 @@ class HyperlinkedManyToManyTests(TestCase): class HyperlinkedForeignKeyTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def setUp(self): target = ForeignKeyTarget(name='target-1') @@ -307,7 +307,7 @@ class HyperlinkedForeignKeyTests(TestCase): class HyperlinkedNullableForeignKeyTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def setUp(self): target = ForeignKeyTarget(name='target-1') @@ -435,7 +435,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase): class HyperlinkedNullableOneToOneTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def setUp(self): target = OneToOneTarget(name='target-1') @@ -458,7 +458,7 @@ class HyperlinkedNullableOneToOneTests(TestCase): # Regression tests for #694 (`source` attribute on related fields) class HyperlinkedRelatedFieldSourceTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def test_related_manager_source(self): """ diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/test_relations_nested.py index f6d006b3..f6d006b3 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/test_relations_pk.py index e2a1b815..e2a1b815 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/test_relations_pk.py diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/test_relations_slug.py index 435c821c..435c821c 100644 --- a/rest_framework/tests/relations_slug.py +++ b/rest_framework/tests/test_relations_slug.py diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/test_renderers.py index 1b2b9279..95b59741 100644 --- a/rest_framework/tests/renderers.py +++ b/rest_framework/tests/test_renderers.py @@ -6,6 +6,7 @@ from django.core.cache import cache from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest +from django.utils.translation import ugettext_lazy as _ from rest_framework import status, permissions from rest_framework.compat import yaml, etree, patterns, url, include from rest_framework.response import Response @@ -29,7 +30,7 @@ RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') expected_results = [ - ((elem for elem in [1, 2, 3]), JSONRenderer, '[1, 2, 3]') # Generator + ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]') # Generator ] @@ -132,7 +133,7 @@ class RendererEndToEndTests(TestCase): End-to-end testing of renderers using an RendererMixin on a generic view. """ - urls = 'rest_framework.tests.renderers' + urls = 'rest_framework.tests.test_renderers' def test_default_renderer_serializes_content(self): """If the Accept header is not set the default renderer should serialize the response.""" @@ -238,6 +239,13 @@ class JSONRendererTests(TestCase): Tests specific to the JSON Renderer """ + def test_render_lazy_strings(self): + """ + JSONRenderer should deal with lazy translated strings. + """ + ret = JSONRenderer().render(_('test')) + self.assertEqual(ret, b'"test"') + def test_without_content_type_args(self): """ Test basic JSON rendering. @@ -246,7 +254,7 @@ class JSONRendererTests(TestCase): renderer = JSONRenderer() content = renderer.render(obj, 'application/json') # Fix failing test case which depends on version of JSON library. - self.assertEqual(content, _flat_repr) + self.assertEqual(content.decode('utf-8'), _flat_repr) def test_with_content_type_args(self): """ @@ -255,13 +263,13 @@ class JSONRendererTests(TestCase): obj = {'foo': ['bar', 'baz']} renderer = JSONRenderer() content = renderer.render(obj, 'application/json; indent=2') - self.assertEqual(strip_trailing_whitespace(content), _indented_repr) + self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr) def test_check_ascii(self): obj = {'countries': ['United Kingdom', 'France', 'EspaƱa']} renderer = JSONRenderer() content = renderer.render(obj, 'application/json') - self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}') + self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8')) class UnicodeJSONRendererTests(TestCase): @@ -280,7 +288,7 @@ class JSONPRendererTests(TestCase): Tests specific to the JSONP Renderer """ - urls = 'rest_framework.tests.renderers' + urls = 'rest_framework.tests.test_renderers' def test_without_callback_with_json_renderer(self): """ @@ -289,7 +297,7 @@ class JSONPRendererTests(TestCase): resp = self.client.get('/jsonp/jsonrenderer', HTTP_ACCEPT='application/javascript') self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp['Content-Type'], 'application/javascript; charset=iso-8859-1') + self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp.content, ('callback(%s);' % _flat_repr).encode('ascii')) @@ -300,7 +308,7 @@ class JSONPRendererTests(TestCase): resp = self.client.get('/jsonp/nojsonrenderer', HTTP_ACCEPT='application/javascript') self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp['Content-Type'], 'application/javascript; charset=iso-8859-1') + self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp.content, ('callback(%s);' % _flat_repr).encode('ascii')) @@ -312,7 +320,7 @@ class JSONPRendererTests(TestCase): resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, HTTP_ACCEPT='application/javascript') self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp['Content-Type'], 'application/javascript; charset=iso-8859-1') + self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp.content, ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')) @@ -453,7 +461,7 @@ class CacheRenderTest(TestCase): Tests specific to caching responses """ - urls = 'rest_framework.tests.renderers' + urls = 'rest_framework.tests.test_renderers' cache_key = 'just_a_cache_key' diff --git a/rest_framework/tests/request.py b/rest_framework/tests/test_request.py index 97e5af20..a5c5e84c 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/test_request.py @@ -254,7 +254,7 @@ urlpatterns = patterns('', class TestContentParsingWithAuthentication(TestCase): - urls = 'rest_framework.tests.request' + urls = 'rest_framework.tests.test_request' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) diff --git a/rest_framework/tests/response.py b/rest_framework/tests/test_response.py index 4e04ac5c..eea3c641 100644 --- a/rest_framework/tests/response.py +++ b/rest_framework/tests/test_response.py @@ -118,7 +118,7 @@ class RendererIntegrationTests(TestCase): End-to-end testing of renderers using an ResponseMixin on a generic view. """ - urls = 'rest_framework.tests.response' + urls = 'rest_framework.tests.test_response' def test_default_renderer_serializes_content(self): """If the Accept header is not set the default renderer should serialize the response.""" @@ -198,7 +198,7 @@ class Issue122Tests(TestCase): """ Tests that covers #122. """ - urls = 'rest_framework.tests.response' + urls = 'rest_framework.tests.test_response' def test_only_html_renderer(self): """ @@ -218,7 +218,7 @@ class Issue467Tests(TestCase): Tests for #467 """ - urls = 'rest_framework.tests.response' + urls = 'rest_framework.tests.test_response' def test_form_has_label_and_help_text(self): resp = self.client.get('/html_new_model') @@ -232,7 +232,7 @@ class Issue807Tests(TestCase): Covers #807 """ - urls = 'rest_framework.tests.response' + urls = 'rest_framework.tests.test_response' def test_does_not_append_charset_by_default(self): """ diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/test_reverse.py index cb8d8132..93ef5637 100644 --- a/rest_framework/tests/reverse.py +++ b/rest_framework/tests/test_reverse.py @@ -19,7 +19,7 @@ class ReverseTests(TestCase): """ Tests for fully qualified URLs when using `reverse`. """ - urls = 'rest_framework.tests.reverse' + urls = 'rest_framework.tests.test_reverse' def test_reversed_urls_are_fully_qualified(self): request = factory.get('/view') diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py new file mode 100644 index 00000000..10d3cc25 --- /dev/null +++ b/rest_framework/tests/test_routers.py @@ -0,0 +1,121 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from django.test.client import RequestFactory +from rest_framework import serializers, viewsets +from rest_framework.compat import include, patterns, url +from rest_framework.decorators import link, action +from rest_framework.response import Response +from rest_framework.routers import SimpleRouter + +factory = RequestFactory() + +urlpatterns = patterns('',) + + +class BasicViewSet(viewsets.ViewSet): + def list(self, request, *args, **kwargs): + return Response({'method': 'list'}) + + @action() + def action1(self, request, *args, **kwargs): + return Response({'method': 'action1'}) + + @action() + def action2(self, request, *args, **kwargs): + return Response({'method': 'action2'}) + + @action(methods=['post', 'delete']) + def action3(self, request, *args, **kwargs): + return Response({'method': 'action2'}) + + @link() + def link1(self, request, *args, **kwargs): + return Response({'method': 'link1'}) + + @link() + def link2(self, request, *args, **kwargs): + return Response({'method': 'link2'}) + + +class TestSimpleRouter(TestCase): + def setUp(self): + self.router = SimpleRouter() + + def test_link_and_action_decorator(self): + routes = self.router.get_routes(BasicViewSet) + decorator_routes = routes[2:] + # Make sure all these endpoints exist and none have been clobbered + for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']): + route = decorator_routes[i] + # check url listing + self.assertEqual(route.url, + '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint)) + # check method to function mapping + if endpoint == 'action3': + methods_map = ['post', 'delete'] + elif endpoint.startswith('action'): + methods_map = ['post'] + else: + methods_map = ['get'] + for method in methods_map: + self.assertEqual(route.mapping[method], endpoint) + + +class RouterTestModel(models.Model): + uuid = models.CharField(max_length=20) + text = models.CharField(max_length=200) + + +class TestCustomLookupFields(TestCase): + """ + Ensure that custom lookup fields are correctly routed. + """ + urls = 'rest_framework.tests.test_routers' + + def setUp(self): + class NoteSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RouterTestModel + lookup_field = 'uuid' + fields = ('url', 'uuid', 'text') + + class NoteViewSet(viewsets.ModelViewSet): + queryset = RouterTestModel.objects.all() + serializer_class = NoteSerializer + lookup_field = 'uuid' + + RouterTestModel.objects.create(uuid='123', text='foo bar') + + self.router = SimpleRouter() + self.router.register(r'notes', NoteViewSet) + + from rest_framework.tests import test_routers + urls = getattr(test_routers, 'urlpatterns') + urls += patterns('', + url(r'^', include(self.router.urls)), + ) + + def test_custom_lookup_field_route(self): + detail_route = self.router.urls[-1] + detail_url_pattern = detail_route.regex.pattern + self.assertIn('<uuid>', detail_url_pattern) + + def test_retrieve_lookup_field_list_view(self): + response = self.client.get('/notes/') + self.assertEquals(response.data, + [{ + "url": "http://testserver/notes/123/", + "uuid": "123", "text": "foo bar" + }] + ) + + def test_retrieve_lookup_field_detail_view(self): + response = self.client.get('/notes/123/') + self.assertEquals(response.data, + { + "url": "http://testserver/notes/123/", + "uuid": "123", "text": "foo bar" + } + ) + diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/test_serializer.py index 1772ee37..6cc913c5 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -129,11 +129,6 @@ class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): fields = ['some_integer'] -class BrokenModelSerializer(serializers.ModelSerializer): - class Meta: - fields = ['some_field'] - - class BasicTests(TestCase): def setUp(self): self.comment = Comment( @@ -424,8 +419,12 @@ class ValidationTests(TestCase): Assert that a meaningful exception message is outputted when the model field is missing (e.g. when mistyping ``model``). """ + class BrokenModelSerializer(serializers.ModelSerializer): + class Meta: + fields = ['some_field'] + try: - serializer = BrokenModelSerializer() + BrokenModelSerializer() except AssertionError as e: self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option") except: @@ -447,7 +446,7 @@ class CustomValidationTests(TestCase): class CommentSerializerWithFieldValidator(CommentSerializer): def validate_email(self, attrs, source): - value = attrs[source] + attrs[source] return attrs def validate_content(self, attrs, source): @@ -1365,16 +1364,16 @@ class FieldLabelTest(TestCase): serializer = self.serializer_class() text_field = serializer.fields['text'] - self.assertEquals('Text comes here', text_field.label) - self.assertEquals('Text description.', text_field.help_text) + self.assertEqual('Text comes here', text_field.label) + self.assertEqual('Text description.', text_field.help_text) def test_field_ctor(self): """ This is check that ctor supports both label and help_text. """ - self.assertEquals('Label', fields.Field(label='Label', help_text='Help').label) - self.assertEquals('Help', fields.CharField(label='Label', help_text='Help').help_text) - self.assertEquals('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label) + self.assertEqual('Label', fields.Field(label='Label', help_text='Help').label) + self.assertEqual('Help', fields.CharField(label='Label', help_text='Help').help_text) + self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label) class AttributeMappingOnAutogeneratedFieldsTests(TestCase): @@ -1529,3 +1528,31 @@ class DefaultValuesOnAutogeneratedFieldsTests(TestCase): def test_url_field(self): self.field_test('url_field') + + +class MetadataSerializer(serializers.Serializer): + field1 = serializers.CharField(3, required=True) + field2 = serializers.CharField(10, required=False) + + +class MetadataSerializerTestCase(TestCase): + def setUp(self): + self.serializer = MetadataSerializer() + + def test_serializer_metadata(self): + metadata = self.serializer.metadata() + expected = { + 'field1': { + 'required': True, + 'max_length': 3, + 'type': 'string', + 'read_only': False + }, + 'field2': { + 'required': False, + 'max_length': 10, + 'type': 'string', + 'read_only': False + } + } + self.assertEqual(expected, metadata) diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/test_serializer_bulk_update.py index 8b0ded1a..8b0ded1a 100644 --- a/rest_framework/tests/serializer_bulk_update.py +++ b/rest_framework/tests/test_serializer_bulk_update.py diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/test_serializer_nested.py index 71d0e24b..71d0e24b 100644 --- a/rest_framework/tests/serializer_nested.py +++ b/rest_framework/tests/test_serializer_nested.py diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/test_settings.py index 857375c2..857375c2 100644 --- a/rest_framework/tests/settings.py +++ b/rest_framework/tests/test_settings.py diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/test_throttling.py index 11cbd8eb..da400b2f 100644 --- a/rest_framework/tests/throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -36,7 +36,7 @@ class MockView_MinuteThrottling(APIView): class ThrottlingTests(TestCase): - urls = 'rest_framework.tests.throttling' + urls = 'rest_framework.tests.test_throttling' def setUp(self): """ diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/test_urlpatterns.py index 29ed4a96..29ed4a96 100644 --- a/rest_framework/tests/urlpatterns.py +++ b/rest_framework/tests/test_urlpatterns.py diff --git a/rest_framework/tests/validation.py b/rest_framework/tests/test_validation.py index cbdd6515..cbdd6515 100644 --- a/rest_framework/tests/validation.py +++ b/rest_framework/tests/test_validation.py diff --git a/rest_framework/tests/views.py b/rest_framework/tests/test_views.py index 994cf6dc..2767d24c 100644 --- a/rest_framework/tests/views.py +++ b/rest_framework/tests/test_views.py @@ -1,12 +1,15 @@ from __future__ import unicode_literals + +import copy + from django.test import TestCase from django.test.client import RequestFactory + from rest_framework import status from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.views import APIView -import copy factory = RequestFactory() diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py deleted file mode 100644 index f8c2579e..00000000 --- a/rest_framework/tests/testcases.py +++ /dev/null @@ -1,66 +0,0 @@ -# http://djangosnippets.org/snippets/1011/ -from __future__ import unicode_literals -from django.conf import settings -from django.core.management import call_command -from django.db.models import loading -from django.test import TestCase - -NO_SETTING = ('!', None) - - -class TestSettingsManager(object): - """ - A class which can modify some Django settings temporarily for a - test and then revert them to their original values later. - - Automatically handles resyncing the DB if INSTALLED_APPS is - modified. - - """ - def __init__(self): - self._original_settings = {} - - def set(self, **kwargs): - for k, v in kwargs.iteritems(): - self._original_settings.setdefault(k, getattr(settings, k, - NO_SETTING)) - setattr(settings, k, v) - if 'INSTALLED_APPS' in kwargs: - self.syncdb() - - def syncdb(self): - loading.cache.loaded = False - call_command('syncdb', verbosity=0) - - def revert(self): - for k, v in self._original_settings.iteritems(): - if v == NO_SETTING: - delattr(settings, k) - else: - setattr(settings, k, v) - if 'INSTALLED_APPS' in self._original_settings: - self.syncdb() - self._original_settings = {} - - -class SettingsTestCase(TestCase): - """ - A subclass of the Django TestCase with a settings_manager - attribute which is an instance of TestSettingsManager. - - Comes with a tearDown() method that calls - self.settings_manager.revert(). - - """ - def __init__(self, *args, **kwargs): - super(SettingsTestCase, self).__init__(*args, **kwargs) - self.settings_manager = TestSettingsManager() - - def tearDown(self): - self.settings_manager.revert() - - -class TestModelsTestCase(SettingsTestCase): - def setUp(self, *args, **kwargs): - installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',) - self.settings_manager.set(INSTALLED_APPS=installed_apps) diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py index 08f88e11..554ebd1a 100644 --- a/rest_framework/tests/tests.py +++ b/rest_framework/tests/tests.py @@ -4,11 +4,13 @@ runner to pick up the tests. Yowzers. """ from __future__ import unicode_literals import os +import django modules = [filename.rsplit('.', 1)[0] for filename in os.listdir(os.path.dirname(__file__)) if filename.endswith('.py') and not filename.startswith('_')] __test__ = dict() -for module in modules: - exec("from rest_framework.tests.%s import *" % module) +if django.VERSION < (1, 6): + for module in modules: + exec("from rest_framework.tests.%s import *" % module) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index b6de18a8..b26a2085 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -3,7 +3,8 @@ Helper classes for parsers. """ from __future__ import unicode_literals from django.utils.datastructures import SortedDict -from rest_framework.compat import timezone +from django.utils.functional import Promise +from rest_framework.compat import timezone, force_text from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata import datetime import decimal @@ -19,7 +20,9 @@ class JSONEncoder(json.JSONEncoder): def default(self, o): # For Date Time string spec, see ECMA 262 # http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15 - if isinstance(o, datetime.datetime): + if isinstance(o, Promise): + return force_text(o) + elif isinstance(o, datetime.datetime): r = o.isoformat() if o.microsecond: r = r[:23] + r[26:] diff --git a/rest_framework/views.py b/rest_framework/views.py index 555fa2f4..e1b6705b 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -2,13 +2,15 @@ Provides an APIView class that is the base of all views in REST framework. """ from __future__ import unicode_literals + from django.core.exceptions import PermissionDenied from django.http import Http404, HttpResponse +from django.utils.datastructures import SortedDict from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions from rest_framework.compat import View -from rest_framework.response import Response from rest_framework.request import Request +from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.utils.formatting import get_view_name, get_view_description @@ -51,21 +53,6 @@ class APIView(View): 'Vary': 'Accept' } - def metadata(self, request): - return { - 'name': get_view_name(self.__class__), - 'description': get_view_description(self.__class__), - 'renders': [renderer.media_type for renderer in self.renderer_classes], - 'parses': [parser.media_type for parser in self.parser_classes], - } - # TODO: Add 'fields', from serializer info, if it exists. - # serializer = self.get_serializer() - # if serializer is not None: - # field_name_types = {} - # for name, field in form.fields.iteritems(): - # field_name_types[name] = field.__class__.__name__ - # content['fields'] = field_name_types - def http_method_not_allowed(self, request, *args, **kwargs): """ If `request.method` does not correspond to a handler method, @@ -348,3 +335,15 @@ class APIView(View): a less useful default implementation. """ return Response(self.metadata(request), status=status.HTTP_200_OK) + + def metadata(self, request): + """ + Return a dictionary of metadata about the view. + Used to return responses for OPTIONS requests. + """ + ret = SortedDict() + ret['name'] = get_view_name(self.__class__) + ret['description'] = get_view_description(self.__class__) + ret['renders'] = [renderer.media_type for renderer in self.renderer_classes] + ret['parses'] = [parser.media_type for parser in self.parser_classes] + return ret |
