diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/__init__.py | 2 | ||||
| -rw-r--r-- | rest_framework/compat.py | 13 | ||||
| -rw-r--r-- | rest_framework/decorators.py | 10 | ||||
| -rw-r--r-- | rest_framework/fields.py | 90 | ||||
| -rw-r--r-- | rest_framework/filters.py | 82 | ||||
| -rw-r--r-- | rest_framework/generics.py | 56 | ||||
| -rw-r--r-- | rest_framework/mixins.py | 32 | ||||
| -rw-r--r-- | rest_framework/permissions.py | 5 | ||||
| -rw-r--r-- | rest_framework/relations.py | 122 | ||||
| -rw-r--r-- | rest_framework/renderers.py | 101 | ||||
| -rw-r--r-- | rest_framework/request.py | 33 | ||||
| -rw-r--r-- | rest_framework/response.py | 24 | ||||
| -rw-r--r-- | rest_framework/routers.py | 48 | ||||
| -rwxr-xr-x | rest_framework/runtests/runtests.py | 7 | ||||
| -rw-r--r-- | rest_framework/runtests/settings.py | 2 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 130 | ||||
| -rw-r--r-- | rest_framework/static/rest_framework/css/bootstrap-tweaks.css | 165 | ||||
| -rw-r--r-- | rest_framework/static/rest_framework/css/default.css | 149 | ||||
| -rw-r--r-- | rest_framework/templates/rest_framework/base.html | 13 | ||||
| -rw-r--r-- | rest_framework/templates/rest_framework/form.html | 2 | ||||
| -rw-r--r-- | rest_framework/templates/rest_framework/login_base.html | 74 | ||||
| -rw-r--r-- | rest_framework/templatetags/rest_framework.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/models.py | 17 | ||||
| -rw-r--r-- | rest_framework/tests/relations.py | 47 | ||||
| -rw-r--r-- | rest_framework/tests/test_authentication.py (renamed from rest_framework/tests/authentication.py) | 53 | ||||
| -rw-r--r-- | rest_framework/tests/test_breadcrumbs.py (renamed from rest_framework/tests/breadcrumbs.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_decorators.py (renamed from rest_framework/tests/decorators.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_description.py (renamed from rest_framework/tests/description.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_fields.py (renamed from rest_framework/tests/fields.py) | 242 | ||||
| -rw-r--r-- | rest_framework/tests/test_files.py (renamed from rest_framework/tests/files.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_filters.py (renamed from rest_framework/tests/filterset.py) | 230 | ||||
| -rw-r--r-- | rest_framework/tests/test_genericrelations.py (renamed from rest_framework/tests/genericrelations.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_generics.py (renamed from rest_framework/tests/generics.py) | 115 | ||||
| -rw-r--r-- | rest_framework/tests/test_htmlrenderer.py (renamed from rest_framework/tests/htmlrenderer.py) | 14 | ||||
| -rw-r--r-- | rest_framework/tests/test_hyperlinkedserializers.py (renamed from rest_framework/tests/hyperlinkedserializers.py) | 50 | ||||
| -rw-r--r-- | rest_framework/tests/test_multitable_inheritance.py (renamed from rest_framework/tests/multitable_inheritance.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_negotiation.py (renamed from rest_framework/tests/negotiation.py) | 9 | ||||
| -rw-r--r-- | rest_framework/tests/test_pagination.py (renamed from rest_framework/tests/pagination.py) | 14 | ||||
| -rw-r--r-- | rest_framework/tests/test_parsers.py (renamed from rest_framework/tests/parsers.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_permissions.py (renamed from rest_framework/tests/permissions.py) | 42 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations.py | 100 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_hyperlink.py (renamed from rest_framework/tests/relations_hyperlink.py) | 79 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_nested.py (renamed from rest_framework/tests/relations_nested.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_pk.py (renamed from rest_framework/tests/relations_pk.py) | 121 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_slug.py (renamed from rest_framework/tests/relations_slug.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_renderers.py (renamed from rest_framework/tests/renderers.py) | 66 | ||||
| -rw-r--r-- | rest_framework/tests/test_request.py (renamed from rest_framework/tests/request.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_response.py (renamed from rest_framework/tests/response.py) | 131 | ||||
| -rw-r--r-- | rest_framework/tests/test_reverse.py (renamed from rest_framework/tests/reverse.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_routers.py | 150 | ||||
| -rw-r--r-- | rest_framework/tests/test_serializer.py (renamed from rest_framework/tests/serializer.py) | 560 | ||||
| -rw-r--r-- | rest_framework/tests/test_serializer_bulk_update.py (renamed from rest_framework/tests/serializer_bulk_update.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_serializer_nested.py (renamed from rest_framework/tests/serializer_nested.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_settings.py (renamed from rest_framework/tests/settings.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_throttling.py (renamed from rest_framework/tests/throttling.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_urlpatterns.py (renamed from rest_framework/tests/urlpatterns.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_validation.py (renamed from rest_framework/tests/validation.py) | 22 | ||||
| -rw-r--r-- | rest_framework/tests/test_views.py (renamed from rest_framework/tests/views.py) | 5 | ||||
| -rw-r--r-- | rest_framework/tests/testcases.py | 66 | ||||
| -rw-r--r-- | rest_framework/tests/tests.py | 6 | ||||
| -rw-r--r-- | rest_framework/utils/encoders.py | 7 | ||||
| -rw-r--r-- | rest_framework/views.py | 31 | ||||
| -rw-r--r-- | rest_framework/viewsets.py | 15 |
63 files changed, 2711 insertions, 651 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index b4961e2f..0a210186 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.3.2' +__version__ = '2.3.5' VERSION = __version__ # synonym diff --git a/rest_framework/compat.py b/rest_framework/compat.py index cd39f544..76dc0052 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -495,3 +495,16 @@ except ImportError: oauth2_provider_forms = None oauth2_provider_scope = None oauth2_constants = None + +# Handle lazy strings +from django.utils.functional import Promise + +if six.PY3: + def is_non_str_iterable(obj): + if (isinstance(obj, str) or + (isinstance(obj, Promise) and obj._delegate_text)): + return False + return hasattr(obj, '__iter__') +else: + def is_non_str_iterable(obj): + return hasattr(obj, '__iter__') diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 81e585e1..c69756a4 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,5 +1,5 @@ """ -The most imporant decorator in this module is `@api_view`, which is used +The most important decorator in this module is `@api_view`, which is used for writing function-based views with REST framework. There are also various decorators for setting the API policies on function @@ -40,7 +40,7 @@ def api_view(http_method_names): # api_view applied with eg. string instead of list of strings assert isinstance(http_method_names, (list, tuple)), \ - '@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__ + '@api_view expected a list of strings, received %s' % type(http_method_names).__name__ allowed_methods = set(http_method_names) | set(('options',)) WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] @@ -112,18 +112,18 @@ def link(**kwargs): Used to mark a method on a ViewSet that should be routed for GET requests. """ def decorator(func): - func.bind_to_method = 'get' + func.bind_to_methods = ['get'] func.kwargs = kwargs return func return decorator -def action(**kwargs): +def action(methods=['post'], **kwargs): """ Used to mark a method on a ViewSet that should be routed for POST requests. """ def decorator(func): - func.bind_to_method = 'post' + func.bind_to_methods = methods func.kwargs = kwargs return func return decorator diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c83ee5ec..535aa2ac 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -11,20 +11,21 @@ from decimal import Decimal, DecimalException import inspect import re import warnings - from django.core import validators from django.core.exceptions import ValidationError from django.conf import settings +from django.db.models.fields import BLANK_CHOICE_DASH from django import forms from django.forms import widgets from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ - +from django.utils.datastructures import SortedDict from rest_framework import ISO_8601 -from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time +from rest_framework.compat import (timezone, parse_date, parse_datetime, + parse_time) from rest_framework.compat import BytesIO from rest_framework.compat import six -from rest_framework.compat import smart_text +from rest_framework.compat import smart_text, force_text, is_non_str_iterable from rest_framework.settings import api_settings @@ -50,7 +51,7 @@ def get_component(obj, attr_name): return that attribute on the object. """ if isinstance(obj, dict): - val = obj[attr_name] + val = obj.get(attr_name) else: val = getattr(obj, attr_name) @@ -60,7 +61,8 @@ def get_component(obj, attr_name): def readable_datetime_formats(formats): - format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]') + format = ', '.join(formats).replace(ISO_8601, + 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]') return humanize_strptime(format) @@ -107,8 +109,9 @@ class Field(object): partial = False use_files = False form_field_class = forms.CharField + type_label = 'field' - def __init__(self, source=None): + def __init__(self, source=None, label=None, help_text=None): self.parent = None self.creation_counter = Field.creation_counter @@ -116,6 +119,12 @@ class Field(object): self.source = source + if label is not None: + self.label = smart_text(label) + + if help_text is not None: + self.help_text = smart_text(help_text) + def initialize(self, parent, field_name): """ Called to set up a field prior to field_to_native or field_from_native. @@ -167,11 +176,16 @@ class Field(object): if is_protected_type(value): return value - elif hasattr(value, '__iter__') and not isinstance(value, (dict, six.string_types)): + elif (is_non_str_iterable(value) and + not isinstance(value, (dict, six.string_types))): return [self.to_native(item) for item in value] elif isinstance(value, dict): - return dict(map(self.to_native, (k, v)) for k, v in value.items()) - return smart_text(value) + # Make sure we preserve field ordering, if it exists + ret = SortedDict() + for key, val in value.items(): + ret[key] = self.to_native(val) + return ret + return force_text(value) def attributes(self): """ @@ -181,6 +195,18 @@ class Field(object): return {'type': self.type_name} return {} + def metadata(self): + metadata = SortedDict() + metadata['type'] = self.type_label + metadata['required'] = getattr(self, 'required', False) + optional_attrs = ['read_only', 'label', 'help_text', + 'min_length', 'max_length'] + for attr in optional_attrs: + value = getattr(self, attr, None) + if value is not None and value != '': + metadata[attr] = force_text(value, strings_only=True) + return metadata + class WritableField(Field): """ @@ -194,7 +220,8 @@ class WritableField(Field): widget = widgets.TextInput default = None - def __init__(self, source=None, read_only=False, required=None, + def __init__(self, source=None, label=None, help_text=None, + read_only=False, required=None, validators=[], error_messages=None, widget=None, default=None, blank=None): @@ -205,7 +232,7 @@ class WritableField(Field): DeprecationWarning, stacklevel=2) required = not(blank) - super(WritableField, self).__init__(source=source) + super(WritableField, self).__init__(source=source, label=label, help_text=help_text) self.read_only = read_only if required is None: @@ -268,7 +295,10 @@ class WritableField(Field): except KeyError: if self.default is not None and not self.partial: # Note: partial updates shouldn't set defaults - native = self.default + if is_simple_callable(self.default): + native = self.default() + else: + native = self.default else: if self.required: raise ValidationError(self.error_messages['required']) @@ -335,6 +365,7 @@ class ModelField(WritableField): class BooleanField(WritableField): type_name = 'BooleanField' + type_label = 'boolean' form_field_class = forms.BooleanField widget = widgets.CheckboxInput default_error_messages = { @@ -357,6 +388,7 @@ class BooleanField(WritableField): class CharField(WritableField): type_name = 'CharField' + type_label = 'string' form_field_class = forms.CharField def __init__(self, max_length=None, min_length=None, *args, **kwargs): @@ -375,23 +407,38 @@ class CharField(WritableField): class URLField(CharField): type_name = 'URLField' + type_label = 'url' def __init__(self, **kwargs): - kwargs['max_length'] = kwargs.get('max_length', 200) kwargs['validators'] = [validators.URLValidator()] super(URLField, self).__init__(**kwargs) class SlugField(CharField): type_name = 'SlugField' + type_label = 'slug' + form_field_class = forms.SlugField + + default_error_messages = { + 'invalid': _("Enter a valid 'slug' consisting of letters, numbers," + " underscores or hyphens."), + } + default_validators = [validators.validate_slug] def __init__(self, *args, **kwargs): - kwargs['max_length'] = kwargs.get('max_length', 50) super(SlugField, self).__init__(*args, **kwargs) + def __deepcopy__(self, memo): + result = copy.copy(self) + memo[id(self)] = result + #result.widget = copy.deepcopy(self.widget, memo) + result.validators = self.validators[:] + return result + class ChoiceField(WritableField): type_name = 'ChoiceField' + type_label = 'multiple choice' form_field_class = forms.ChoiceField widget = widgets.Select default_error_messages = { @@ -402,6 +449,8 @@ class ChoiceField(WritableField): def __init__(self, choices=(), *args, **kwargs): super(ChoiceField, self).__init__(*args, **kwargs) self.choices = choices + if not self.required: + self.choices = BLANK_CHOICE_DASH + self.choices def _get_choices(self): return self._choices @@ -440,6 +489,7 @@ class ChoiceField(WritableField): class EmailField(CharField): type_name = 'EmailField' + type_label = 'email' form_field_class = forms.EmailField default_error_messages = { @@ -463,6 +513,7 @@ class EmailField(CharField): class RegexField(CharField): type_name = 'RegexField' + type_label = 'regex' form_field_class = forms.RegexField def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs): @@ -492,6 +543,7 @@ class RegexField(CharField): class DateField(WritableField): type_name = 'DateField' + type_label = 'date' widget = widgets.DateInput form_field_class = forms.DateField @@ -555,6 +607,7 @@ class DateField(WritableField): class DateTimeField(WritableField): type_name = 'DateTimeField' + type_label = 'datetime' widget = widgets.DateTimeInput form_field_class = forms.DateTimeField @@ -624,6 +677,7 @@ class DateTimeField(WritableField): class TimeField(WritableField): type_name = 'TimeField' + type_label = 'time' widget = widgets.TimeInput form_field_class = forms.TimeField @@ -680,6 +734,7 @@ class TimeField(WritableField): class IntegerField(WritableField): type_name = 'IntegerField' + type_label = 'integer' form_field_class = forms.IntegerField default_error_messages = { @@ -710,6 +765,7 @@ class IntegerField(WritableField): class FloatField(WritableField): type_name = 'FloatField' + type_label = 'float' form_field_class = forms.FloatField default_error_messages = { @@ -729,6 +785,7 @@ class FloatField(WritableField): class DecimalField(WritableField): type_name = 'DecimalField' + type_label = 'decimal' form_field_class = forms.DecimalField default_error_messages = { @@ -799,6 +856,7 @@ class DecimalField(WritableField): class FileField(WritableField): use_files = True type_name = 'FileField' + type_label = 'file upload' form_field_class = forms.FileField widget = widgets.FileInput @@ -842,6 +900,8 @@ class FileField(WritableField): class ImageField(FileField): use_files = True + type_name = 'ImageField' + type_label = 'image upload' form_field_class = forms.ImageField default_error_messages = { diff --git a/rest_framework/filters.py b/rest_framework/filters.py index f2163f6f..c058bc71 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -3,9 +3,9 @@ Provides generic filtering backends that can be used to filter the results returned by list views. """ from __future__ import unicode_literals - from django.db import models -from rest_framework.compat import django_filters +from rest_framework.compat import django_filters, six +from functools import reduce import operator FilterSet = django_filters and django_filters.FilterSet or None @@ -32,40 +32,33 @@ class DjangoFilterBackend(BaseFilterBackend): def __init__(self): assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed' - def get_filter_class(self, view): + def get_filter_class(self, view, queryset=None): """ Return the django-filters `FilterSet` used to filter the queryset. """ filter_class = getattr(view, 'filter_class', None) filter_fields = getattr(view, 'filter_fields', None) - model_cls = getattr(view, 'model', None) - queryset = getattr(view, 'queryset', None) - if model_cls is None and queryset is not None: - model_cls = queryset.model if filter_class: filter_model = filter_class.Meta.model - assert issubclass(filter_model, model_cls), \ - 'FilterSet model %s does not match view model %s' % \ - (filter_model, model_cls) + assert issubclass(filter_model, queryset.model), \ + 'FilterSet model %s does not match queryset model %s' % \ + (filter_model, queryset.model) return filter_class if filter_fields: - assert model_cls is not None, 'Cannot use DjangoFilterBackend ' \ - 'on a view which does not have a .model or .queryset attribute.' - class AutoFilterSet(self.default_filter_set): class Meta: - model = model_cls + model = queryset.model fields = filter_fields return AutoFilterSet return None def filter_queryset(self, request, queryset, view): - filter_class = self.get_filter_class(view) + filter_class = self.get_filter_class(view, queryset) if filter_class: return filter_class(request.QUERY_PARAMS, queryset=queryset).qs @@ -74,6 +67,16 @@ class DjangoFilterBackend(BaseFilterBackend): class SearchFilter(BaseFilterBackend): + search_param = 'search' # The URL query parameter used for the search. + + def get_search_terms(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be comma and/or whitespace delimited. + """ + params = request.QUERY_PARAMS.get(self.search_param, '') + return params.replace(',', ' ').split() + def construct_search(self, field_name): if field_name.startswith('^'): return "%s__istartswith" % field_name[1:] @@ -88,12 +91,53 @@ class SearchFilter(BaseFilterBackend): search_fields = getattr(view, 'search_fields', None) if not search_fields: - return None + return queryset orm_lookups = [self.construct_search(str(search_field)) - for search_field in self.search_fields] - for bit in self.query.split(): - or_queries = [models.Q(**{orm_lookup: bit}) + for search_field in search_fields] + + for search_term in self.get_search_terms(request): + or_queries = [models.Q(**{orm_lookup: search_term}) for orm_lookup in orm_lookups] queryset = queryset.filter(reduce(operator.or_, or_queries)) + + return queryset + + +class OrderingFilter(BaseFilterBackend): + ordering_param = 'ordering' # The URL query parameter used for the ordering. + + def get_ordering(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be comma and/or whitespace delimited. + """ + params = request.QUERY_PARAMS.get(self.ordering_param) + if params: + return [param.strip() for param in params.split(',')] + + def get_default_ordering(self, view): + ordering = getattr(view, 'ordering', None) + if isinstance(ordering, six.string_types): + return (ordering,) + return ordering + + def remove_invalid_fields(self, queryset, ordering): + field_names = [field.name for field in queryset.model._meta.fields] + return [term for term in ordering if term.lstrip('-') in field_names] + + def filter_queryset(self, request, queryset, view): + ordering = self.get_ordering(request) + + if ordering: + # Skip any incorrect parameters + ordering = self.remove_invalid_fields(queryset, ordering) + + if not ordering: + # Use 'ordering' attribtue by default + ordering = self.get_default_ordering(view) + + if ordering: + return queryset.order_by(*ordering) + return queryset diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 05ec93d3..9ccc7898 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -3,17 +3,28 @@ Generic views that provide commonly needed behaviour. """ from __future__ import unicode_literals -from django.core.exceptions import ImproperlyConfigured +from django.core.exceptions import ImproperlyConfigured, PermissionDenied from django.core.paginator import Paginator, InvalidPage from django.http import Http404 -from django.shortcuts import get_object_or_404 +from django.shortcuts import get_object_or_404 as _get_object_or_404 from django.utils.translation import ugettext as _ -from rest_framework import views, mixins -from rest_framework.exceptions import ConfigurationError +from rest_framework import views, mixins, exceptions +from rest_framework.request import clone_request from rest_framework.settings import api_settings import warnings +def get_object_or_404(queryset, **filter_kwargs): + """ + Same as Django's standard shortcut, but make sure to raise 404 + if the filter_kwargs don't match the required types. + """ + try: + return _get_object_or_404(queryset, **filter_kwargs) + except (TypeError, ValueError): + raise Http404 + + class GenericAPIView(views.APIView): """ Base class for all other generic views. @@ -274,7 +285,7 @@ class GenericAPIView(views.APIView): ) filter_kwargs = {self.slug_field: slug} else: - raise ConfigurationError( + raise exceptions.ConfigurationError( 'Expected view %s to be called with a URL keyword argument ' 'named "%s". Fix your URL conf, or set the `.lookup_field` ' 'attribute on the view correctly.' % @@ -310,6 +321,41 @@ class GenericAPIView(views.APIView): """ pass + def metadata(self, request): + """ + Return a dictionary of metadata about the view. + Used to return responses for OPTIONS requests. + + We override the default behavior, and add some extra information + about the required request body for POST and PUT operations. + """ + ret = super(GenericAPIView, self).metadata(request) + + actions = {} + for method in ('PUT', 'POST'): + if method not in self.allowed_methods: + continue + + cloned_request = clone_request(request, method) + try: + # Test global permissions + self.check_permissions(cloned_request) + # Test object permissions + if method == 'PUT': + self.get_object() + except (exceptions.APIException, PermissionDenied, Http404): + pass + else: + # If user has appropriate permissions for the view, include + # appropriate metadata about the fields that should be supplied. + serializer = self.get_serializer() + actions[method] = serializer.metadata() + + if actions: + ret['actions'] = actions + + return ret + ########################################################## ### Concrete view classes that provide method handlers ### diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index ae703771..f11def6d 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -10,6 +10,7 @@ from django.http import Http404 from rest_framework import status from rest_framework.response import Response from rest_framework.request import clone_request +import warnings def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None): @@ -42,7 +43,6 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None) class CreateModelMixin(object): """ Create a model instance. - Should be mixed in with any `GenericAPIView`. """ def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.DATA, files=request.FILES) @@ -67,7 +67,6 @@ class CreateModelMixin(object): class ListModelMixin(object): """ List a queryset. - Should be mixed in with `MultipleObjectAPIView`. """ empty_error = "Empty list and '%(class_name)s.allow_empty' is False." @@ -77,6 +76,12 @@ class ListModelMixin(object): # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. if not self.allow_empty and not self.object_list: + warnings.warn( + 'The `allow_empty` parameter is due to be deprecated. ' + 'To use `allow_empty=False` style behavior, You should override ' + '`get_queryset()` and explicitly raise a 404 on empty querysets.', + PendingDeprecationWarning + ) class_name = self.__class__.__name__ error_msg = self.empty_error % {'class_name': class_name} raise Http404(error_msg) @@ -94,7 +99,6 @@ class ListModelMixin(object): class RetrieveModelMixin(object): """ Retrieve a model instance. - Should be mixed in with `SingleObjectAPIView`. """ def retrieve(self, request, *args, **kwargs): self.object = self.get_object() @@ -105,17 +109,12 @@ class RetrieveModelMixin(object): class UpdateModelMixin(object): """ Update a model instance. - Should be mixed in with `SingleObjectAPIView`. """ def update(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) - self.object = None - try: - self.object = self.get_object() - except Http404: - # If this is a PUT-as-create operation, we need to ensure that - # we have relevant permissions, as if this was a POST request. - self.check_permissions(clone_request(request, 'POST')) + self.object = self.get_object_or_none() + + if self.object is None: created = True save_kwargs = {'force_insert': True} success_status_code = status.HTTP_201_CREATED @@ -139,6 +138,16 @@ class UpdateModelMixin(object): kwargs['partial'] = True return self.update(request, *args, **kwargs) + def get_object_or_none(self): + try: + return self.get_object() + except Http404: + # If this is a PUT-as-create operation, we need to ensure that + # we have relevant permissions, as if this was a POST request. + # This will either raise a PermissionDenied exception, + # or simply return None + self.check_permissions(clone_request(self.request, 'POST')) + def pre_save(self, obj): """ Set any attributes on the object that are implicit in the request. @@ -168,7 +177,6 @@ class UpdateModelMixin(object): class DestroyModelMixin(object): """ Destroy a model instance. - Should be mixed in with `SingleObjectAPIView`. """ def destroy(self, request, *args, **kwargs): obj = self.get_object() diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 751f31a7..45fcfd66 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -126,6 +126,11 @@ class DjangoModelPermissions(BasePermission): if model_cls is None and queryset is not None: model_cls = queryset.model + # Workaround to ensure DjangoModelPermissions are not applied + # to the root view when using DefaultRouter. + if model_cls is None and getattr(view, '_ignore_model_permissions'): + return True + assert model_cls, ('Cannot apply DjangoModelPermissions on a view that' ' does not have `.model` or `.queryset` property.') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index fc5054b2..edaf76d6 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -8,10 +8,11 @@ from __future__ import unicode_literals from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch from django import forms +from django.db.models.fields import BLANK_CHOICE_DASH from django.forms import widgets from django.forms.models import ModelChoiceIterator from django.utils.translation import ugettext_lazy as _ -from rest_framework.fields import Field, WritableField, get_component +from rest_framework.fields import Field, WritableField, get_component, is_simple_callable from rest_framework.reverse import reverse from rest_framework.compat import urlparse from rest_framework.compat import smart_text @@ -47,7 +48,7 @@ class RelatedField(WritableField): DeprecationWarning, stacklevel=2) kwargs['required'] = not kwargs.pop('null') - self.queryset = kwargs.pop('queryset', None) + queryset = kwargs.pop('queryset', None) self.many = kwargs.pop('many', self.many) if self.many: self.widget = self.many_widget @@ -56,6 +57,11 @@ class RelatedField(WritableField): kwargs['read_only'] = kwargs.pop('read_only', self.read_only) super(RelatedField, self).__init__(*args, **kwargs) + if not self.required: + self.empty_label = BLANK_CHOICE_DASH[0][1] + + self.queryset = queryset + def initialize(self, parent, field_name): super(RelatedField, self).initialize(parent, field_name) if self.queryset is None and not self.read_only: @@ -66,7 +72,6 @@ class RelatedField(WritableField): else: # Reverse self.queryset = manager.field.rel.to._default_manager.all() except Exception: - raise msg = ('Serializer related fields must include a `queryset`' + ' argument or set `read_only=True') raise Exception(msg) @@ -139,7 +144,12 @@ class RelatedField(WritableField): return None if self.many: - return [self.to_native(item) for item in value.all()] + if is_simple_callable(getattr(value, 'all', None)): + return [self.to_native(item) for item in value.all()] + else: + # Also support non-queryset iterables. + # This allows us to also support plain lists of related items. + return [self.to_native(item) for item in value] return self.to_native(value) def field_from_native(self, data, files, field_name, into): @@ -221,15 +231,28 @@ class PrimaryKeyRelatedField(RelatedField): def field_to_native(self, obj, field_name): if self.many: # To-many relationship - try: + + queryset = None + if not self.source: # Prefer obj.serializable_value for performance reasons - queryset = obj.serializable_value(self.source or field_name) - except AttributeError: + try: + queryset = obj.serializable_value(field_name) + except AttributeError: + pass + if queryset is None: # RelatedManager (reverse relationship) - queryset = getattr(obj, self.source or field_name) + source = self.source or field_name + queryset = obj + for component in source.split('.'): + queryset = get_component(queryset, component) # Forward relationship - return [self.to_native(item.pk) for item in queryset.all()] + if is_simple_callable(getattr(queryset, 'all', None)): + return [self.to_native(item.pk) for item in queryset.all()] + else: + # Also support non-queryset iterables. + # This allows us to also support plain lists of related items. + return [self.to_native(item.pk) for item in queryset] # To-one relationship try: @@ -434,7 +457,7 @@ class HyperlinkedRelatedField(RelatedField): raise Exception('Writable related fields must include a `queryset` argument') try: - http_prefix = value.startswith('http:') or value.startswith('https:') + http_prefix = value.startswith(('http:', 'https:')) except AttributeError: msg = self.error_messages['incorrect_type'] raise ValidationError(msg % type(value).__name__) @@ -465,17 +488,35 @@ class HyperlinkedIdentityField(Field): """ Represents the instance, or a property on the instance, using hyperlinking. """ + lookup_field = 'pk' + read_only = True + + # These are all pending deprecation pk_url_kwarg = 'pk' slug_field = 'slug' slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden - read_only = True def __init__(self, *args, **kwargs): - # TODO: Make view_name mandatory, and have the - # HyperlinkedModelSerializer set it on-the-fly - self.view_name = kwargs.pop('view_name', None) - # Optionally the format of the target hyperlink may be specified + try: + self.view_name = kwargs.pop('view_name') + except KeyError: + msg = "HyperlinkedIdentityField requires 'view_name' argument" + raise ValueError(msg) + self.format = kwargs.pop('format', None) + lookup_field = kwargs.pop('lookup_field', None) + self.lookup_field = lookup_field or self.lookup_field + + # These are pending deprecation + if 'pk_url_kwarg' in kwargs: + msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_url_kwarg' in kwargs: + msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_field' in kwargs: + msg = 'slug_field is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) self.slug_field = kwargs.pop('slug_field', self.slug_field) default_slug_kwarg = self.slug_url_kwarg or self.slug_field @@ -487,8 +528,7 @@ class HyperlinkedIdentityField(Field): def field_to_native(self, obj, field_name): request = self.context.get('request', None) format = self.context.get('format', None) - view_name = self.view_name or self.parent.opts.view_name - kwargs = {self.pk_url_kwarg: obj.pk} + view_name = self.view_name if request is None: warnings.warn("Using `HyperlinkedIdentityField` without including the " @@ -508,29 +548,51 @@ class HyperlinkedIdentityField(Field): if format and self.format and self.format != format: format = self.format + # Return the hyperlink, or error if incorrectly configured. try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) + return self.get_url(obj, view_name, request, format) except NoReverseMatch: - pass - - slug = getattr(obj, self.slug_field, None) + msg = ( + 'Could not resolve URL for hyperlinked relationship using ' + 'view name "%s". You may have failed to include the related ' + 'model in your API, or incorrectly configured the ' + '`lookup_field` attribute on this field.' + ) + raise Exception(msg % view_name) - if not slug: - raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + def get_url(self, obj, view_name, request, format): + """ + Given an object, return the URL that hyperlinks to the object. - kwargs = {self.slug_url_kwarg: slug} + May raise a `NoReverseMatch` if the `view_name` and `lookup_field` + attributes are not configured to correctly match the URL conf. + """ + lookup_field = getattr(obj, self.lookup_field) + kwargs = {self.lookup_field: lookup_field} try: return reverse(view_name, kwargs=kwargs, request=request, format=format) except NoReverseMatch: pass - kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass + if self.pk_url_kwarg != 'pk': + # Only try pk lookup if it has been explicitly set. + # Otherwise, the default `lookup_field = 'pk'` has us covered. + kwargs = {self.pk_url_kwarg: obj.pk} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass - raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + slug = getattr(obj, self.slug_field, None) + if slug: + # Only use slug lookup if a slug field exists on the model + kwargs = {self.slug_url_kwarg: slug} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass + + raise NoReverseMatch() ### Old-style many classes for backwards compat diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 1917a080..b2fe43ea 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -9,7 +9,6 @@ REST framework also provides an HTML renderer the renders the browsable API. from __future__ import unicode_literals import copy -import string import json from django import forms from django.http.multipartparser import parse_header @@ -36,6 +35,7 @@ class BaseRenderer(object): media_type = None format = None + charset = 'utf-8' def render(self, data, accepted_media_type=None, renderer_context=None): raise NotImplemented('Renderer class requires .render() to be implemented') @@ -43,16 +43,21 @@ class BaseRenderer(object): class JSONRenderer(BaseRenderer): """ - Renderer which serializes to json. + Renderer which serializes to JSON. + Applies JSON's backslash-u character escaping for non-ascii characters. """ media_type = 'application/json' format = 'json' encoder_class = encoders.JSONEncoder + ensure_ascii = True + charset = 'utf-8' + # Note that JSON encodings must be utf-8, utf-16 or utf-32. + # See: http://www.ietf.org/rfc/rfc4627.txt def render(self, data, accepted_media_type=None, renderer_context=None): """ - Render `obj` into json. + Render `data` into JSON. """ if data is None: return '' @@ -72,7 +77,25 @@ class JSONRenderer(BaseRenderer): except (ValueError, TypeError): indent = None - return json.dumps(data, cls=self.encoder_class, indent=indent) + ret = json.dumps(data, cls=self.encoder_class, + indent=indent, ensure_ascii=self.ensure_ascii) + + # On python 2.x json.dumps() returns bytestrings if ensure_ascii=True, + # but if ensure_ascii=False, the return type is underspecified, + # and may (or may not) be unicode. + # On python 3.x json.dumps() returns unicode strings. + if isinstance(ret, six.text_type): + return bytes(ret.encode(self.charset)) + return ret + + +class UnicodeJSONRenderer(JSONRenderer): + ensure_ascii = False + charset = 'utf-8' + """ + Renderer which serializes to JSON. + Does *not* apply JSON's character escaping for non-ascii characters. + """ class JSONPRenderer(JSONRenderer): @@ -105,7 +128,7 @@ class JSONPRenderer(JSONRenderer): callback = self.get_callback(renderer_context) json = super(JSONPRenderer, self).render(data, accepted_media_type, renderer_context) - return "%s(%s);" % (callback, json) + return callback.encode(self.charset) + b'(' + json + b');' class XMLRenderer(BaseRenderer): @@ -115,6 +138,7 @@ class XMLRenderer(BaseRenderer): media_type = 'application/xml' format = 'xml' + charset = 'utf-8' def render(self, data, accepted_media_type=None, renderer_context=None): """ @@ -125,7 +149,7 @@ class XMLRenderer(BaseRenderer): stream = StringIO() - xml = SimplerXMLGenerator(stream, "utf-8") + xml = SimplerXMLGenerator(stream, self.charset) xml.startDocument() xml.startElement("root", {}) @@ -164,6 +188,7 @@ class YAMLRenderer(BaseRenderer): media_type = 'application/yaml' format = 'yaml' encoder = encoders.SafeDumper + charset = 'utf-8' def render(self, data, accepted_media_type=None, renderer_context=None): """ @@ -174,7 +199,7 @@ class YAMLRenderer(BaseRenderer): if data is None: return '' - return yaml.dump(data, stream=None, Dumper=self.encoder) + return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder) class TemplateHTMLRenderer(BaseRenderer): @@ -204,6 +229,7 @@ class TemplateHTMLRenderer(BaseRenderer): '%(status_code)s.html', 'api_exception.html' ] + charset = 'utf-8' def render(self, data, accepted_media_type=None, renderer_context=None): """ @@ -275,6 +301,7 @@ class StaticHTMLRenderer(TemplateHTMLRenderer): """ media_type = 'text/html' format = 'html' + charset = 'utf-8' def render(self, data, accepted_media_type=None, renderer_context=None): renderer_context = renderer_context or {} @@ -296,6 +323,7 @@ class BrowsableAPIRenderer(BaseRenderer): media_type = 'text/html' format = 'api' template = 'rest_framework/api.html' + charset = 'utf-8' def get_default_renderer(self, view): """ @@ -320,8 +348,8 @@ class BrowsableAPIRenderer(BaseRenderer): renderer_context['indent'] = 4 content = renderer.render(data, accepted_media_type, renderer_context) - if not all(char in string.printable for char in content): - return '[%d bytes of binary content]' + if renderer.charset is None: + return '[%d bytes of binary content]' % len(content) return content @@ -336,7 +364,9 @@ class BrowsableAPIRenderer(BaseRenderer): return # Cannot use form overloading try: - view.check_permissions(clone_request(request, method)) + view.check_permissions(request) + if obj is not None: + view.check_object_permissions(request, obj) except exceptions.APIException: return False # Doesn't have permissions return True @@ -366,12 +396,40 @@ class BrowsableAPIRenderer(BaseRenderer): if getattr(v, 'default', None) is not None: kwargs['initial'] = v.default - kwargs['label'] = k + if getattr(v, 'label', None) is not None: + kwargs['label'] = v.label + + if getattr(v, 'help_text', None) is not None: + kwargs['help_text'] = v.help_text fields[k] = v.form_field_class(**kwargs) return fields + def _get_form(self, view, method, request): + # We need to impersonate a request with the correct method, + # so that eg. any dynamic get_serializer_class methods return the + # correct form for each method. + restore = view.request + request = clone_request(request, method) + view.request = request + try: + return self.get_form(view, method, request) + finally: + view.request = restore + + def _get_raw_data_form(self, view, method, request, media_types): + # We need to impersonate a request with the correct method, + # so that eg. any dynamic get_serializer_class methods return the + # correct form for each method. + restore = view.request + request = clone_request(request, method) + view.request = request + try: + return self.get_raw_data_form(view, method, request, media_types) + finally: + view.request = restore + def get_form(self, view, method, request): """ Get a form, possibly bound to either the input or output data. @@ -449,10 +507,7 @@ class BrowsableAPIRenderer(BaseRenderer): def render(self, data, accepted_media_type=None, renderer_context=None): """ - Renders *obj* using the :attr:`template` set on the class. - - The context used in the template contains all the information - needed to self-document the response to this request. + Render the HTML for the browsable API representation. """ accepted_media_type = accepted_media_type or '' renderer_context = renderer_context or {} @@ -465,15 +520,15 @@ class BrowsableAPIRenderer(BaseRenderer): renderer = self.get_default_renderer(view) content = self.get_content(renderer, data, accepted_media_type, renderer_context) - put_form = self.get_form(view, 'PUT', request) - post_form = self.get_form(view, 'POST', request) - patch_form = self.get_form(view, 'PATCH', request) - delete_form = self.get_form(view, 'DELETE', request) - options_form = self.get_form(view, 'OPTIONS', request) + put_form = self._get_form(view, 'PUT', request) + post_form = self._get_form(view, 'POST', request) + patch_form = self._get_form(view, 'PATCH', request) + delete_form = self._get_form(view, 'DELETE', request) + options_form = self._get_form(view, 'OPTIONS', request) - raw_data_put_form = self.get_raw_data_form(view, 'PUT', request, media_types) - raw_data_post_form = self.get_raw_data_form(view, 'POST', request, media_types) - raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request, media_types) + raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types) + raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types) + raw_data_patch_form = self._get_raw_data_form(view, 'PATCH', request, media_types) raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form name = self.get_name(view) diff --git a/rest_framework/request.py b/rest_framework/request.py index a434659c..0d88ebc7 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -173,7 +173,7 @@ class Request(object): by the authentication classes provided to the request. """ if not hasattr(self, '_user'): - self._authenticator, self._user, self._auth = self._authenticate() + self._authenticate() return self._user @user.setter @@ -192,7 +192,7 @@ class Request(object): request, such as an authentication token. """ if not hasattr(self, '_auth'): - self._authenticator, self._user, self._auth = self._authenticate() + self._authenticate() return self._auth @auth.setter @@ -210,7 +210,7 @@ class Request(object): to authenticate the request, or `None`. """ if not hasattr(self, '_authenticator'): - self._authenticator, self._user, self._auth = self._authenticate() + self._authenticate() return self._authenticator def _load_data_and_files(self): @@ -330,11 +330,18 @@ class Request(object): Returns a three-tuple of (authenticator, user, authtoken). """ for authenticator in self.authenticators: - user_auth_tuple = authenticator.authenticate(self) + try: + user_auth_tuple = authenticator.authenticate(self) + except exceptions.APIException: + self._not_authenticated() + raise + if not user_auth_tuple is None: - user, auth = user_auth_tuple - return (authenticator, user, auth) - return self._not_authenticated() + self._authenticator = authenticator + self._user, self._auth = user_auth_tuple + return + + self._not_authenticated() def _not_authenticated(self): """ @@ -343,17 +350,17 @@ class Request(object): By default this will be (None, AnonymousUser, None). """ + self._authenticator = None + if api_settings.UNAUTHENTICATED_USER: - user = api_settings.UNAUTHENTICATED_USER() + self._user = api_settings.UNAUTHENTICATED_USER() else: - user = None + self._user = None if api_settings.UNAUTHENTICATED_TOKEN: - auth = api_settings.UNAUTHENTICATED_TOKEN() + self._auth = api_settings.UNAUTHENTICATED_TOKEN() else: - auth = None - - return (None, user, auth) + self._auth = None def __getattr__(self, attr): """ diff --git a/rest_framework/response.py b/rest_framework/response.py index 26e4ab37..5877c8a3 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -1,5 +1,5 @@ """ -The Response class in REST framework is similiar to HTTPResponse, except that +The Response class in REST framework is similar to HTTPResponse, except that it is initialized with unrendered data, instead of a pre-rendered string. The appropriate renderer is called during Django's template response rendering. @@ -12,13 +12,13 @@ from rest_framework.compat import six class Response(SimpleTemplateResponse): """ - An HttpResponse that allows it's data to be rendered into + An HttpResponse that allows its data to be rendered into arbitrary media types. """ def __init__(self, data=None, status=200, template_name=None, headers=None, - exception=False): + exception=False, content_type=None): """ Alters the init arguments slightly. For example, drop 'template_name', and instead use 'data'. @@ -30,6 +30,7 @@ class Response(SimpleTemplateResponse): self.data = data self.template_name = template_name self.exception = exception + self.content_type = content_type if headers: for name, value in six.iteritems(headers): @@ -46,8 +47,21 @@ class Response(SimpleTemplateResponse): assert context, ".renderer_context not set on Response" context['response'] = self - self['Content-Type'] = media_type - return renderer.render(self.data, media_type, context) + charset = renderer.charset + content_type = self.content_type + + if content_type is None and charset is not None: + content_type = "{0}; charset={1}".format(media_type, charset) + elif content_type is None: + content_type = media_type + self['Content-Type'] = content_type + + ret = renderer.render(self.data, media_type, context) + if isinstance(ret, six.text_type): + assert charset, 'renderer returned unicode, and did not specify ' \ + 'a charset value.' + return bytes(ret.encode(charset)) + return ret @property def status_text(self): diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 0707635a..9764e569 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -16,8 +16,8 @@ For example, you might have a `urls.py` that looks something like this: from __future__ import unicode_literals from collections import namedtuple -from django.conf.urls import url, patterns -from rest_framework.decorators import api_view +from rest_framework import views +from rest_framework.compat import patterns, url from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.urlpatterns import format_suffix_patterns @@ -71,7 +71,7 @@ class SimpleRouter(BaseRouter): routes = [ # List route. Route( - url=r'^{prefix}/$', + url=r'^{prefix}{trailing_slash}$', mapping={ 'get': 'list', 'post': 'create' @@ -81,7 +81,7 @@ class SimpleRouter(BaseRouter): ), # Detail route. Route( - url=r'^{prefix}/{lookup}/$', + url=r'^{prefix}/{lookup}{trailing_slash}$', mapping={ 'get': 'retrieve', 'put': 'update', @@ -94,7 +94,7 @@ class SimpleRouter(BaseRouter): # Dynamically generated routes. # Generated using @action or @link decorators on methods of the viewset. Route( - url=r'^{prefix}/{lookup}/{methodname}/$', + url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$', mapping={ '{httpmethod}': '{methodname}', }, @@ -103,6 +103,10 @@ class SimpleRouter(BaseRouter): ), ] + def __init__(self, trailing_slash=True): + self.trailing_slash = trailing_slash and '/' or '' + super(SimpleRouter, self).__init__() + def get_default_base_name(self, viewset): """ If `base_name` is not specified, attempt to automatically determine @@ -127,23 +131,23 @@ class SimpleRouter(BaseRouter): """ # Determine any `@action` or `@link` decorated methods on the viewset - dynamic_routes = {} + dynamic_routes = [] for methodname in dir(viewset): attr = getattr(viewset, methodname) - httpmethod = getattr(attr, 'bind_to_method', None) - if httpmethod: - dynamic_routes[httpmethod] = methodname + httpmethods = getattr(attr, 'bind_to_methods', None) + if httpmethods: + dynamic_routes.append((httpmethods, methodname)) ret = [] for route in self.routes: if route.mapping == {'{httpmethod}': '{methodname}'}: # Dynamic routes (@link or @action decorator) - for httpmethod, methodname in dynamic_routes.items(): + for httpmethods, methodname in dynamic_routes: initkwargs = route.initkwargs.copy() initkwargs.update(getattr(viewset, methodname).kwargs) ret.append(Route( url=replace_methodname(route.url, methodname), - mapping={httpmethod: methodname}, + mapping=dict((httpmethod, methodname) for httpmethod in httpmethods), name=replace_methodname(route.name, methodname), initkwargs=initkwargs, )) @@ -192,7 +196,11 @@ class SimpleRouter(BaseRouter): continue # Build the url pattern - regex = route.url.format(prefix=prefix, lookup=lookup) + regex = route.url.format( + prefix=prefix, + lookup=lookup, + trailing_slash=self.trailing_slash + ) view = viewset.as_view(mapping, **route.initkwargs) name = route.name.format(basename=basename) ret.append(url(regex, view, name=name)) @@ -217,14 +225,16 @@ class DefaultRouter(SimpleRouter): for prefix, viewset, basename in self.registry: api_root_dict[prefix] = list_name.format(basename=basename) - @api_view(('GET',)) - def api_root(request, format=None): - ret = {} - for key, url_name in api_root_dict.items(): - ret[key] = reverse(url_name, request=request, format=format) - return Response(ret) + class APIRoot(views.APIView): + _ignore_model_permissions = True + + def get(self, request, format=None): + ret = {} + for key, url_name in api_root_dict.items(): + ret[key] = reverse(url_name, request=request, format=format) + return Response(ret) - return api_root + return APIRoot.as_view() def get_urls(self): """ diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py index 4a333fb3..da36d23f 100755 --- a/rest_framework/runtests/runtests.py +++ b/rest_framework/runtests/runtests.py @@ -10,6 +10,7 @@ import sys sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings' +import django from django.conf import settings from django.test.utils import get_runner @@ -35,7 +36,11 @@ def main(): else: print(usage()) sys.exit(1) - failures = test_runner.run_tests(['tests' + test_case]) + test_module_name = 'rest_framework.tests' + if django.VERSION[0] == 1 and django.VERSION[1] < 6: + test_module_name = 'tests' + + failures = test_runner.run_tests([test_module_name + test_case]) sys.exit(failures) diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index 9b519f27..9dd7b545 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -4,6 +4,8 @@ DEBUG = True TEMPLATE_DEBUG = DEBUG DEBUG_PROPAGATE_EXCEPTIONS = True +ALLOWED_HOSTS = ['*'] + ADMINS = ( # ('Your Name', 'your_email@domain.com'), ) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 87f08374..4acbc704 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -25,7 +25,7 @@ from rest_framework.compat import get_concrete_model, six # # example_field = serializers.CharField(...) # -# This helps keep the seperation between model fields, form fields, and +# This helps keep the separation between model fields, form fields, and # serializer fields more explicit. from rest_framework.relations import * @@ -61,7 +61,7 @@ class DictWithMetadata(dict): def __getstate__(self): """ Used by pickle (e.g., caching). - Overriden to remove the metadata from the dict, since it shouldn't be + Overridden to remove the metadata from the dict, since it shouldn't be pickled and may in some instances be unpickleable. """ return dict(self) @@ -202,7 +202,7 @@ class BaseSerializer(WritableField): # If 'fields' is specified, use those fields, in that order. if self.opts.fields: - assert isinstance(self.opts.fields, (list, tuple)), '`include` must be a list or tuple' + assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple' new = SortedDict() for key in self.opts.fields: new[key] = ret[key] @@ -210,7 +210,7 @@ class BaseSerializer(WritableField): # Remove anything in 'exclude' if self.opts.exclude: - assert isinstance(self.opts.fields, (list, tuple)), '`exclude` must be a list or tuple' + assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple' for key in self.opts.exclude: ret.pop(key, None) @@ -317,7 +317,8 @@ class BaseSerializer(WritableField): self._errors = {} if data is not None or files is not None: attrs = self.restore_fields(data, files) - attrs = self.perform_validation(attrs) + if attrs is not None: + attrs = self.perform_validation(attrs) else: self._errors['non_field_errors'] = ['No input provided'] @@ -381,24 +382,28 @@ class BaseSerializer(WritableField): obj = getattr(self.parent.object, field_name) if self.parent.object else None obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj - if value in (None, ''): - into[(self.source or field_name)] = None + if self.source == '*': + if value: + into.update(value) else: - kwargs = { - 'instance': obj, - 'data': value, - 'context': self.context, - 'partial': self.partial, - 'many': self.many, - 'allow_add_remove': self.allow_add_remove - } - serializer = self.__class__(**kwargs) - - if serializer.is_valid(): - into[self.source or field_name] = serializer.object + if value in (None, ''): + into[(self.source or field_name)] = None else: - # Propagate errors up to our parent - raise NestedValidationError(serializer.errors) + kwargs = { + 'instance': obj, + 'data': value, + 'context': self.context, + 'partial': self.partial, + 'many': self.many, + 'allow_add_remove': self.allow_add_remove + } + serializer = self.__class__(**kwargs) + + if serializer.is_valid(): + into[self.source or field_name] = serializer.object + else: + # Propagate errors up to our parent + raise NestedValidationError(serializer.errors) def get_identity(self, data): """ @@ -521,6 +526,17 @@ class BaseSerializer(WritableField): return self.object + def metadata(self): + """ + Return a dictionary of metadata about the fields on the serializer. + Useful for things like responding to OPTIONS requests, or generating + API schemas for auto-documentation. + """ + return SortedDict( + [(field_name, field.metadata()) + for field_name, field in six.iteritems(self.fields)] + ) + class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)): pass @@ -591,11 +607,16 @@ class ModelSerializer(Serializer): forward_rels += [field for field in opts.many_to_many if field.serialize] for model_field in forward_rels: + has_through_model = False + if model_field.rel: to_many = isinstance(model_field, models.fields.related.ManyToManyField) related_model = model_field.rel.to + if to_many and not model_field.rel.through._meta.auto_created: + has_through_model = True + if model_field.rel and nested: if len(inspect.getargspec(self.get_nested_field).args) == 2: warnings.warn( @@ -624,6 +645,9 @@ class ModelSerializer(Serializer): field = self.get_field(model_field) if field: + if has_through_model: + field.read_only = True + ret[model_field.name] = field # Deal with reverse relationships @@ -641,6 +665,12 @@ class ModelSerializer(Serializer): continue related_model = relation.model to_many = relation.field.rel.multiple + has_through_model = False + is_m2m = isinstance(relation.field, + models.fields.related.ManyToManyField) + + if is_m2m and not relation.field.rel.through._meta.auto_created: + has_through_model = True if nested: field = self.get_nested_field(None, related_model, to_many) @@ -648,13 +678,22 @@ class ModelSerializer(Serializer): field = self.get_related_field(None, related_model, to_many) if field: + if has_through_model: + field.read_only = True + ret[accessor_name] = field # Add the `read_only` flag to any fields that have bee specified # in the `read_only_fields` option for field_name in self.opts.read_only_fields: + assert field_name not in self.base_fields.keys(), \ + "field '%s' on serializer '%s' specfied in " \ + "`read_only_fields`, but also added " \ + "as an explict field. Remove it from `read_only_fields`." % \ + (field_name, self.__class__.__name__) assert field_name in ret, \ - "read_only_fields on '%s' included invalid item '%s'" % \ + "Noexistant field '%s' specified in `read_only_fields` " \ + "on serializer '%s'." % \ (self.__class__.__name__, field_name) ret[field_name].read_only = True @@ -703,25 +742,51 @@ class ModelSerializer(Serializer): Creates a default instance of a basic non-relational field. """ kwargs = {} - has_default = model_field.has_default() - if model_field.null or model_field.blank or has_default: + if model_field.null or model_field.blank: kwargs['required'] = False if isinstance(model_field, models.AutoField) or not model_field.editable: kwargs['read_only'] = True - if has_default: + if model_field.has_default(): kwargs['default'] = model_field.get_default() if issubclass(model_field.__class__, models.TextField): kwargs['widget'] = widgets.Textarea + if model_field.verbose_name is not None: + kwargs['label'] = model_field.verbose_name + + if model_field.help_text is not None: + kwargs['help_text'] = model_field.help_text + # TODO: TypedChoiceField? if model_field.flatchoices: # This ModelField contains choices kwargs['choices'] = model_field.flatchoices return ChoiceField(**kwargs) + # put this below the ChoiceField because min_value isn't a valid initializer + if issubclass(model_field.__class__, models.PositiveIntegerField) or\ + issubclass(model_field.__class__, models.PositiveSmallIntegerField): + kwargs['min_value'] = 0 + + attribute_dict = { + models.CharField: ['max_length'], + models.CommaSeparatedIntegerField: ['max_length'], + models.DecimalField: ['max_digits', 'decimal_places'], + models.EmailField: ['max_length'], + models.FileField: ['max_length'], + models.ImageField: ['max_length'], + models.SlugField: ['max_length'], + models.URLField: ['max_length'], + } + + if model_field.__class__ in attribute_dict: + attributes = attribute_dict[model_field.__class__] + for attribute in attributes: + kwargs.update({attribute: getattr(model_field, attribute)}) + try: return self.field_mapping[model_field.__class__](**kwargs) except KeyError: @@ -867,7 +932,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): def __init__(self, meta): super(HyperlinkedModelSerializerOptions, self).__init__(meta) self.view_name = getattr(meta, 'view_name', None) - self.lookup_field = getattr(meta, 'slug_field', None) + self.lookup_field = getattr(meta, 'lookup_field', None) class HyperlinkedModelSerializer(ModelSerializer): @@ -879,13 +944,24 @@ class HyperlinkedModelSerializer(ModelSerializer): _default_view_name = '%(model_name)s-detail' _hyperlink_field_class = HyperlinkedRelatedField - url = HyperlinkedIdentityField() + # Just a placeholder to ensure 'url' is the first field + # The field itself is actually created on initialization, + # when the view_name and lookup_field arguments are available. + url = Field() def __init__(self, *args, **kwargs): super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs) + if self.opts.view_name is None: self.opts.view_name = self._get_default_view_name(self.opts.model) + url_field = HyperlinkedIdentityField( + view_name=self.opts.view_name, + lookup_field=self.opts.lookup_field + ) + url_field.initialize(self, 'url') + self.fields['url'] = url_field + def _get_default_view_name(self, model): """ Return the view name to use if 'view_name' is not specified in 'Meta' diff --git a/rest_framework/static/rest_framework/css/bootstrap-tweaks.css b/rest_framework/static/rest_framework/css/bootstrap-tweaks.css index c650ef2e..6bfb778c 100644 --- a/rest_framework/static/rest_framework/css/bootstrap-tweaks.css +++ b/rest_framework/static/rest_framework/css/bootstrap-tweaks.css @@ -19,4 +19,167 @@ a single block in the template. .navbar-inverse .brand:hover a { color: white; text-decoration: none; -}
\ No newline at end of file +} + +/* custom navigation styles */ +.wrapper .navbar{ + width: 100%; + position: absolute; + left: 0; + top: 0; +} + +.navbar .navbar-inner{ + background: #2C2C2C; + color: white; + border: none; + border-top: 5px solid #A30000; + border-radius: 0px; +} + +.navbar .navbar-inner .nav li, .navbar .navbar-inner .nav li a, .navbar .navbar-inner .brand:hover{ + color: white; +} + +.nav-list > .active > a, .nav-list > .active > a:hover { + background: #2c2c2c; +} + +.navbar .navbar-inner .dropdown-menu li a, .navbar .navbar-inner .dropdown-menu li{ + color: #A30000; +} +.navbar .navbar-inner .dropdown-menu li a:hover{ + background: #eeeeee; + color: #c20000; +} + +/*=== dabapps bootstrap styles ====*/ + +html{ + width:100%; + background: none; +} + +body, .navbar .navbar-inner .container-fluid { + max-width: 1150px; + margin: 0 auto; +} + +body{ + background: url("../img/grid.png") repeat-x; + background-attachment: fixed; +} + +#content{ + margin: 0; +} + +/* sticky footer and footer */ +html, body { + height: 100%; +} +.wrapper { + min-height: 100%; + height: auto !important; + height: 100%; + margin: 0 auto -60px; +} + +.form-switcher { + margin-bottom: 0; +} + +.well { + -webkit-box-shadow: none; + -moz-box-shadow: none; + box-shadow: none; +} + +.well .form-actions { + padding-bottom: 0; + margin-bottom: 0; +} + +.well form { + margin-bottom: 0; +} + +.well form .help-block { + color: #999; +} + +.nav-tabs { + border: 0; +} + +.nav-tabs > li { + float: right; +} + +.nav-tabs li a { + margin-right: 0; +} + +.nav-tabs > .active > a { + background: #f5f5f5; +} + +.nav-tabs > .active > a:hover { + background: #f5f5f5; +} + +.tabbable.first-tab-active .tab-content +{ + border-top-right-radius: 0; +} + +#footer, #push { + height: 60px; /* .push must be the same height as .footer */ +} + +#footer{ + text-align: right; +} + +#footer p { + text-align: center; + color: gray; + border-top: 1px solid #DDD; + padding-top: 10px; +} + +#footer a { + color: gray; + font-weight: bold; +} + +#footer a:hover { + color: gray; +} + +.page-header { + border-bottom: none; + padding-bottom: 0px; + margin-bottom: 20px; +} + +/* custom general page styles */ +.hero-unit h2, .hero-unit h1{ + color: #A30000; +} + +body a, body a{ + color: #A30000; +} + +body a:hover{ + color: #c20000; +} + +#content a span{ + text-decoration: underline; + } + +.request-info { + clear:both; +} diff --git a/rest_framework/static/rest_framework/css/default.css b/rest_framework/static/rest_framework/css/default.css index d806267b..0261a303 100644 --- a/rest_framework/static/rest_framework/css/default.css +++ b/rest_framework/static/rest_framework/css/default.css @@ -69,152 +69,3 @@ pre { margin-bottom: 20px; } - -/*=== dabapps bootstrap styles ====*/ - -html{ - width:100%; - background: none; -} - -body, .navbar .navbar-inner .container-fluid { - max-width: 1150px; - margin: 0 auto; -} - -body{ - background: url("../img/grid.png") repeat-x; - background-attachment: fixed; -} - -#content{ - margin: 0; -} -/* custom navigation styles */ -.wrapper .navbar{ - width: 100%; - position: absolute; - left: 0; - top: 0; -} - -.navbar .navbar-inner{ - background: #2C2C2C; - color: white; - border: none; - border-top: 5px solid #A30000; - border-radius: 0px; -} - -.navbar .navbar-inner .nav li, .navbar .navbar-inner .nav li a, .navbar .navbar-inner .brand{ - color: white; -} - -.nav-list > .active > a, .nav-list > .active > a:hover { - background: #2c2c2c; -} - -.navbar .navbar-inner .dropdown-menu li a, .navbar .navbar-inner .dropdown-menu li{ - color: #A30000; -} -.navbar .navbar-inner .dropdown-menu li a:hover{ - background: #eeeeee; - color: #c20000; -} - -/* custom general page styles */ -.hero-unit h2, .hero-unit h1{ - color: #A30000; -} - -body a, body a{ - color: #A30000; -} - -body a:hover{ - color: #c20000; -} - -#content a span{ - text-decoration: underline; - } - -/* sticky footer and footer */ -html, body { - height: 100%; -} -.wrapper { - min-height: 100%; - height: auto !important; - height: 100%; - margin: 0 auto -60px; -} - -.form-switcher { - margin-bottom: 0; -} - -.well { - -webkit-box-shadow: none; - -moz-box-shadow: none; - box-shadow: none; -} - -.well .form-actions { - padding-bottom: 0; - margin-bottom: 0; -} - -.well form { - margin-bottom: 0; -} - -.nav-tabs { - border: 0; -} - -.nav-tabs > li { - float: right; -} - -.nav-tabs li a { - margin-right: 0; -} - -.nav-tabs > .active > a { - background: #f5f5f5; -} - -.nav-tabs > .active > a:hover { - background: #f5f5f5; -} - -.tabbable.first-tab-active .tab-content -{ - border-top-right-radius: 0; -} - -#footer, #push { - height: 60px; /* .push must be the same height as .footer */ -} - -#footer{ - text-align: right; -} - -#footer p { - text-align: center; - color: gray; - border-top: 1px solid #DDD; - padding-top: 10px; -} - -#footer a { - color: gray; - font-weight: bold; -} - -#footer a:hover { - color: gray; -} - diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 4410f285..9d939e73 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -13,8 +13,10 @@ <title>{% block title %}Django REST framework{% endblock %}</title> {% block style %} - {% block bootstrap_theme %}<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>{% endblock %} - <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> + {% block bootstrap_theme %} + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> + {% endblock %} <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/prettify.css" %}"/> <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/> {% endblock %} @@ -30,8 +32,8 @@ <div class="navbar {% block bootstrap_navbar_variant %}navbar-inverse{% endblock %}"> <div class="navbar-inner"> <div class="container-fluid"> - <span class="brand" href="/"> - {% block branding %}<a href='http://django-rest-framework.org'>Django REST framework <span class="version">{{ version }}</span></a>{% endblock %} + <span href="/"> + {% block branding %}<a class='brand' href='http://django-rest-framework.org'>Django REST framework <span class="version">{{ version }}</span></a>{% endblock %} </span> <ul class="nav pull-right"> {% block userlinks %} @@ -109,8 +111,7 @@ <div class="content-main"> <div class="page-header"><h1>{{ name }}</h1></div> {{ description }} - - <div class="request-info"> + <div class="request-info" style="clear: both" > <pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre> </div> <div class="response-info"> diff --git a/rest_framework/templates/rest_framework/form.html b/rest_framework/templates/rest_framework/form.html index dc7acc70..b27f652e 100644 --- a/rest_framework/templates/rest_framework/form.html +++ b/rest_framework/templates/rest_framework/form.html @@ -6,7 +6,7 @@ {{ field.label_tag|add_class:"control-label" }} <div class="controls"> {{ field }} - <span class="help-inline">{{ field.help_text }}</span> + <span class="help-block">{{ field.help_text }}</span> <!--{{ field.errors|add_class:"help-block" }}--> </div> </div> diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html index 380d5820..be9a0072 100644 --- a/rest_framework/templates/rest_framework/login_base.html +++ b/rest_framework/templates/rest_framework/login_base.html @@ -4,52 +4,50 @@ <head> {% block style %} - {% block bootstrap_theme %}<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>{% endblock %} - <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> + {% block bootstrap_theme %} + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> + {% endblock %} <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/> {% endblock %} </head> <body class="container"> -<div class="container-fluid" style="margin-top: 30px"> - <div class="row-fluid"> - - <div class="well" style="width: 320px; margin-left: auto; margin-right: auto"> - <div class="row-fluid"> - <div> - {% block branding %}<h3 style="margin: 0 0 20px;">Django REST framework</h3>{% endblock %} - </div> - </div><!-- /row fluid --> - + <div class="container-fluid" style="margin-top: 30px"> <div class="row-fluid"> - <div> - <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> - {% csrf_token %} - <div id="div_id_username" class="clearfix control-group"> - <div class="controls"> - <Label class="span4">Username:</label> - <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> - </div> + <div class="well" style="width: 320px; margin-left: auto; margin-right: auto"> + <div class="row-fluid"> + <div> + {% block branding %}<h3 style="margin: 0 0 20px;">Django REST framework</h3>{% endblock %} </div> - <div id="div_id_password" class="clearfix control-group"> - <div class="controls"> - <Label class="span4">Password:</label> - <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> - </div> + </div><!-- /row fluid --> + + <div class="row-fluid"> + <div> + <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> + {% csrf_token %} + <div id="div_id_username" class="clearfix control-group"> + <div class="controls"> + <Label class="span4">Username:</label> + <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> + </div> + </div> + <div id="div_id_password" class="clearfix control-group"> + <div class="controls"> + <Label class="span4">Password:</label> + <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> + </div> + </div> + <input type="hidden" name="next" value="{{ next }}" /> + <div class="form-actions-no-box"> + <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit"> + </div> + </form> </div> - <input type="hidden" name="next" value="{{ next }}" /> - <div class="form-actions-no-box"> - <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit"> - </div> - </form> - </div> - </div><!-- /row fluid --> - </div><!--/span--> - - </div><!-- /.row-fluid --> - </div> - - </div> + </div><!-- /.row-fluid --> + </div><!--/.well--> + </div><!-- /.row-fluid --> + </div><!-- /.container-fluid --> </body> </html> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index c86b6456..e9c1cdd5 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -15,7 +15,7 @@ register = template.Library() # When 1.3 becomes unsupported by REST framework, we can instead start to # use the {% load staticfiles %} tag, remove the following code, -# and add a dependancy that `django.contrib.staticfiles` must be installed. +# and add a dependency that `django.contrib.staticfiles` must be installed. # Note: We can't put this into the `compat` module because the compat import # from rest_framework.compat import ... diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index f2117538..e2d4eacd 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -1,5 +1,7 @@ from __future__ import unicode_literals from django.db import models +from django.utils.translation import ugettext_lazy as _ +from rest_framework import serializers def foobar(): @@ -32,7 +34,7 @@ class Anchor(RESTFrameworkModel): class BasicModel(RESTFrameworkModel): - text = models.CharField(max_length=100) + text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description.")) class SlugBasedModel(RESTFrameworkModel): @@ -58,13 +60,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel): rel = models.ManyToManyField(Anchor) -# Model to test filtering. -class FilterableItem(RESTFrameworkModel): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() - - # Model for regression test for #285 class Comment(RESTFrameworkModel): @@ -166,3 +161,9 @@ class NullableOneToOneSource(RESTFrameworkModel): name = models.CharField(max_length=100) target = models.OneToOneField(OneToOneTarget, null=True, blank=True, related_name='nullable_source') + + +# Serializer used to test BasicModel +class BasicModelSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py deleted file mode 100644 index cbf93c65..00000000 --- a/rest_framework/tests/relations.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -General tests for relational fields. -""" -from __future__ import unicode_literals -from django.db import models -from django.test import TestCase -from rest_framework import serializers - - -class NullModel(models.Model): - pass - - -class FieldTests(TestCase): - def test_pk_related_field_with_empty_string(self): - """ - Regression test for #446 - - https://github.com/tomchristie/django-rest-framework/issues/446 - """ - field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) - self.assertRaises(serializers.ValidationError, field.from_native, '') - self.assertRaises(serializers.ValidationError, field.from_native, []) - - def test_hyperlinked_related_field_with_empty_string(self): - field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') - self.assertRaises(serializers.ValidationError, field.from_native, '') - self.assertRaises(serializers.ValidationError, field.from_native, []) - - def test_slug_related_field_with_empty_string(self): - field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') - self.assertRaises(serializers.ValidationError, field.from_native, '') - self.assertRaises(serializers.ValidationError, field.from_native, []) - - -class TestManyRelateMixin(TestCase): - def test_missing_many_to_many_related_field(self): - ''' - Regression test for #632 - - https://github.com/tomchristie/django-rest-framework/pull/632 - ''' - field = serializers.RelatedField(many=True, read_only=False) - - into = {} - field.field_from_native({}, None, 'field_name', into) - self.assertEqual(into['field_name'], []) diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/test_authentication.py index 8e6d3e51..d46ac079 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -6,6 +6,8 @@ from django.utils import unittest from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions from rest_framework import permissions +from rest_framework import renderers +from rest_framework.response import Response from rest_framework import status from rest_framework.authentication import ( BaseAuthentication, @@ -48,7 +50,7 @@ urlpatterns = patterns('', (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), - (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], + (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], permission_classes=[permissions.TokenHasReadWriteScope])) ) @@ -56,14 +58,14 @@ if oauth2_provider is not None: urlpatterns += patterns('', url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), - url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], + url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], permission_classes=[permissions.TokenHasReadWriteScope])), ) class BasicAuthTests(TestCase): """Basic authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) @@ -102,7 +104,7 @@ class BasicAuthTests(TestCase): class SessionAuthTests(TestCase): """User session authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) @@ -149,7 +151,7 @@ class SessionAuthTests(TestCase): class TokenAuthTests(TestCase): """Token authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) @@ -243,7 +245,7 @@ class IncorrectCredentialsTests(TestCase): class OAuthTests(TestCase): """OAuth 1.0a authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): # these imports are here because oauth is optional and hiding them in try..except block or compat @@ -429,7 +431,7 @@ class OAuthTests(TestCase): class OAuth2Tests(TestCase): """OAuth 2.0 authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) @@ -553,3 +555,40 @@ class OAuth2Tests(TestCase): auth = self._create_authorization_header(token=read_write_access_token.token) response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) + + +class FailingAuthAccessedInRenderer(TestCase): + def setUp(self): + class AuthAccessingRenderer(renderers.BaseRenderer): + media_type = 'text/plain' + format = 'txt' + + def render(self, data, media_type=None, renderer_context=None): + request = renderer_context['request'] + if request.user.is_authenticated(): + return b'authenticated' + return b'not authenticated' + + class FailingAuth(BaseAuthentication): + def authenticate(self, request): + raise exceptions.AuthenticationFailed('authentication failed') + + class ExampleView(APIView): + authentication_classes = (FailingAuth,) + renderer_classes = (AuthAccessingRenderer,) + + def get(self, request): + return Response({'foo': 'bar'}) + + self.view = ExampleView.as_view() + + def test_failing_auth_accessed_in_renderer(self): + """ + When authentication fails the renderer should still be able to access + `request.user` without raising an exception. Particularly relevant + to HTML responses that might reasonably access `request.user`. + """ + request = factory.get('/') + response = self.view(request) + content = response.render().content + self.assertEqual(content, b'not authenticated') diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/test_breadcrumbs.py index d9ed647e..41ddf2ce 100644 --- a/rest_framework/tests/breadcrumbs.py +++ b/rest_framework/tests/test_breadcrumbs.py @@ -36,7 +36,7 @@ urlpatterns = patterns('', class BreadcrumbTests(TestCase): """Tests the breadcrumb functionality used by the HTML renderer.""" - urls = 'rest_framework.tests.breadcrumbs' + urls = 'rest_framework.tests.test_breadcrumbs' def test_root_breadcrumbs(self): url = '/' diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/test_decorators.py index 1016fed3..1016fed3 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/test_decorators.py diff --git a/rest_framework/tests/description.py b/rest_framework/tests/test_description.py index 52c1a34c..52c1a34c 100644 --- a/rest_framework/tests/description.py +++ b/rest_framework/tests/test_description.py diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/test_fields.py index 3cdfa0f6..69a0468e 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/test_fields.py @@ -2,15 +2,16 @@ General serializer field tests. """ from __future__ import unicode_literals + import datetime from decimal import Decimal - +from uuid import uuid4 +from django.core import validators from django.db import models from django.test import TestCase -from django.core import validators - +from django.utils.datastructures import SortedDict from rest_framework import serializers -from rest_framework.serializers import Serializer +from rest_framework.tests.models import RESTFrameworkModel class TimestampedModel(models.Model): @@ -63,6 +64,20 @@ class BasicFieldTests(TestCase): serializer = CharPrimaryKeyModelSerializer() self.assertEqual(serializer.fields['id'].read_only, False) + def test_dict_field_ordering(self): + """ + Field should preserve dictionary ordering, if it exists. + See: https://github.com/tomchristie/django-rest-framework/issues/832 + """ + ret = SortedDict() + ret['c'] = 1 + ret['b'] = 1 + ret['a'] = 1 + ret['z'] = 1 + field = serializers.Field() + keys = list(field.to_native(ret).keys()) + self.assertEqual(keys, ['c', 'b', 'a', 'z']) + class DateFieldTest(TestCase): """ @@ -573,7 +588,7 @@ class DecimalFieldTest(TestCase): """ Make sure the serializer works correctly """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_value=9010, min_value=9000, max_digits=6, @@ -591,7 +606,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_value violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_value=100) s = DecimalSerializer(data={'decimal_field': '123'}) @@ -603,7 +618,7 @@ class DecimalFieldTest(TestCase): """ Make sure min_value violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(min_value=100) s = DecimalSerializer(data={'decimal_field': '99'}) @@ -615,7 +630,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_digits violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_digits=5) s = DecimalSerializer(data={'decimal_field': '123.456'}) @@ -627,7 +642,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_decimal_places violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(decimal_places=3) s = DecimalSerializer(data={'decimal_field': '123.4567'}) @@ -639,10 +654,215 @@ class DecimalFieldTest(TestCase): """ Make sure max_whole_digits violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) s = DecimalSerializer(data={'decimal_field': '12345.6'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
\ No newline at end of file + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) + + +class ChoiceFieldTests(TestCase): + """ + Tests for the ChoiceField options generator + """ + + SAMPLE_CHOICES = [ + ('red', 'Red'), + ('green', 'Green'), + ('blue', 'Blue'), + ] + + def test_choices_required(self): + """ + Make sure proper choices are rendered if field is required + """ + f = serializers.ChoiceField(required=True, choices=self.SAMPLE_CHOICES) + self.assertEqual(f.choices, self.SAMPLE_CHOICES) + + def test_choices_not_required(self): + """ + Make sure proper choices (plus blank) are rendered if the field isn't required + """ + f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES) + self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES) + + +class EmailFieldTests(TestCase): + """ + Tests for EmailField attribute values + """ + + class EmailFieldModel(RESTFrameworkModel): + email_field = models.EmailField(blank=True) + + class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel): + email_field = models.EmailField(max_length=150, blank=True) + + def test_default_model_value(self): + class EmailFieldSerializer(serializers.ModelSerializer): + class Meta: + model = self.EmailFieldModel + + serializer = EmailFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75) + + def test_given_model_value(self): + class EmailFieldSerializer(serializers.ModelSerializer): + class Meta: + model = self.EmailFieldWithGivenMaxLengthModel + + serializer = EmailFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150) + + def test_given_serializer_value(self): + class EmailFieldSerializer(serializers.ModelSerializer): + email_field = serializers.EmailField(source='email_field', max_length=20, required=False) + + class Meta: + model = self.EmailFieldModel + + serializer = EmailFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20) + + +class SlugFieldTests(TestCase): + """ + Tests for SlugField attribute values + """ + + class SlugFieldModel(RESTFrameworkModel): + slug_field = models.SlugField(blank=True) + + class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel): + slug_field = models.SlugField(max_length=84, blank=True) + + def test_default_model_value(self): + class SlugFieldSerializer(serializers.ModelSerializer): + class Meta: + model = self.SlugFieldModel + + serializer = SlugFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50) + + def test_given_model_value(self): + class SlugFieldSerializer(serializers.ModelSerializer): + class Meta: + model = self.SlugFieldWithGivenMaxLengthModel + + serializer = SlugFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84) + + def test_given_serializer_value(self): + class SlugFieldSerializer(serializers.ModelSerializer): + slug_field = serializers.SlugField(source='slug_field', + max_length=20, required=False) + + class Meta: + model = self.SlugFieldModel + + serializer = SlugFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['slug_field'], + 'max_length'), 20) + + def test_invalid_slug(self): + """ + Make sure an invalid slug raises ValidationError + """ + class SlugFieldSerializer(serializers.ModelSerializer): + slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True) + + class Meta: + model = self.SlugFieldModel + + s = SlugFieldSerializer(data={'slug_field': 'a b'}) + + self.assertEqual(s.is_valid(), False) + self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]}) + + +class URLFieldTests(TestCase): + """ + Tests for URLField attribute values + """ + + class URLFieldModel(RESTFrameworkModel): + url_field = models.URLField(blank=True) + + class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel): + url_field = models.URLField(max_length=128, blank=True) + + def test_default_model_value(self): + class URLFieldSerializer(serializers.ModelSerializer): + class Meta: + model = self.URLFieldModel + + serializer = URLFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['url_field'], + 'max_length'), 200) + + def test_given_model_value(self): + class URLFieldSerializer(serializers.ModelSerializer): + class Meta: + model = self.URLFieldWithGivenMaxLengthModel + + serializer = URLFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['url_field'], + 'max_length'), 128) + + def test_given_serializer_value(self): + class URLFieldSerializer(serializers.ModelSerializer): + url_field = serializers.URLField(source='url_field', + max_length=20, required=False) + + class Meta: + model = self.URLFieldWithGivenMaxLengthModel + + serializer = URLFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['url_field'], + 'max_length'), 20) + + +class FieldMetadata(TestCase): + def setUp(self): + self.required_field = serializers.Field() + self.required_field.label = uuid4().hex + self.required_field.required = True + + self.optional_field = serializers.Field() + self.optional_field.label = uuid4().hex + self.optional_field.required = False + + def test_required(self): + self.assertEqual(self.required_field.metadata()['required'], True) + + def test_optional(self): + self.assertEqual(self.optional_field.metadata()['required'], False) + + def test_label(self): + for field in (self.required_field, self.optional_field): + self.assertEqual(field.metadata()['label'], field.label) + + +class FieldCallableDefault(TestCase): + def setUp(self): + self.simple_callable = lambda: 'foo bar' + + def test_default_can_be_simple_callable(self): + """ + Ensure that the 'default' argument can also be a simple callable. + """ + field = serializers.WritableField(default=self.simple_callable) + into = {} + field.field_from_native({}, {}, 'field', into) + self.assertEqual(into, {'field': 'foo bar'}) diff --git a/rest_framework/tests/files.py b/rest_framework/tests/test_files.py index 487046ac..487046ac 100644 --- a/rest_framework/tests/files.py +++ b/rest_framework/tests/test_files.py diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/test_filters.py index 023bd016..aaed6247 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/test_filters.py @@ -1,23 +1,30 @@ from __future__ import unicode_literals import datetime from decimal import Decimal +from django.db import models from django.core.urlresolvers import reverse from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url -from rest_framework.tests.models import FilterableItem, BasicModel +from rest_framework.tests.models import BasicModel factory = RequestFactory() +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + if django_filters: # Basic filter on a list view. class FilterFieldsRootView(generics.ListCreateAPIView): model = FilterableItem filter_fields = ['decimal', 'date'] - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) # These class are used to test a filter class. class SeveralFieldsFilter(django_filters.FilterSet): @@ -32,7 +39,7 @@ if django_filters: class FilterClassRootView(generics.ListCreateAPIView): model = FilterableItem filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) # These classes are used to test a misconfigured filter class. class MisconfiguredFilter(django_filters.FilterSet): @@ -45,12 +52,12 @@ if django_filters: class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): model = FilterableItem filter_class = MisconfiguredFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) class FilterClassDetailView(generics.RetrieveAPIView): model = FilterableItem filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) # Regression test for #814 class FilterableItemSerializer(serializers.ModelSerializer): @@ -61,11 +68,21 @@ if django_filters: queryset = FilterableItem.objects.all() serializer_class = FilterableItemSerializer filter_fields = ['decimal', 'date'] - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) + + class GetQuerysetView(generics.ListCreateAPIView): + serializer_class = FilterableItemSerializer + filter_class = SeveralFieldsFilter + filter_backends = (filters.DjangoFilterBackend,) + + def get_queryset(self): + return FilterableItem.objects.all() urlpatterns = patterns('', url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), url(r'^$', FilterClassRootView.as_view(), name='root-view'), + url(r'^get-queryset/$', GetQuerysetView.as_view(), + name='get-queryset-view'), ) @@ -141,6 +158,17 @@ class IntegrationTestFiltering(CommonFilteringTestCase): self.assertEqual(response.data, expected_data) @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_filter_with_get_queryset_only(self): + """ + Regression test for #834. + """ + view = GetQuerysetView.as_view() + request = factory.get('/get-queryset/') + view(request).render() + # Used to raise "issubclass() arg 2 must be a class or tuple of classes" + # here when neither `model' nor `queryset' was specified. + + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_class_root_view(self): """ GET requests to filtered ListCreateAPIView that have a filter_class set @@ -215,7 +243,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): """ Integration tests for filtered detail views. """ - urls = 'rest_framework.tests.filterset' + urls = 'rest_framework.tests.test_filters' def _get_url(self, item): return reverse('detail-view', kwargs=dict(pk=item.pk)) @@ -256,3 +284,191 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, valid_item_data) + + +class SearchFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class SearchFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # z abc + # zz bcd + # zzz cde + # ... + for idx in range(10): + title = 'z' * (idx + 1) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + SearchFilterModel(title=title, text=text).save() + + def test_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'z', 'text': 'abc'}, + {'id': 2, 'title': 'zz', 'text': 'bcd'} + ] + ) + + def test_exact_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('=title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=zzz') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'zzz', 'text': 'cde'} + ] + ) + + def test_startswith_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', '^text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 2, 'title': 'zz', 'text': 'bcd'} + ] + ) + + +class OrdringFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class OrderingFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # zyx abc + # yxw bcd + # xwv cde + for idx in range(3): + title = ( + chr(ord('z') - idx) + + chr(ord('y') - idx) + + chr(ord('x') - idx) + ) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + OrdringFilterModel(title=title, text=text).save() + + def test_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=text') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + ] + ) + + def test_reverse_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=-text') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_incorrectfield_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=foobar') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_default_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_default_ordering_using_string(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = 'title' + + view = OrderingListView.as_view() + request = factory.get('') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/test_genericrelations.py index c38bfb9f..c38bfb9f 100644 --- a/rest_framework/tests/genericrelations.py +++ b/rest_framework/tests/test_genericrelations.py diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/test_generics.py index eca50d82..37734195 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/test_generics.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals from django.db import models from django.shortcuts import get_object_or_404 from django.test import TestCase -from rest_framework import generics, serializers, status +from rest_framework import generics, renderers, serializers, status from rest_framework.tests.utils import RequestFactory from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel from rest_framework.compat import six @@ -39,6 +39,7 @@ class SlugBasedInstanceView(InstanceView): """ model = SlugBasedModel serializer_class = SlugSerializer + lookup_field = 'slug' class TestRootView(TestCase): @@ -120,7 +121,25 @@ class TestRootView(TestCase): 'text/html' ], 'name': 'Root', - 'description': 'Example description for OPTIONS.' + 'description': 'Example description for OPTIONS.', + 'actions': { + 'POST': { + 'text': { + 'max_length': 100, + 'read_only': False, + 'required': True, + 'type': 'string', + "label": "Text comes here", + "help_text": "Text description." + }, + 'id': { + 'read_only': True, + 'required': False, + 'type': 'integer', + 'label': 'ID', + }, + } + } } self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, expected) @@ -223,9 +242,9 @@ class TestInstanceView(TestCase): """ OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata """ - request = factory.options('/') - with self.assertNumQueries(0): - response = self.view(request).render() + request = factory.options('/1') + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() expected = { 'parses': [ 'application/json', @@ -237,11 +256,39 @@ class TestInstanceView(TestCase): 'text/html' ], 'name': 'Instance', - 'description': 'Example description for OPTIONS.' + 'description': 'Example description for OPTIONS.', + 'actions': { + 'PUT': { + 'text': { + 'max_length': 100, + 'read_only': False, + 'required': True, + 'type': 'string', + 'label': 'Text comes here', + 'help_text': 'Text description.' + }, + 'id': { + 'read_only': True, + 'required': False, + 'type': 'integer', + 'label': 'ID', + }, + } + } } self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, expected) + def test_get_instance_view_incorrect_arg(self): + """ + GET requests with an incorrect pk type, should raise 404, not 500. + Regression test for #890. + """ + request = factory.get('/a') + with self.assertNumQueries(0): + response = self.view(request, pk='a').render() + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + def test_put_cannot_set_id(self): """ PUT requests to create a new object should not be able to set the id. @@ -434,22 +481,14 @@ class TestFilterBackendAppliedToViews(TestCase): {'id': obj.id, 'text': obj.text} for obj in self.objects.all() ] - self.root_view = RootView.as_view() - self.instance_view = InstanceView.as_view() - self.original_root_backend = getattr(RootView, 'filter_backend') - self.original_instance_backend = getattr(InstanceView, 'filter_backend') - - def tearDown(self): - setattr(RootView, 'filter_backend', self.original_root_backend) - setattr(InstanceView, 'filter_backend', self.original_instance_backend) def test_get_root_view_filters_by_name_with_filter_backend(self): """ GET requests to ListCreateAPIView should return filtered list. """ - setattr(RootView, 'filter_backend', InclusiveFilterBackend) + root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) request = factory.get('/') - response = self.root_view(request).render() + response = root_view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data), 1) self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}]) @@ -458,9 +497,9 @@ class TestFilterBackendAppliedToViews(TestCase): """ GET requests to ListCreateAPIView should return empty list when all models are filtered out. """ - setattr(RootView, 'filter_backend', ExclusiveFilterBackend) + root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) request = factory.get('/') - response = self.root_view(request).render() + response = root_view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, []) @@ -468,9 +507,9 @@ class TestFilterBackendAppliedToViews(TestCase): """ GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. """ - setattr(InstanceView, 'filter_backend', ExclusiveFilterBackend) + instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) request = factory.get('/1') - response = self.instance_view(request, pk=1).render() + response = instance_view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.data, {'detail': 'Not found'}) @@ -478,8 +517,40 @@ class TestFilterBackendAppliedToViews(TestCase): """ GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded """ - setattr(InstanceView, 'filter_backend', InclusiveFilterBackend) + instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) request = factory.get('/1') - response = self.instance_view(request, pk=1).render() + response = instance_view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foo'}) + + +class TwoFieldModel(models.Model): + field_a = models.CharField(max_length=100) + field_b = models.CharField(max_length=100) + + +class DynamicSerializerView(generics.ListCreateAPIView): + model = TwoFieldModel + renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) + + def get_serializer_class(self): + if self.request.method == 'POST': + class DynamicSerializer(serializers.ModelSerializer): + class Meta: + model = TwoFieldModel + fields = ('field_b',) + return DynamicSerializer + return super(DynamicSerializerView, self).get_serializer_class() + + +class TestFilterBackendAppliedToViews(TestCase): + + def test_dynamic_serializer_form_in_browsable_api(self): + """ + GET requests to ListCreateAPIView should return filtered list. + """ + view = DynamicSerializerView.as_view() + request = factory.get('/') + response = view(request).render() + self.assertContains(response, 'field_b') + self.assertNotContains(response, 'field_a') diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py index 8f2e2b5a..8957a43c 100644 --- a/rest_framework/tests/htmlrenderer.py +++ b/rest_framework/tests/test_htmlrenderer.py @@ -42,7 +42,7 @@ urlpatterns = patterns('', class TemplateHTMLRendererTests(TestCase): - urls = 'rest_framework.tests.htmlrenderer' + urls = 'rest_framework.tests.test_htmlrenderer' def setUp(self): """ @@ -66,23 +66,23 @@ class TemplateHTMLRendererTests(TestCase): def test_simple_html_view(self): response = self.client.get('/') self.assertContains(response, "example: foobar") - self.assertEqual(response['Content-Type'], 'text/html') + self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') def test_not_found_html_view(self): response = self.client.get('/not_found') self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.content, six.b("404 Not Found")) - self.assertEqual(response['Content-Type'], 'text/html') + self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') def test_permission_denied_html_view(self): response = self.client.get('/permission_denied') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.content, six.b("403 Forbidden")) - self.assertEqual(response['Content-Type'], 'text/html') + self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') class TemplateHTMLRendererExceptionTests(TestCase): - urls = 'rest_framework.tests.htmlrenderer' + urls = 'rest_framework.tests.test_htmlrenderer' def setUp(self): """ @@ -109,10 +109,10 @@ class TemplateHTMLRendererExceptionTests(TestCase): response = self.client.get('/not_found') self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.content, six.b("404: Not found")) - self.assertEqual(response['Content-Type'], 'text/html') + self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') def test_permission_denied_html_view_with_template(self): response = self.client.get('/permission_denied') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.content, six.b("403: Permission denied")) - self.assertEqual(response['Content-Type'], 'text/html') + self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py index 9a61f299..1894ddb2 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/test_hyperlinkedserializers.py @@ -27,6 +27,14 @@ class PhotoSerializer(serializers.Serializer): return Photo(**attrs) +class AlbumSerializer(serializers.ModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title') + + class Meta: + model = Album + fields = ('title', 'url') + + class BasicList(generics.ListCreateAPIView): model = BasicModel model_serializer_class = serializers.HyperlinkedModelSerializer @@ -73,6 +81,8 @@ class PhotoListCreate(generics.ListCreateAPIView): class AlbumDetail(generics.RetrieveAPIView): model = Album + serializer_class = AlbumSerializer + lookup_field = 'title' class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): @@ -96,7 +106,7 @@ urlpatterns = patterns('', class TestBasicHyperlinkedView(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -133,7 +143,7 @@ class TestBasicHyperlinkedView(TestCase): class TestManyToManyHyperlinkedView(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -180,8 +190,38 @@ class TestManyToManyHyperlinkedView(TestCase): self.assertEqual(response.data, self.data[0]) +class TestHyperlinkedIdentityFieldLookup(TestCase): + urls = 'rest_framework.tests.test_hyperlinkedserializers' + + def setUp(self): + """ + Create 3 Album instances. + """ + titles = ['foo', 'bar', 'baz'] + for title in titles: + album = Album(title=title) + album.save() + self.detail_view = AlbumDetail.as_view() + self.data = { + 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'}, + 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'}, + 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'} + } + + def test_lookup_field(self): + """ + GET requests to AlbumDetail view should return serialized Albums + with a url field keyed by `title`. + """ + for album in Album.objects.all(): + request = factory.get('/albums/{0}/'.format(album.title)) + response = self.detail_view(request, title=album.title) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[album.title]) + + class TestCreateWithForeignKeys(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -206,7 +246,7 @@ class TestCreateWithForeignKeys(TestCase): class TestCreateWithForeignKeysAndCustomSlug(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -231,7 +271,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase): class TestOptionalRelationHyperlinkedView(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ diff --git a/rest_framework/tests/multitable_inheritance.py b/rest_framework/tests/test_multitable_inheritance.py index 00c15327..00c15327 100644 --- a/rest_framework/tests/multitable_inheritance.py +++ b/rest_framework/tests/test_multitable_inheritance.py diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/test_negotiation.py index 43721b84..7f84827f 100644 --- a/rest_framework/tests/negotiation.py +++ b/rest_framework/tests/test_negotiation.py @@ -3,19 +3,24 @@ from django.test import TestCase from django.test.client import RequestFactory from rest_framework.negotiation import DefaultContentNegotiation from rest_framework.request import Request +from rest_framework.renderers import BaseRenderer factory = RequestFactory() -class MockJSONRenderer(object): +class MockJSONRenderer(BaseRenderer): media_type = 'application/json' -class MockHTMLRenderer(object): +class MockHTMLRenderer(BaseRenderer): media_type = 'text/html' +class NoCharsetSpecifiedRenderer(BaseRenderer): + media_type = 'my/media' + + class TestAcceptedMediaType(TestCase): def setUp(self): self.renderers = [MockJSONRenderer(), MockHTMLRenderer()] diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/test_pagination.py index 6b8ef02f..e538a78e 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -1,18 +1,24 @@ from __future__ import unicode_literals import datetime from decimal import Decimal -import django +from django.db import models from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters -from rest_framework.tests.models import BasicModel, FilterableItem +from rest_framework.tests.models import BasicModel factory = RequestFactory() +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + class RootView(generics.ListCreateAPIView): """ Example description for OPTIONS. @@ -124,7 +130,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): model = FilterableItem paginate_by = 10 filter_class = DecimalFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) view = FilterFieldsRootView.as_view() @@ -171,7 +177,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): class BasicFilterFieldsRootView(generics.ListCreateAPIView): model = FilterableItem paginate_by = 10 - filter_backend = DecimalFilterBackend + filter_backends = (DecimalFilterBackend,) view = BasicFilterFieldsRootView.as_view() diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/test_parsers.py index 7699e10c..7699e10c 100644 --- a/rest_framework/tests/parsers.py +++ b/rest_framework/tests/test_parsers.py diff --git a/rest_framework/tests/permissions.py b/rest_framework/tests/test_permissions.py index b3993be5..6caaf65b 100644 --- a/rest_framework/tests/permissions.py +++ b/rest_framework/tests/test_permissions.py @@ -108,6 +108,48 @@ class ModelPermissionsIntegrationTests(TestCase): response = instance_view(request, pk='2') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + def test_options_permitted(self): + request = factory.options('/', content_type='application/json', + HTTP_AUTHORIZATION=self.permitted_credentials) + response = root_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('actions', response.data) + self.assertEqual(list(response.data['actions'].keys()), ['POST']) + + request = factory.options('/1', content_type='application/json', + HTTP_AUTHORIZATION=self.permitted_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('actions', response.data) + self.assertEqual(list(response.data['actions'].keys()), ['PUT']) + + def test_options_disallowed(self): + request = factory.options('/', content_type='application/json', + HTTP_AUTHORIZATION=self.disallowed_credentials) + response = root_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertNotIn('actions', response.data) + + request = factory.options('/1', content_type='application/json', + HTTP_AUTHORIZATION=self.disallowed_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertNotIn('actions', response.data) + + def test_options_updateonly(self): + request = factory.options('/', content_type='application/json', + HTTP_AUTHORIZATION=self.updateonly_credentials) + response = root_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertNotIn('actions', response.data) + + request = factory.options('/1', content_type='application/json', + HTTP_AUTHORIZATION=self.updateonly_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('actions', response.data) + self.assertEqual(list(response.data['actions'].keys()), ['PUT']) + class OwnerModel(models.Model): text = models.CharField(max_length=100) diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py new file mode 100644 index 00000000..d19219c9 --- /dev/null +++ b/rest_framework/tests/test_relations.py @@ -0,0 +1,100 @@ +""" +General tests for relational fields. +""" +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import BlogPost + + +class NullModel(models.Model): + pass + + +class FieldTests(TestCase): + def test_pk_related_field_with_empty_string(self): + """ + Regression test for #446 + + https://github.com/tomchristie/django-rest-framework/issues/446 + """ + field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + def test_hyperlinked_related_field_with_empty_string(self): + field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + def test_slug_related_field_with_empty_string(self): + field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + +class TestManyRelatedMixin(TestCase): + def test_missing_many_to_many_related_field(self): + ''' + Regression test for #632 + + https://github.com/tomchristie/django-rest-framework/pull/632 + ''' + field = serializers.RelatedField(many=True, read_only=False) + + into = {} + field.field_from_native({}, None, 'field_name', into) + self.assertEqual(into['field_name'], []) + + +# Regression tests for #694 (`source` attribute on related fields) + +class RelatedFieldSourceTests(TestCase): + def test_related_manager_source(self): + """ + Relational fields should be able to use manager-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.RelatedField(many=True, source='get_blogposts_manager') + + class ClassWithManagerMethod(object): + def get_blogposts_manager(self): + return BlogPost.objects + + obj = ClassWithManagerMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['BlogPost object']) + + def test_related_queryset_source(self): + """ + Relational fields should be able to use queryset-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.RelatedField(many=True, source='get_blogposts_queryset') + + class ClassWithQuerysetMethod(object): + def get_blogposts_queryset(self): + return BlogPost.objects.all() + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['BlogPost object']) + + def test_dotted_source(self): + """ + Source argument should support dotted.source notation. + """ + BlogPost.objects.create(title='blah') + field = serializers.RelatedField(many=True, source='a.b.c') + + class ClassWithQuerysetMethod(object): + a = { + 'b': { + 'c': BlogPost.objects.all() + } + } + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['BlogPost object']) diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py index b1eed9a7..2ca7f4f2 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/test_relations_hyperlink.py @@ -4,6 +4,7 @@ from django.test.client import RequestFactory from rest_framework import serializers from rest_framework.compat import patterns, url from rest_framework.tests.models import ( + BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource ) @@ -16,6 +17,7 @@ def dummy_view(request, pk): pass urlpatterns = patterns('', + url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'), url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), @@ -69,7 +71,7 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): # TODO: Add test that .data cannot be accessed prior to .is_valid class HyperlinkedManyToManyTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def setUp(self): for idx in range(1, 4): @@ -177,7 +179,7 @@ class HyperlinkedManyToManyTests(TestCase): class HyperlinkedForeignKeyTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def setUp(self): target = ForeignKeyTarget(name='target-1') @@ -305,7 +307,7 @@ class HyperlinkedForeignKeyTests(TestCase): class HyperlinkedNullableForeignKeyTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def setUp(self): target = ForeignKeyTarget(name='target-1') @@ -433,7 +435,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase): class HyperlinkedNullableOneToOneTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def setUp(self): target = OneToOneTarget(name='target-1') @@ -451,3 +453,72 @@ class HyperlinkedNullableOneToOneTests(TestCase): {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ] self.assertEqual(serializer.data, expected) + + +# Regression tests for #694 (`source` attribute on related fields) + +class HyperlinkedRelatedFieldSourceTests(TestCase): + urls = 'rest_framework.tests.test_relations_hyperlink' + + def test_related_manager_source(self): + """ + Relational fields should be able to use manager-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.HyperlinkedRelatedField( + many=True, + source='get_blogposts_manager', + view_name='dummy-url', + ) + field.context = {'request': request} + + class ClassWithManagerMethod(object): + def get_blogposts_manager(self): + return BlogPost.objects + + obj = ClassWithManagerMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['http://testserver/dummyurl/1/']) + + def test_related_queryset_source(self): + """ + Relational fields should be able to use queryset-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.HyperlinkedRelatedField( + many=True, + source='get_blogposts_queryset', + view_name='dummy-url', + ) + field.context = {'request': request} + + class ClassWithQuerysetMethod(object): + def get_blogposts_queryset(self): + return BlogPost.objects.all() + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['http://testserver/dummyurl/1/']) + + def test_dotted_source(self): + """ + Source argument should support dotted.source notation. + """ + BlogPost.objects.create(title='blah') + field = serializers.HyperlinkedRelatedField( + many=True, + source='a.b.c', + view_name='dummy-url', + ) + field.context = {'request': request} + + class ClassWithQuerysetMethod(object): + a = { + 'b': { + 'c': BlogPost.objects.all() + } + } + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['http://testserver/dummyurl/1/']) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/test_relations_nested.py index 8325580f..8325580f 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/test_relations_pk.py index 5ce8b567..e2a1b815 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/test_relations_pk.py @@ -1,7 +1,11 @@ from __future__ import unicode_literals +from django.db import models from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource +from rest_framework.tests.models import ( + BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, + NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource, +) from rest_framework.compat import six @@ -124,6 +128,7 @@ class PKManyToManyTests(TestCase): # Ensure source 4 is added, and everything else is as expected queryset = ManyToManySource.objects.all() serializer = ManyToManySourceSerializer(queryset, many=True) + self.assertFalse(serializer.fields['targets'].read_only) expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, @@ -135,6 +140,7 @@ class PKManyToManyTests(TestCase): def test_reverse_many_to_many_create(self): data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]} serializer = ManyToManyTargetSerializer(data=data) + self.assertFalse(serializer.fields['sources'].read_only) self.assertTrue(serializer.is_valid()) obj = serializer.save() self.assertEqual(serializer.data, data) @@ -421,3 +427,116 @@ class PKNullableOneToOneTests(TestCase): {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ] self.assertEqual(serializer.data, expected) + + +# The below models and tests ensure that serializer fields corresponding +# to a ManyToManyField field with a user-specified ``through`` model are +# set to read only + + +class ManyToManyThroughTarget(models.Model): + name = models.CharField(max_length=100) + + +class ManyToManyThrough(models.Model): + source = models.ForeignKey('ManyToManyThroughSource') + target = models.ForeignKey(ManyToManyThroughTarget) + + +class ManyToManyThroughSource(models.Model): + name = models.CharField(max_length=100) + targets = models.ManyToManyField(ManyToManyThroughTarget, + related_name='sources', + through='ManyToManyThrough') + + +class ManyToManyThroughTargetSerializer(serializers.ModelSerializer): + class Meta: + model = ManyToManyThroughTarget + fields = ('id', 'name', 'sources') + + +class ManyToManyThroughSourceSerializer(serializers.ModelSerializer): + class Meta: + model = ManyToManyThroughSource + fields = ('id', 'name', 'targets') + + +class PKManyToManyThroughTests(TestCase): + def setUp(self): + self.source = ManyToManyThroughSource.objects.create( + name='through-source-1') + self.target = ManyToManyThroughTarget.objects.create( + name='through-target-1') + + def test_many_to_many_create(self): + data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]} + serializer = ManyToManyThroughSourceSerializer(data=data) + self.assertTrue(serializer.fields['targets'].read_only) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(obj.name, 'source-2') + self.assertEqual(obj.targets.count(), 0) + + def test_many_to_many_reverse_create(self): + data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]} + serializer = ManyToManyThroughTargetSerializer(data=data) + self.assertTrue(serializer.fields['sources'].read_only) + self.assertTrue(serializer.is_valid()) + serializer.save() + obj = serializer.save() + self.assertEqual(obj.name, 'target-2') + self.assertEqual(obj.sources.count(), 0) + + +# Regression tests for #694 (`source` attribute on related fields) + + +class PrimaryKeyRelatedFieldSourceTests(TestCase): + def test_related_manager_source(self): + """ + Relational fields should be able to use manager-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager') + + class ClassWithManagerMethod(object): + def get_blogposts_manager(self): + return BlogPost.objects + + obj = ClassWithManagerMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, [1]) + + def test_related_queryset_source(self): + """ + Relational fields should be able to use queryset-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset') + + class ClassWithQuerysetMethod(object): + def get_blogposts_queryset(self): + return BlogPost.objects.all() + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, [1]) + + def test_dotted_source(self): + """ + Source argument should support dotted.source notation. + """ + BlogPost.objects.create(title='blah') + field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c') + + class ClassWithQuerysetMethod(object): + a = { + 'b': { + 'c': BlogPost.objects.all() + } + } + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, [1]) diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/test_relations_slug.py index 435c821c..435c821c 100644 --- a/rest_framework/tests/relations_slug.py +++ b/rest_framework/tests/test_relations_slug.py diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/test_renderers.py index 40bac9cb..95b59741 100644 --- a/rest_framework/tests/renderers.py +++ b/rest_framework/tests/test_renderers.py @@ -1,14 +1,18 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + from decimal import Decimal from django.core.cache import cache from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest +from django.utils.translation import ugettext_lazy as _ from rest_framework import status, permissions from rest_framework.compat import yaml, etree, patterns, url, include from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ - XMLRenderer, JSONPRenderer, BrowsableAPIRenderer + XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.settings import api_settings from rest_framework.compat import StringIO @@ -26,7 +30,7 @@ RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') expected_results = [ - ((elem for elem in [1, 2, 3]), JSONRenderer, '[1, 2, 3]') # Generator + ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]') # Generator ] @@ -129,12 +133,12 @@ class RendererEndToEndTests(TestCase): End-to-end testing of renderers using an RendererMixin on a generic view. """ - urls = 'rest_framework.tests.renderers' + urls = 'rest_framework.tests.test_renderers' def test_default_renderer_serializes_content(self): """If the Accept header is not set the default renderer should serialize the response.""" resp = self.client.get('/') - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -142,13 +146,13 @@ class RendererEndToEndTests(TestCase): """No response must be included in HEAD requests.""" resp = self.client.head('/') self.assertEqual(resp.status_code, DUMMYSTATUS) - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, six.b('')) def test_default_renderer_serializes_content_on_accept_any(self): """If the Accept header is set to */* the default renderer should serialize the response.""" resp = self.client.get('/', HTTP_ACCEPT='*/*') - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -156,7 +160,7 @@ class RendererEndToEndTests(TestCase): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for the default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -164,7 +168,7 @@ class RendererEndToEndTests(TestCase): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for a non-default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -175,7 +179,7 @@ class RendererEndToEndTests(TestCase): RendererB.media_type ) resp = self.client.get('/' + param) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -192,7 +196,7 @@ class RendererEndToEndTests(TestCase): RendererB.format ) resp = self.client.get('/' + param) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -200,7 +204,7 @@ class RendererEndToEndTests(TestCase): """If a 'format' keyword arg is specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/something.formatb') - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -213,7 +217,7 @@ class RendererEndToEndTests(TestCase): ) resp = self.client.get('/' + param, HTTP_ACCEPT=RendererB.media_type) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -235,6 +239,13 @@ class JSONRendererTests(TestCase): Tests specific to the JSON Renderer """ + def test_render_lazy_strings(self): + """ + JSONRenderer should deal with lazy translated strings. + """ + ret = JSONRenderer().render(_('test')) + self.assertEqual(ret, b'"test"') + def test_without_content_type_args(self): """ Test basic JSON rendering. @@ -243,7 +254,7 @@ class JSONRendererTests(TestCase): renderer = JSONRenderer() content = renderer.render(obj, 'application/json') # Fix failing test case which depends on version of JSON library. - self.assertEqual(content, _flat_repr) + self.assertEqual(content.decode('utf-8'), _flat_repr) def test_with_content_type_args(self): """ @@ -252,7 +263,24 @@ class JSONRendererTests(TestCase): obj = {'foo': ['bar', 'baz']} renderer = JSONRenderer() content = renderer.render(obj, 'application/json; indent=2') - self.assertEqual(strip_trailing_whitespace(content), _indented_repr) + self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr) + + def test_check_ascii(self): + obj = {'countries': ['United Kingdom', 'France', 'España']} + renderer = JSONRenderer() + content = renderer.render(obj, 'application/json') + self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8')) + + +class UnicodeJSONRendererTests(TestCase): + """ + Tests specific for the Unicode JSON Renderer + """ + def test_proper_encoding(self): + obj = {'countries': ['United Kingdom', 'France', 'España']} + renderer = UnicodeJSONRenderer() + content = renderer.render(obj, 'application/json') + self.assertEqual(content, '{"countries": ["United Kingdom", "France", "España"]}'.encode('utf-8')) class JSONPRendererTests(TestCase): @@ -260,7 +288,7 @@ class JSONPRendererTests(TestCase): Tests specific to the JSONP Renderer """ - urls = 'rest_framework.tests.renderers' + urls = 'rest_framework.tests.test_renderers' def test_without_callback_with_json_renderer(self): """ @@ -269,7 +297,7 @@ class JSONPRendererTests(TestCase): resp = self.client.get('/jsonp/jsonrenderer', HTTP_ACCEPT='application/javascript') self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp['Content-Type'], 'application/javascript') + self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp.content, ('callback(%s);' % _flat_repr).encode('ascii')) @@ -280,7 +308,7 @@ class JSONPRendererTests(TestCase): resp = self.client.get('/jsonp/nojsonrenderer', HTTP_ACCEPT='application/javascript') self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp['Content-Type'], 'application/javascript') + self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp.content, ('callback(%s);' % _flat_repr).encode('ascii')) @@ -292,7 +320,7 @@ class JSONPRendererTests(TestCase): resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, HTTP_ACCEPT='application/javascript') self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp['Content-Type'], 'application/javascript') + self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp.content, ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')) @@ -433,7 +461,7 @@ class CacheRenderTest(TestCase): Tests specific to caching responses """ - urls = 'rest_framework.tests.renderers' + urls = 'rest_framework.tests.test_renderers' cache_key = 'just_a_cache_key' diff --git a/rest_framework/tests/request.py b/rest_framework/tests/test_request.py index 97e5af20..a5c5e84c 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/test_request.py @@ -254,7 +254,7 @@ urlpatterns = patterns('', class TestContentParsingWithAuthentication(TestCase): - urls = 'rest_framework.tests.request' + urls = 'rest_framework.tests.test_request' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) diff --git a/rest_framework/tests/response.py b/rest_framework/tests/test_response.py index aecf83f4..eea3c641 100644 --- a/rest_framework/tests/response.py +++ b/rest_framework/tests/test_response.py @@ -1,14 +1,18 @@ from __future__ import unicode_literals from django.test import TestCase +from rest_framework.tests.models import BasicModel, BasicModelSerializer from rest_framework.compat import patterns, url, include from rest_framework.response import Response from rest_framework.views import APIView +from rest_framework import generics +from rest_framework import routers from rest_framework import status from rest_framework.renderers import ( BaseRenderer, JSONRenderer, BrowsableAPIRenderer ) +from rest_framework import viewsets from rest_framework.settings import api_settings from rest_framework.compat import six @@ -21,6 +25,9 @@ class MockJsonRenderer(BaseRenderer): media_type = 'application/json' +class MockTextMediaRenderer(BaseRenderer): + media_type = 'text/html' + DUMMYSTATUS = status.HTTP_200_OK DUMMYCONTENT = 'dummycontent' @@ -44,13 +51,26 @@ class RendererB(BaseRenderer): return RENDERER_B_SERIALIZER(data) +class RendererC(RendererB): + media_type = 'mock/rendererc' + format = 'formatc' + charset = "rendererc" + + class MockView(APIView): - renderer_classes = (RendererA, RendererB) + renderer_classes = (RendererA, RendererB, RendererC) def get(self, request, **kwargs): return Response(DUMMYCONTENT, status=DUMMYSTATUS) +class MockViewSettingContentType(APIView): + renderer_classes = (RendererA, RendererB, RendererC) + + def get(self, request, **kwargs): + return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview') + + class HTMLView(APIView): renderer_classes = (BrowsableAPIRenderer, ) @@ -65,11 +85,29 @@ class HTMLView1(APIView): return Response('text') +class HTMLNewModelViewSet(viewsets.ModelViewSet): + model = BasicModel + + +class HTMLNewModelView(generics.ListCreateAPIView): + renderer_classes = (BrowsableAPIRenderer,) + permission_classes = [] + serializer_class = BasicModelSerializer + model = BasicModel + + +new_model_viewset_router = routers.DefaultRouter() +new_model_viewset_router.register(r'', HTMLNewModelViewSet) + + urlpatterns = patterns('', - url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), - url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), + url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])), + url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), + url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), url(r'^html$', HTMLView.as_view()), url(r'^html1$', HTMLView1.as_view()), + url(r'^html_new_model$', HTMLNewModelView.as_view()), + url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)), url(r'^restframework', include('rest_framework.urls', namespace='rest_framework')) ) @@ -80,12 +118,12 @@ class RendererIntegrationTests(TestCase): End-to-end testing of renderers using an ResponseMixin on a generic view. """ - urls = 'rest_framework.tests.response' + urls = 'rest_framework.tests.test_response' def test_default_renderer_serializes_content(self): """If the Accept header is not set the default renderer should serialize the response.""" resp = self.client.get('/') - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -93,13 +131,13 @@ class RendererIntegrationTests(TestCase): """No response must be included in HEAD requests.""" resp = self.client.head('/') self.assertEqual(resp.status_code, DUMMYSTATUS) - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, six.b('')) def test_default_renderer_serializes_content_on_accept_any(self): """If the Accept header is set to */* the default renderer should serialize the response.""" resp = self.client.get('/', HTTP_ACCEPT='*/*') - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -107,7 +145,7 @@ class RendererIntegrationTests(TestCase): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for the default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -115,7 +153,7 @@ class RendererIntegrationTests(TestCase): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for a non-default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -126,7 +164,7 @@ class RendererIntegrationTests(TestCase): RendererB.media_type ) resp = self.client.get('/' + param) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -134,7 +172,7 @@ class RendererIntegrationTests(TestCase): """If a 'format' query is specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/?format=%s' % RendererB.format) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -142,7 +180,7 @@ class RendererIntegrationTests(TestCase): """If a 'format' keyword arg is specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/something.formatb') - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -151,7 +189,7 @@ class RendererIntegrationTests(TestCase): the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/?format=%s' % RendererB.format, HTTP_ACCEPT=RendererB.media_type) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -160,7 +198,7 @@ class Issue122Tests(TestCase): """ Tests that covers #122. """ - urls = 'rest_framework.tests.response' + urls = 'rest_framework.tests.test_response' def test_only_html_renderer(self): """ @@ -173,3 +211,68 @@ class Issue122Tests(TestCase): Test if no infinite recursion occurs. """ self.client.get('/html1') + + +class Issue467Tests(TestCase): + """ + Tests for #467 + """ + + urls = 'rest_framework.tests.test_response' + + def test_form_has_label_and_help_text(self): + resp = self.client.get('/html_new_model') + self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') + self.assertContains(resp, 'Text comes here') + self.assertContains(resp, 'Text description.') + + +class Issue807Tests(TestCase): + """ + Covers #807 + """ + + urls = 'rest_framework.tests.test_response' + + def test_does_not_append_charset_by_default(self): + """ + Renderers don't include a charset unless set explicitly. + """ + headers = {"HTTP_ACCEPT": RendererA.media_type} + resp = self.client.get('/', **headers) + expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8') + self.assertEqual(expected, resp['Content-Type']) + + def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self): + """ + If renderer class has charset attribute declared, it gets appended + to Response's Content-Type + """ + headers = {"HTTP_ACCEPT": RendererC.media_type} + resp = self.client.get('/', **headers) + expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset) + self.assertEqual(expected, resp['Content-Type']) + + def test_content_type_set_explictly_on_response(self): + """ + The content type may be set explictly on the response. + """ + headers = {"HTTP_ACCEPT": RendererC.media_type} + resp = self.client.get('/setbyview', **headers) + self.assertEqual('setbyview', resp['Content-Type']) + + def test_viewset_label_help_text(self): + param = '?%s=%s' % ( + api_settings.URL_ACCEPT_OVERRIDE, + 'text/html' + ) + resp = self.client.get('/html_new_model_viewset/' + param) + self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') + self.assertContains(resp, 'Text comes here') + self.assertContains(resp, 'Text description.') + + def test_form_has_label_and_help_text(self): + resp = self.client.get('/html_new_model') + self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') + self.assertContains(resp, 'Text comes here') + self.assertContains(resp, 'Text description.') diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/test_reverse.py index cb8d8132..93ef5637 100644 --- a/rest_framework/tests/reverse.py +++ b/rest_framework/tests/test_reverse.py @@ -19,7 +19,7 @@ class ReverseTests(TestCase): """ Tests for fully qualified URLs when using `reverse`. """ - urls = 'rest_framework.tests.reverse' + urls = 'rest_framework.tests.test_reverse' def test_reversed_urls_are_fully_qualified(self): request = factory.get('/view') diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py new file mode 100644 index 00000000..a7534f70 --- /dev/null +++ b/rest_framework/tests/test_routers.py @@ -0,0 +1,150 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from django.test.client import RequestFactory +from rest_framework import serializers, viewsets +from rest_framework.compat import include, patterns, url +from rest_framework.decorators import link, action +from rest_framework.response import Response +from rest_framework.routers import SimpleRouter + +factory = RequestFactory() + +urlpatterns = patterns('',) + + +class BasicViewSet(viewsets.ViewSet): + def list(self, request, *args, **kwargs): + return Response({'method': 'list'}) + + @action() + def action1(self, request, *args, **kwargs): + return Response({'method': 'action1'}) + + @action() + def action2(self, request, *args, **kwargs): + return Response({'method': 'action2'}) + + @action(methods=['post', 'delete']) + def action3(self, request, *args, **kwargs): + return Response({'method': 'action2'}) + + @link() + def link1(self, request, *args, **kwargs): + return Response({'method': 'link1'}) + + @link() + def link2(self, request, *args, **kwargs): + return Response({'method': 'link2'}) + + +class TestSimpleRouter(TestCase): + def setUp(self): + self.router = SimpleRouter() + + def test_link_and_action_decorator(self): + routes = self.router.get_routes(BasicViewSet) + decorator_routes = routes[2:] + # Make sure all these endpoints exist and none have been clobbered + for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']): + route = decorator_routes[i] + # check url listing + self.assertEqual(route.url, + '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint)) + # check method to function mapping + if endpoint == 'action3': + methods_map = ['post', 'delete'] + elif endpoint.startswith('action'): + methods_map = ['post'] + else: + methods_map = ['get'] + for method in methods_map: + self.assertEqual(route.mapping[method], endpoint) + + +class RouterTestModel(models.Model): + uuid = models.CharField(max_length=20) + text = models.CharField(max_length=200) + + +class TestCustomLookupFields(TestCase): + """ + Ensure that custom lookup fields are correctly routed. + """ + urls = 'rest_framework.tests.test_routers' + + def setUp(self): + class NoteSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RouterTestModel + lookup_field = 'uuid' + fields = ('url', 'uuid', 'text') + + class NoteViewSet(viewsets.ModelViewSet): + queryset = RouterTestModel.objects.all() + serializer_class = NoteSerializer + lookup_field = 'uuid' + + RouterTestModel.objects.create(uuid='123', text='foo bar') + + self.router = SimpleRouter() + self.router.register(r'notes', NoteViewSet) + + from rest_framework.tests import test_routers + urls = getattr(test_routers, 'urlpatterns') + urls += patterns('', + url(r'^', include(self.router.urls)), + ) + + def test_custom_lookup_field_route(self): + detail_route = self.router.urls[-1] + detail_url_pattern = detail_route.regex.pattern + self.assertIn('<uuid>', detail_url_pattern) + + def test_retrieve_lookup_field_list_view(self): + response = self.client.get('/notes/') + self.assertEqual(response.data, + [{ + "url": "http://testserver/notes/123/", + "uuid": "123", "text": "foo bar" + }] + ) + + def test_retrieve_lookup_field_detail_view(self): + response = self.client.get('/notes/123/') + self.assertEqual(response.data, + { + "url": "http://testserver/notes/123/", + "uuid": "123", "text": "foo bar" + } + ) + + +class TestTrailingSlash(TestCase): + def setUp(self): + class NoteViewSet(viewsets.ModelViewSet): + model = RouterTestModel + + self.router = SimpleRouter() + self.router.register(r'notes', NoteViewSet) + self.urls = self.router.urls + + def test_urls_have_trailing_slash_by_default(self): + expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$'] + for idx in range(len(expected)): + self.assertEqual(expected[idx], self.urls[idx].regex.pattern) + + +class TestTrailingSlash(TestCase): + def setUp(self): + class NoteViewSet(viewsets.ModelViewSet): + model = RouterTestModel + + self.router = SimpleRouter(trailing_slash=False) + self.router.register(r'notes', NoteViewSet) + self.urls = self.router.urls + + def test_urls_can_have_trailing_slash_removed(self): + expected = ['^notes$', '^notes/(?P<pk>[^/]+)$'] + for idx in range(len(expected)): + self.assertEqual(expected[idx], self.urls[idx].regex.pattern) diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/test_serializer.py index 84e1ee4e..8b87a084 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -1,10 +1,14 @@ from __future__ import unicode_literals -from django.utils.datastructures import MultiValueDict +from django.db import models +from django.db.models.fields import BLANK_CHOICE_DASH from django.test import TestCase -from rest_framework import serializers +from django.utils.datastructures import MultiValueDict +from django.utils.translation import ugettext_lazy as _ +from rest_framework import serializers, fields, relations from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, - ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) + ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel) +from rest_framework.tests.models import BasicModelSerializer import datetime import pickle @@ -43,6 +47,17 @@ class CommentSerializer(serializers.Serializer): return instance +class NamesSerializer(serializers.Serializer): + first = serializers.CharField() + last = serializers.CharField(required=False, default='') + initials = serializers.CharField(required=False, default='') + + +class PersonIdentifierSerializer(serializers.Serializer): + ssn = serializers.CharField() + names = NamesSerializer(source='names', required=False) + + class BookSerializer(serializers.ModelSerializer): isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'}) @@ -78,6 +93,29 @@ class PersonSerializer(serializers.ModelSerializer): read_only_fields = ('age',) +class NestedSerializer(serializers.Serializer): + info = serializers.Field() + + +class ModelSerializerWithNestedSerializer(serializers.ModelSerializer): + nested = NestedSerializer(source='*') + + class Meta: + model = Person + + +class PersonSerializerInvalidReadOnly(serializers.ModelSerializer): + """ + Testing for #652. + """ + info = serializers.Field(source='info') + + class Meta: + model = Person + fields = ('name', 'age', 'info') + read_only_fields = ('age', 'info') + + class AlbumsSerializer(serializers.ModelSerializer): class Meta: @@ -91,11 +129,6 @@ class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): fields = ['some_integer'] -class BrokenModelSerializer(serializers.ModelSerializer): - class Meta: - fields = ['some_field'] - - class BasicTests(TestCase): def setUp(self): self.comment = Comment( @@ -141,6 +174,42 @@ class BasicTests(TestCase): self.assertFalse(serializer.object is expected) self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') + def test_create_nested(self): + """Test a serializer with nested data.""" + names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'} + data = {'ssn': '1234567890', 'names': names} + serializer = PersonIdentifierSerializer(data=data) + + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + self.assertFalse(serializer.object is data) + self.assertEqual(serializer.data['names'], names) + + def test_create_partial_nested(self): + """Test a serializer with nested data which has missing fields.""" + names = {'first': 'John'} + data = {'ssn': '1234567890', 'names': names} + serializer = PersonIdentifierSerializer(data=data) + + expected_names = {'first': 'John', 'last': '', 'initials': ''} + data['names'] = expected_names + + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + self.assertFalse(serializer.object is expected_names) + self.assertEqual(serializer.data['names'], expected_names) + + def test_null_nested(self): + """Test a serializer with a nonexistent nested field""" + data = {'ssn': '1234567890'} + serializer = PersonIdentifierSerializer(data=data) + + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + self.assertFalse(serializer.object is data) + expected = {'ssn': '1234567890', 'names': None} + self.assertEqual(serializer.data, expected) + def test_update(self): serializer = CommentSerializer(self.comment, data=self.data) expected = self.comment @@ -189,6 +258,12 @@ class BasicTests(TestCase): # Assert age is unchanged (35) self.assertEqual(instance.age, self.person_data['age']) + def test_invalid_read_only_fields(self): + """ + Regression test for #652. + """ + self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, []) + class DictStyleSerializer(serializers.Serializer): """ @@ -344,19 +419,34 @@ class ValidationTests(TestCase): Assert that a meaningful exception message is outputted when the model field is missing (e.g. when mistyping ``model``). """ + class BrokenModelSerializer(serializers.ModelSerializer): + class Meta: + fields = ['some_field'] + try: - serializer = BrokenModelSerializer() + BrokenModelSerializer() except AssertionError as e: self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option") except: self.fail('Wrong exception type thrown.') + def test_writable_star_source_on_nested_serializer(self): + """ + Assert that a nested serializer instantiated with source='*' correctly + expands the data into the outer serializer. + """ + serializer = ModelSerializerWithNestedSerializer(data={ + 'name': 'marko', + 'nested': {'info': 'hi'}}, + ) + self.assertEqual(serializer.is_valid(), True) + class CustomValidationTests(TestCase): class CommentSerializerWithFieldValidator(CommentSerializer): def validate_email(self, attrs, source): - value = attrs[source] + attrs[source] return attrs def validate_content(self, attrs, source): @@ -853,23 +943,6 @@ class RelatedTraversalTest(TestCase): self.assertEqual(serializer.data, expected) - def test_queryset_nested_traversal(self): - """ - Relational fields should be able to use methods as their source. - """ - BlogPost.objects.create(title='blah') - - class QuerysetMethodSerializer(serializers.Serializer): - blogposts = serializers.RelatedField(many=True, source='get_all_blogposts') - - class ClassWithQuerysetMethod(object): - def get_all_blogposts(self): - return BlogPost.objects - - obj = ClassWithQuerysetMethod() - serializer = QuerysetMethodSerializer(obj) - self.assertEqual(serializer.data, {'blogposts': ['BlogPost object']}) - class SerializerMethodFieldTests(TestCase): def setUp(self): @@ -1000,6 +1073,130 @@ class SerializerPickleTests(TestCase): repr(pickle.loads(pickle.dumps(data, 0))) +# test for issue #725 +class SeveralChoicesModel(models.Model): + color = models.CharField( + max_length=10, + choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')], + blank=False + ) + drink = models.CharField( + max_length=10, + choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')], + blank=False, + default='beer' + ) + os = models.CharField( + max_length=10, + choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')], + blank=True + ) + music_genre = models.CharField( + max_length=10, + choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')], + blank=True, + default='metal' + ) + + +class SerializerChoiceFields(TestCase): + + def setUp(self): + super(SerializerChoiceFields, self).setUp() + + class SeveralChoicesSerializer(serializers.ModelSerializer): + class Meta: + model = SeveralChoicesModel + fields = ('color', 'drink', 'os', 'music_genre') + + self.several_choices_serializer = SeveralChoicesSerializer + + def test_choices_blank_false_not_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['color'].choices, + [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')] + ) + + def test_choices_blank_false_with_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['drink'].choices, + [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')] + ) + + def test_choices_blank_true_not_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['os'].choices, + BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')] + ) + + def test_choices_blank_true_with_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['music_genre'].choices, + BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')] + ) + + +# Regression tests for #675 +class Ticket(models.Model): + assigned = models.ForeignKey( + Person, related_name='assigned_tickets') + reviewer = models.ForeignKey( + Person, blank=True, null=True, related_name='reviewed_tickets') + + +class SerializerRelatedChoicesTest(TestCase): + + def setUp(self): + super(SerializerRelatedChoicesTest, self).setUp() + + class RelatedChoicesSerializer(serializers.ModelSerializer): + class Meta: + model = Ticket + fields = ('assigned', 'reviewer') + + self.related_fields_serializer = RelatedChoicesSerializer + + def test_empty_queryset_required(self): + serializer = self.related_fields_serializer() + self.assertEqual(serializer.fields['assigned'].queryset.count(), 0) + self.assertEqual( + [x for x in serializer.fields['assigned'].widget.choices], + [] + ) + + def test_empty_queryset_not_required(self): + serializer = self.related_fields_serializer() + self.assertEqual(serializer.fields['reviewer'].queryset.count(), 0) + self.assertEqual( + [x for x in serializer.fields['reviewer'].widget.choices], + [('', '---------')] + ) + + def test_with_some_persons_required(self): + Person.objects.create(name="Lionel Messi") + Person.objects.create(name="Xavi Hernandez") + serializer = self.related_fields_serializer() + self.assertEqual(serializer.fields['assigned'].queryset.count(), 2) + self.assertEqual( + [x for x in serializer.fields['assigned'].widget.choices], + [(1, 'Person object - 1'), (2, 'Person object - 2')] + ) + + def test_with_some_persons_not_required(self): + Person.objects.create(name="Lionel Messi") + Person.objects.create(name="Xavi Hernandez") + serializer = self.related_fields_serializer() + self.assertEqual(serializer.fields['reviewer'].queryset.count(), 2) + self.assertEqual( + [x for x in serializer.fields['reviewer'].widget.choices], + [('', '---------'), (1, 'Person object - 1'), (2, 'Person object - 2')] + ) + + class DepthTest(TestCase): def test_implicit_nesting(self): @@ -1125,3 +1322,312 @@ class DeserializeListTestCase(TestCase): self.assertFalse(serializer.is_valid()) expected = [{}, {'email': ['This field is required.']}, {}] self.assertEqual(serializer.errors, expected) + + +# Test for issue 747 + +class LazyStringModel(object): + def __init__(self, lazystring): + self.lazystring = lazystring + + +class LazyStringSerializer(serializers.Serializer): + lazystring = serializers.Field() + + def restore_object(self, attrs, instance=None): + if instance is not None: + instance.lazystring = attrs.get('lazystring', instance.lazystring) + return instance + return LazyStringModel(**attrs) + + +class LazyStringsTestCase(TestCase): + def setUp(self): + self.model = LazyStringModel(lazystring=_('lazystring')) + + def test_lazy_strings_are_translated(self): + serializer = LazyStringSerializer(self.model) + self.assertEqual(type(serializer.data['lazystring']), + type('lazystring')) + + +# Test for issue #467 + +class FieldLabelTest(TestCase): + def setUp(self): + self.serializer_class = BasicModelSerializer + + def test_label_from_model(self): + """ + Validates that label and help_text are correctly copied from the model class. + """ + serializer = self.serializer_class() + text_field = serializer.fields['text'] + + self.assertEqual('Text comes here', text_field.label) + self.assertEqual('Text description.', text_field.help_text) + + def test_field_ctor(self): + """ + This is check that ctor supports both label and help_text. + """ + self.assertEqual('Label', fields.Field(label='Label', help_text='Help').label) + self.assertEqual('Help', fields.CharField(label='Label', help_text='Help').help_text) + self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label) + + +class AttributeMappingOnAutogeneratedFieldsTests(TestCase): + + def setUp(self): + class AMOAFModel(RESTFrameworkModel): + char_field = models.CharField(max_length=1024, blank=True) + comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) + decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) + email_field = models.EmailField(max_length=1024, blank=True) + file_field = models.FileField(max_length=1024, blank=True) + image_field = models.ImageField(max_length=1024, blank=True) + slug_field = models.SlugField(max_length=1024, blank=True) + url_field = models.URLField(max_length=1024, blank=True) + + class AMOAFSerializer(serializers.ModelSerializer): + class Meta: + model = AMOAFModel + + self.serializer_class = AMOAFSerializer + self.fields_attributes = { + 'char_field': [ + ('max_length', 1024), + ], + 'comma_separated_integer_field': [ + ('max_length', 1024), + ], + 'decimal_field': [ + ('max_digits', 64), + ('decimal_places', 32), + ], + 'email_field': [ + ('max_length', 1024), + ], + 'file_field': [ + ('max_length', 1024), + ], + 'image_field': [ + ('max_length', 1024), + ], + 'slug_field': [ + ('max_length', 1024), + ], + 'url_field': [ + ('max_length', 1024), + ], + } + + def field_test(self, field): + serializer = self.serializer_class(data={}) + self.assertEqual(serializer.is_valid(), True) + + for attribute in self.fields_attributes[field]: + self.assertEqual( + getattr(serializer.fields[field], attribute[0]), + attribute[1] + ) + + def test_char_field(self): + self.field_test('char_field') + + def test_comma_separated_integer_field(self): + self.field_test('comma_separated_integer_field') + + def test_decimal_field(self): + self.field_test('decimal_field') + + def test_email_field(self): + self.field_test('email_field') + + def test_file_field(self): + self.field_test('file_field') + + def test_image_field(self): + self.field_test('image_field') + + def test_slug_field(self): + self.field_test('slug_field') + + def test_url_field(self): + self.field_test('url_field') + + +class DefaultValuesOnAutogeneratedFieldsTests(TestCase): + + def setUp(self): + class DVOAFModel(RESTFrameworkModel): + positive_integer_field = models.PositiveIntegerField(blank=True) + positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) + email_field = models.EmailField(blank=True) + file_field = models.FileField(blank=True) + image_field = models.ImageField(blank=True) + slug_field = models.SlugField(blank=True) + url_field = models.URLField(blank=True) + + class DVOAFSerializer(serializers.ModelSerializer): + class Meta: + model = DVOAFModel + + self.serializer_class = DVOAFSerializer + self.fields_attributes = { + 'positive_integer_field': [ + ('min_value', 0), + ], + 'positive_small_integer_field': [ + ('min_value', 0), + ], + 'email_field': [ + ('max_length', 75), + ], + 'file_field': [ + ('max_length', 100), + ], + 'image_field': [ + ('max_length', 100), + ], + 'slug_field': [ + ('max_length', 50), + ], + 'url_field': [ + ('max_length', 200), + ], + } + + def field_test(self, field): + serializer = self.serializer_class(data={}) + self.assertEqual(serializer.is_valid(), True) + + for attribute in self.fields_attributes[field]: + self.assertEqual( + getattr(serializer.fields[field], attribute[0]), + attribute[1] + ) + + def test_positive_integer_field(self): + self.field_test('positive_integer_field') + + def test_positive_small_integer_field(self): + self.field_test('positive_small_integer_field') + + def test_email_field(self): + self.field_test('email_field') + + def test_file_field(self): + self.field_test('file_field') + + def test_image_field(self): + self.field_test('image_field') + + def test_slug_field(self): + self.field_test('slug_field') + + def test_url_field(self): + self.field_test('url_field') + + +class MetadataSerializer(serializers.Serializer): + field1 = serializers.CharField(3, required=True) + field2 = serializers.CharField(10, required=False) + + +class MetadataSerializerTestCase(TestCase): + def setUp(self): + self.serializer = MetadataSerializer() + + def test_serializer_metadata(self): + metadata = self.serializer.metadata() + expected = { + 'field1': { + 'required': True, + 'max_length': 3, + 'type': 'string', + 'read_only': False + }, + 'field2': { + 'required': False, + 'max_length': 10, + 'type': 'string', + 'read_only': False + } + } + self.assertEqual(expected, metadata) + + +### Regression test for #840 + +class SimpleModel(models.Model): + text = models.CharField(max_length=100) + + +class SimpleModelSerializer(serializers.ModelSerializer): + text = serializers.CharField() + other = serializers.CharField() + + class Meta: + model = SimpleModel + + def validate_other(self, attrs, source): + del attrs['other'] + return attrs + + +class FieldValidationRemovingAttr(TestCase): + def test_removing_non_model_field_in_validation(self): + """ + Removing an attr during field valiation should ensure that it is not + passed through when restoring the object. + + This allows additional non-model fields to be supported. + + Regression test for #840. + """ + serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'}) + self.assertTrue(serializer.is_valid()) + serializer.save() + self.assertEqual(serializer.object.text, 'foo') + + +### Regression test for #878 + +class SimpleTargetModel(models.Model): + text = models.CharField(max_length=100) + + +class SimplePKSourceModelSerializer(serializers.Serializer): + targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True) + text = serializers.CharField() + + +class SimpleSlugSourceModelSerializer(serializers.Serializer): + targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk') + text = serializers.CharField() + + +class SerializerSupportsManyRelationships(TestCase): + def setUp(self): + SimpleTargetModel.objects.create(text='foo') + SimpleTargetModel.objects.create(text='bar') + + def test_serializer_supports_pk_many_relationships(self): + """ + Regression test for #878. + + Note that pk behavior has a different code path to usual cases, + for performance reasons. + """ + serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) + + def test_serializer_supports_slug_many_relationships(self): + """ + Regression test for #878. + """ + serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/test_serializer_bulk_update.py index 8b0ded1a..8b0ded1a 100644 --- a/rest_framework/tests/serializer_bulk_update.py +++ b/rest_framework/tests/test_serializer_bulk_update.py diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/test_serializer_nested.py index 71d0e24b..71d0e24b 100644 --- a/rest_framework/tests/serializer_nested.py +++ b/rest_framework/tests/test_serializer_nested.py diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/test_settings.py index 857375c2..857375c2 100644 --- a/rest_framework/tests/settings.py +++ b/rest_framework/tests/test_settings.py diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/test_throttling.py index 11cbd8eb..da400b2f 100644 --- a/rest_framework/tests/throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -36,7 +36,7 @@ class MockView_MinuteThrottling(APIView): class ThrottlingTests(TestCase): - urls = 'rest_framework.tests.throttling' + urls = 'rest_framework.tests.test_throttling' def setUp(self): """ diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/test_urlpatterns.py index 29ed4a96..29ed4a96 100644 --- a/rest_framework/tests/urlpatterns.py +++ b/rest_framework/tests/test_urlpatterns.py diff --git a/rest_framework/tests/validation.py b/rest_framework/tests/test_validation.py index cbdd6515..a6ec0e99 100644 --- a/rest_framework/tests/validation.py +++ b/rest_framework/tests/test_validation.py @@ -63,3 +63,25 @@ class TestPreSaveValidationExclusions(TestCase): # does not have `blank=True`, so this serializer should not validate. serializer = ShouldValidateModelSerializer(data={'renamed': ''}) self.assertEqual(serializer.is_valid(), False) + + +class ValidationSerializer(serializers.Serializer): + foo = serializers.CharField() + + def validate_foo(self, attrs, source): + raise serializers.ValidationError("foo invalid") + + def validate(self, attrs): + raise serializers.ValidationError("serializer invalid") + + +class TestAvoidValidation(TestCase): + """ + If serializer was initialized with invalid data (None or non dict-like), it + should avoid validation layer (validate_<field> and validate methods) + """ + def test_serializer_errors_has_only_invalid_data_error(self): + serializer = ValidationSerializer(data='invalid data') + self.assertFalse(serializer.is_valid()) + self.assertDictEqual(serializer.errors, + {'non_field_errors': ['Invalid data']}) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/test_views.py index 994cf6dc..2767d24c 100644 --- a/rest_framework/tests/views.py +++ b/rest_framework/tests/test_views.py @@ -1,12 +1,15 @@ from __future__ import unicode_literals + +import copy + from django.test import TestCase from django.test.client import RequestFactory + from rest_framework import status from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.views import APIView -import copy factory = RequestFactory() diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py deleted file mode 100644 index f8c2579e..00000000 --- a/rest_framework/tests/testcases.py +++ /dev/null @@ -1,66 +0,0 @@ -# http://djangosnippets.org/snippets/1011/ -from __future__ import unicode_literals -from django.conf import settings -from django.core.management import call_command -from django.db.models import loading -from django.test import TestCase - -NO_SETTING = ('!', None) - - -class TestSettingsManager(object): - """ - A class which can modify some Django settings temporarily for a - test and then revert them to their original values later. - - Automatically handles resyncing the DB if INSTALLED_APPS is - modified. - - """ - def __init__(self): - self._original_settings = {} - - def set(self, **kwargs): - for k, v in kwargs.iteritems(): - self._original_settings.setdefault(k, getattr(settings, k, - NO_SETTING)) - setattr(settings, k, v) - if 'INSTALLED_APPS' in kwargs: - self.syncdb() - - def syncdb(self): - loading.cache.loaded = False - call_command('syncdb', verbosity=0) - - def revert(self): - for k, v in self._original_settings.iteritems(): - if v == NO_SETTING: - delattr(settings, k) - else: - setattr(settings, k, v) - if 'INSTALLED_APPS' in self._original_settings: - self.syncdb() - self._original_settings = {} - - -class SettingsTestCase(TestCase): - """ - A subclass of the Django TestCase with a settings_manager - attribute which is an instance of TestSettingsManager. - - Comes with a tearDown() method that calls - self.settings_manager.revert(). - - """ - def __init__(self, *args, **kwargs): - super(SettingsTestCase, self).__init__(*args, **kwargs) - self.settings_manager = TestSettingsManager() - - def tearDown(self): - self.settings_manager.revert() - - -class TestModelsTestCase(SettingsTestCase): - def setUp(self, *args, **kwargs): - installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',) - self.settings_manager.set(INSTALLED_APPS=installed_apps) diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py index 08f88e11..554ebd1a 100644 --- a/rest_framework/tests/tests.py +++ b/rest_framework/tests/tests.py @@ -4,11 +4,13 @@ runner to pick up the tests. Yowzers. """ from __future__ import unicode_literals import os +import django modules = [filename.rsplit('.', 1)[0] for filename in os.listdir(os.path.dirname(__file__)) if filename.endswith('.py') and not filename.startswith('_')] __test__ = dict() -for module in modules: - exec("from rest_framework.tests.%s import *" % module) +if django.VERSION < (1, 6): + for module in modules: + exec("from rest_framework.tests.%s import *" % module) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index b6de18a8..b26a2085 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -3,7 +3,8 @@ Helper classes for parsers. """ from __future__ import unicode_literals from django.utils.datastructures import SortedDict -from rest_framework.compat import timezone +from django.utils.functional import Promise +from rest_framework.compat import timezone, force_text from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata import datetime import decimal @@ -19,7 +20,9 @@ class JSONEncoder(json.JSONEncoder): def default(self, o): # For Date Time string spec, see ECMA 262 # http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15 - if isinstance(o, datetime.datetime): + if isinstance(o, Promise): + return force_text(o) + elif isinstance(o, datetime.datetime): r = o.isoformat() if o.microsecond: r = r[:23] + r[26:] diff --git a/rest_framework/views.py b/rest_framework/views.py index 555fa2f4..e1b6705b 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -2,13 +2,15 @@ Provides an APIView class that is the base of all views in REST framework. """ from __future__ import unicode_literals + from django.core.exceptions import PermissionDenied from django.http import Http404, HttpResponse +from django.utils.datastructures import SortedDict from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions from rest_framework.compat import View -from rest_framework.response import Response from rest_framework.request import Request +from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.utils.formatting import get_view_name, get_view_description @@ -51,21 +53,6 @@ class APIView(View): 'Vary': 'Accept' } - def metadata(self, request): - return { - 'name': get_view_name(self.__class__), - 'description': get_view_description(self.__class__), - 'renders': [renderer.media_type for renderer in self.renderer_classes], - 'parses': [parser.media_type for parser in self.parser_classes], - } - # TODO: Add 'fields', from serializer info, if it exists. - # serializer = self.get_serializer() - # if serializer is not None: - # field_name_types = {} - # for name, field in form.fields.iteritems(): - # field_name_types[name] = field.__class__.__name__ - # content['fields'] = field_name_types - def http_method_not_allowed(self, request, *args, **kwargs): """ If `request.method` does not correspond to a handler method, @@ -348,3 +335,15 @@ class APIView(View): a less useful default implementation. """ return Response(self.metadata(request), status=status.HTTP_200_OK) + + def metadata(self, request): + """ + Return a dictionary of metadata about the view. + Used to return responses for OPTIONS requests. + """ + ret = SortedDict() + ret['name'] = get_view_name(self.__class__) + ret['description'] = get_view_description(self.__class__) + ret['renders'] = [renderer.media_type for renderer in self.renderer_classes] + ret['parses'] = [parser.media_type for parser in self.parser_classes] + return ret diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 0eb3e86d..d91323f2 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -108,10 +108,18 @@ class ViewSet(ViewSetMixin, views.APIView): pass +class GenericViewSet(ViewSetMixin, generics.GenericAPIView): + """ + The GenericViewSet class does not provide any actions by default, + but does include the base set of generic view behavior, such as + the `get_object` and `get_queryset` methods. + """ + pass + + class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, - ViewSetMixin, - generics.GenericAPIView): + GenericViewSet): """ A viewset that provides default `list()` and `retrieve()` actions. """ @@ -123,8 +131,7 @@ class ModelViewSet(mixins.CreateModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, mixins.ListModelMixin, - ViewSetMixin, - generics.GenericAPIView): + GenericViewSet): """ A viewset that provides default `create()`, `retrieve()`, `update()`, `partial_update()`, `destroy()` and `list()` actions. |
