aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/compat.py13
-rw-r--r--rest_framework/decorators.py10
-rw-r--r--rest_framework/fields.py90
-rw-r--r--rest_framework/filters.py82
-rw-r--r--rest_framework/generics.py56
-rw-r--r--rest_framework/mixins.py32
-rw-r--r--rest_framework/permissions.py5
-rw-r--r--rest_framework/relations.py122
-rw-r--r--rest_framework/renderers.py101
-rw-r--r--rest_framework/request.py33
-rw-r--r--rest_framework/response.py24
-rw-r--r--rest_framework/routers.py48
-rwxr-xr-xrest_framework/runtests/runtests.py7
-rw-r--r--rest_framework/runtests/settings.py2
-rw-r--r--rest_framework/serializers.py130
-rw-r--r--rest_framework/static/rest_framework/css/bootstrap-tweaks.css165
-rw-r--r--rest_framework/static/rest_framework/css/default.css149
-rw-r--r--rest_framework/templates/rest_framework/base.html13
-rw-r--r--rest_framework/templates/rest_framework/form.html2
-rw-r--r--rest_framework/templates/rest_framework/login_base.html74
-rw-r--r--rest_framework/templatetags/rest_framework.py2
-rw-r--r--rest_framework/tests/models.py17
-rw-r--r--rest_framework/tests/relations.py47
-rw-r--r--rest_framework/tests/test_authentication.py (renamed from rest_framework/tests/authentication.py)53
-rw-r--r--rest_framework/tests/test_breadcrumbs.py (renamed from rest_framework/tests/breadcrumbs.py)2
-rw-r--r--rest_framework/tests/test_decorators.py (renamed from rest_framework/tests/decorators.py)0
-rw-r--r--rest_framework/tests/test_description.py (renamed from rest_framework/tests/description.py)0
-rw-r--r--rest_framework/tests/test_fields.py (renamed from rest_framework/tests/fields.py)242
-rw-r--r--rest_framework/tests/test_files.py (renamed from rest_framework/tests/files.py)0
-rw-r--r--rest_framework/tests/test_filters.py (renamed from rest_framework/tests/filterset.py)230
-rw-r--r--rest_framework/tests/test_genericrelations.py (renamed from rest_framework/tests/genericrelations.py)0
-rw-r--r--rest_framework/tests/test_generics.py (renamed from rest_framework/tests/generics.py)115
-rw-r--r--rest_framework/tests/test_htmlrenderer.py (renamed from rest_framework/tests/htmlrenderer.py)14
-rw-r--r--rest_framework/tests/test_hyperlinkedserializers.py (renamed from rest_framework/tests/hyperlinkedserializers.py)50
-rw-r--r--rest_framework/tests/test_multitable_inheritance.py (renamed from rest_framework/tests/multitable_inheritance.py)0
-rw-r--r--rest_framework/tests/test_negotiation.py (renamed from rest_framework/tests/negotiation.py)9
-rw-r--r--rest_framework/tests/test_pagination.py (renamed from rest_framework/tests/pagination.py)14
-rw-r--r--rest_framework/tests/test_parsers.py (renamed from rest_framework/tests/parsers.py)0
-rw-r--r--rest_framework/tests/test_permissions.py (renamed from rest_framework/tests/permissions.py)42
-rw-r--r--rest_framework/tests/test_relations.py100
-rw-r--r--rest_framework/tests/test_relations_hyperlink.py (renamed from rest_framework/tests/relations_hyperlink.py)79
-rw-r--r--rest_framework/tests/test_relations_nested.py (renamed from rest_framework/tests/relations_nested.py)0
-rw-r--r--rest_framework/tests/test_relations_pk.py (renamed from rest_framework/tests/relations_pk.py)121
-rw-r--r--rest_framework/tests/test_relations_slug.py (renamed from rest_framework/tests/relations_slug.py)0
-rw-r--r--rest_framework/tests/test_renderers.py (renamed from rest_framework/tests/renderers.py)66
-rw-r--r--rest_framework/tests/test_request.py (renamed from rest_framework/tests/request.py)2
-rw-r--r--rest_framework/tests/test_response.py (renamed from rest_framework/tests/response.py)131
-rw-r--r--rest_framework/tests/test_reverse.py (renamed from rest_framework/tests/reverse.py)2
-rw-r--r--rest_framework/tests/test_routers.py150
-rw-r--r--rest_framework/tests/test_serializer.py (renamed from rest_framework/tests/serializer.py)560
-rw-r--r--rest_framework/tests/test_serializer_bulk_update.py (renamed from rest_framework/tests/serializer_bulk_update.py)0
-rw-r--r--rest_framework/tests/test_serializer_nested.py (renamed from rest_framework/tests/serializer_nested.py)0
-rw-r--r--rest_framework/tests/test_settings.py (renamed from rest_framework/tests/settings.py)0
-rw-r--r--rest_framework/tests/test_throttling.py (renamed from rest_framework/tests/throttling.py)2
-rw-r--r--rest_framework/tests/test_urlpatterns.py (renamed from rest_framework/tests/urlpatterns.py)0
-rw-r--r--rest_framework/tests/test_validation.py (renamed from rest_framework/tests/validation.py)22
-rw-r--r--rest_framework/tests/test_views.py (renamed from rest_framework/tests/views.py)5
-rw-r--r--rest_framework/tests/testcases.py66
-rw-r--r--rest_framework/tests/tests.py6
-rw-r--r--rest_framework/utils/encoders.py7
-rw-r--r--rest_framework/views.py31
-rw-r--r--rest_framework/viewsets.py15
63 files changed, 2711 insertions, 651 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index b4961e2f..0a210186 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,4 +1,4 @@
-__version__ = '2.3.2'
+__version__ = '2.3.5'
VERSION = __version__ # synonym
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index cd39f544..76dc0052 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -495,3 +495,16 @@ except ImportError:
oauth2_provider_forms = None
oauth2_provider_scope = None
oauth2_constants = None
+
+# Handle lazy strings
+from django.utils.functional import Promise
+
+if six.PY3:
+ def is_non_str_iterable(obj):
+ if (isinstance(obj, str) or
+ (isinstance(obj, Promise) and obj._delegate_text)):
+ return False
+ return hasattr(obj, '__iter__')
+else:
+ def is_non_str_iterable(obj):
+ return hasattr(obj, '__iter__')
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 81e585e1..c69756a4 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -1,5 +1,5 @@
"""
-The most imporant decorator in this module is `@api_view`, which is used
+The most important decorator in this module is `@api_view`, which is used
for writing function-based views with REST framework.
There are also various decorators for setting the API policies on function
@@ -40,7 +40,7 @@ def api_view(http_method_names):
# api_view applied with eg. string instead of list of strings
assert isinstance(http_method_names, (list, tuple)), \
- '@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__
+ '@api_view expected a list of strings, received %s' % type(http_method_names).__name__
allowed_methods = set(http_method_names) | set(('options',))
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
@@ -112,18 +112,18 @@ def link(**kwargs):
Used to mark a method on a ViewSet that should be routed for GET requests.
"""
def decorator(func):
- func.bind_to_method = 'get'
+ func.bind_to_methods = ['get']
func.kwargs = kwargs
return func
return decorator
-def action(**kwargs):
+def action(methods=['post'], **kwargs):
"""
Used to mark a method on a ViewSet that should be routed for POST requests.
"""
def decorator(func):
- func.bind_to_method = 'post'
+ func.bind_to_methods = methods
func.kwargs = kwargs
return func
return decorator
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index c83ee5ec..535aa2ac 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -11,20 +11,21 @@ from decimal import Decimal, DecimalException
import inspect
import re
import warnings
-
from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
+from django.db.models.fields import BLANK_CHOICE_DASH
from django import forms
from django.forms import widgets
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
-
+from django.utils.datastructures import SortedDict
from rest_framework import ISO_8601
-from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time
+from rest_framework.compat import (timezone, parse_date, parse_datetime,
+ parse_time)
from rest_framework.compat import BytesIO
from rest_framework.compat import six
-from rest_framework.compat import smart_text
+from rest_framework.compat import smart_text, force_text, is_non_str_iterable
from rest_framework.settings import api_settings
@@ -50,7 +51,7 @@ def get_component(obj, attr_name):
return that attribute on the object.
"""
if isinstance(obj, dict):
- val = obj[attr_name]
+ val = obj.get(attr_name)
else:
val = getattr(obj, attr_name)
@@ -60,7 +61,8 @@ def get_component(obj, attr_name):
def readable_datetime_formats(formats):
- format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]')
+ format = ', '.join(formats).replace(ISO_8601,
+ 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]')
return humanize_strptime(format)
@@ -107,8 +109,9 @@ class Field(object):
partial = False
use_files = False
form_field_class = forms.CharField
+ type_label = 'field'
- def __init__(self, source=None):
+ def __init__(self, source=None, label=None, help_text=None):
self.parent = None
self.creation_counter = Field.creation_counter
@@ -116,6 +119,12 @@ class Field(object):
self.source = source
+ if label is not None:
+ self.label = smart_text(label)
+
+ if help_text is not None:
+ self.help_text = smart_text(help_text)
+
def initialize(self, parent, field_name):
"""
Called to set up a field prior to field_to_native or field_from_native.
@@ -167,11 +176,16 @@ class Field(object):
if is_protected_type(value):
return value
- elif hasattr(value, '__iter__') and not isinstance(value, (dict, six.string_types)):
+ elif (is_non_str_iterable(value) and
+ not isinstance(value, (dict, six.string_types))):
return [self.to_native(item) for item in value]
elif isinstance(value, dict):
- return dict(map(self.to_native, (k, v)) for k, v in value.items())
- return smart_text(value)
+ # Make sure we preserve field ordering, if it exists
+ ret = SortedDict()
+ for key, val in value.items():
+ ret[key] = self.to_native(val)
+ return ret
+ return force_text(value)
def attributes(self):
"""
@@ -181,6 +195,18 @@ class Field(object):
return {'type': self.type_name}
return {}
+ def metadata(self):
+ metadata = SortedDict()
+ metadata['type'] = self.type_label
+ metadata['required'] = getattr(self, 'required', False)
+ optional_attrs = ['read_only', 'label', 'help_text',
+ 'min_length', 'max_length']
+ for attr in optional_attrs:
+ value = getattr(self, attr, None)
+ if value is not None and value != '':
+ metadata[attr] = force_text(value, strings_only=True)
+ return metadata
+
class WritableField(Field):
"""
@@ -194,7 +220,8 @@ class WritableField(Field):
widget = widgets.TextInput
default = None
- def __init__(self, source=None, read_only=False, required=None,
+ def __init__(self, source=None, label=None, help_text=None,
+ read_only=False, required=None,
validators=[], error_messages=None, widget=None,
default=None, blank=None):
@@ -205,7 +232,7 @@ class WritableField(Field):
DeprecationWarning, stacklevel=2)
required = not(blank)
- super(WritableField, self).__init__(source=source)
+ super(WritableField, self).__init__(source=source, label=label, help_text=help_text)
self.read_only = read_only
if required is None:
@@ -268,7 +295,10 @@ class WritableField(Field):
except KeyError:
if self.default is not None and not self.partial:
# Note: partial updates shouldn't set defaults
- native = self.default
+ if is_simple_callable(self.default):
+ native = self.default()
+ else:
+ native = self.default
else:
if self.required:
raise ValidationError(self.error_messages['required'])
@@ -335,6 +365,7 @@ class ModelField(WritableField):
class BooleanField(WritableField):
type_name = 'BooleanField'
+ type_label = 'boolean'
form_field_class = forms.BooleanField
widget = widgets.CheckboxInput
default_error_messages = {
@@ -357,6 +388,7 @@ class BooleanField(WritableField):
class CharField(WritableField):
type_name = 'CharField'
+ type_label = 'string'
form_field_class = forms.CharField
def __init__(self, max_length=None, min_length=None, *args, **kwargs):
@@ -375,23 +407,38 @@ class CharField(WritableField):
class URLField(CharField):
type_name = 'URLField'
+ type_label = 'url'
def __init__(self, **kwargs):
- kwargs['max_length'] = kwargs.get('max_length', 200)
kwargs['validators'] = [validators.URLValidator()]
super(URLField, self).__init__(**kwargs)
class SlugField(CharField):
type_name = 'SlugField'
+ type_label = 'slug'
+ form_field_class = forms.SlugField
+
+ default_error_messages = {
+ 'invalid': _("Enter a valid 'slug' consisting of letters, numbers,"
+ " underscores or hyphens."),
+ }
+ default_validators = [validators.validate_slug]
def __init__(self, *args, **kwargs):
- kwargs['max_length'] = kwargs.get('max_length', 50)
super(SlugField, self).__init__(*args, **kwargs)
+ def __deepcopy__(self, memo):
+ result = copy.copy(self)
+ memo[id(self)] = result
+ #result.widget = copy.deepcopy(self.widget, memo)
+ result.validators = self.validators[:]
+ return result
+
class ChoiceField(WritableField):
type_name = 'ChoiceField'
+ type_label = 'multiple choice'
form_field_class = forms.ChoiceField
widget = widgets.Select
default_error_messages = {
@@ -402,6 +449,8 @@ class ChoiceField(WritableField):
def __init__(self, choices=(), *args, **kwargs):
super(ChoiceField, self).__init__(*args, **kwargs)
self.choices = choices
+ if not self.required:
+ self.choices = BLANK_CHOICE_DASH + self.choices
def _get_choices(self):
return self._choices
@@ -440,6 +489,7 @@ class ChoiceField(WritableField):
class EmailField(CharField):
type_name = 'EmailField'
+ type_label = 'email'
form_field_class = forms.EmailField
default_error_messages = {
@@ -463,6 +513,7 @@ class EmailField(CharField):
class RegexField(CharField):
type_name = 'RegexField'
+ type_label = 'regex'
form_field_class = forms.RegexField
def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs):
@@ -492,6 +543,7 @@ class RegexField(CharField):
class DateField(WritableField):
type_name = 'DateField'
+ type_label = 'date'
widget = widgets.DateInput
form_field_class = forms.DateField
@@ -555,6 +607,7 @@ class DateField(WritableField):
class DateTimeField(WritableField):
type_name = 'DateTimeField'
+ type_label = 'datetime'
widget = widgets.DateTimeInput
form_field_class = forms.DateTimeField
@@ -624,6 +677,7 @@ class DateTimeField(WritableField):
class TimeField(WritableField):
type_name = 'TimeField'
+ type_label = 'time'
widget = widgets.TimeInput
form_field_class = forms.TimeField
@@ -680,6 +734,7 @@ class TimeField(WritableField):
class IntegerField(WritableField):
type_name = 'IntegerField'
+ type_label = 'integer'
form_field_class = forms.IntegerField
default_error_messages = {
@@ -710,6 +765,7 @@ class IntegerField(WritableField):
class FloatField(WritableField):
type_name = 'FloatField'
+ type_label = 'float'
form_field_class = forms.FloatField
default_error_messages = {
@@ -729,6 +785,7 @@ class FloatField(WritableField):
class DecimalField(WritableField):
type_name = 'DecimalField'
+ type_label = 'decimal'
form_field_class = forms.DecimalField
default_error_messages = {
@@ -799,6 +856,7 @@ class DecimalField(WritableField):
class FileField(WritableField):
use_files = True
type_name = 'FileField'
+ type_label = 'file upload'
form_field_class = forms.FileField
widget = widgets.FileInput
@@ -842,6 +900,8 @@ class FileField(WritableField):
class ImageField(FileField):
use_files = True
+ type_name = 'ImageField'
+ type_label = 'image upload'
form_field_class = forms.ImageField
default_error_messages = {
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index f2163f6f..c058bc71 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -3,9 +3,9 @@ Provides generic filtering backends that can be used to filter the results
returned by list views.
"""
from __future__ import unicode_literals
-
from django.db import models
-from rest_framework.compat import django_filters
+from rest_framework.compat import django_filters, six
+from functools import reduce
import operator
FilterSet = django_filters and django_filters.FilterSet or None
@@ -32,40 +32,33 @@ class DjangoFilterBackend(BaseFilterBackend):
def __init__(self):
assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed'
- def get_filter_class(self, view):
+ def get_filter_class(self, view, queryset=None):
"""
Return the django-filters `FilterSet` used to filter the queryset.
"""
filter_class = getattr(view, 'filter_class', None)
filter_fields = getattr(view, 'filter_fields', None)
- model_cls = getattr(view, 'model', None)
- queryset = getattr(view, 'queryset', None)
- if model_cls is None and queryset is not None:
- model_cls = queryset.model
if filter_class:
filter_model = filter_class.Meta.model
- assert issubclass(filter_model, model_cls), \
- 'FilterSet model %s does not match view model %s' % \
- (filter_model, model_cls)
+ assert issubclass(filter_model, queryset.model), \
+ 'FilterSet model %s does not match queryset model %s' % \
+ (filter_model, queryset.model)
return filter_class
if filter_fields:
- assert model_cls is not None, 'Cannot use DjangoFilterBackend ' \
- 'on a view which does not have a .model or .queryset attribute.'
-
class AutoFilterSet(self.default_filter_set):
class Meta:
- model = model_cls
+ model = queryset.model
fields = filter_fields
return AutoFilterSet
return None
def filter_queryset(self, request, queryset, view):
- filter_class = self.get_filter_class(view)
+ filter_class = self.get_filter_class(view, queryset)
if filter_class:
return filter_class(request.QUERY_PARAMS, queryset=queryset).qs
@@ -74,6 +67,16 @@ class DjangoFilterBackend(BaseFilterBackend):
class SearchFilter(BaseFilterBackend):
+ search_param = 'search' # The URL query parameter used for the search.
+
+ def get_search_terms(self, request):
+ """
+ Search terms are set by a ?search=... query parameter,
+ and may be comma and/or whitespace delimited.
+ """
+ params = request.QUERY_PARAMS.get(self.search_param, '')
+ return params.replace(',', ' ').split()
+
def construct_search(self, field_name):
if field_name.startswith('^'):
return "%s__istartswith" % field_name[1:]
@@ -88,12 +91,53 @@ class SearchFilter(BaseFilterBackend):
search_fields = getattr(view, 'search_fields', None)
if not search_fields:
- return None
+ return queryset
orm_lookups = [self.construct_search(str(search_field))
- for search_field in self.search_fields]
- for bit in self.query.split():
- or_queries = [models.Q(**{orm_lookup: bit})
+ for search_field in search_fields]
+
+ for search_term in self.get_search_terms(request):
+ or_queries = [models.Q(**{orm_lookup: search_term})
for orm_lookup in orm_lookups]
queryset = queryset.filter(reduce(operator.or_, or_queries))
+
+ return queryset
+
+
+class OrderingFilter(BaseFilterBackend):
+ ordering_param = 'ordering' # The URL query parameter used for the ordering.
+
+ def get_ordering(self, request):
+ """
+ Search terms are set by a ?search=... query parameter,
+ and may be comma and/or whitespace delimited.
+ """
+ params = request.QUERY_PARAMS.get(self.ordering_param)
+ if params:
+ return [param.strip() for param in params.split(',')]
+
+ def get_default_ordering(self, view):
+ ordering = getattr(view, 'ordering', None)
+ if isinstance(ordering, six.string_types):
+ return (ordering,)
+ return ordering
+
+ def remove_invalid_fields(self, queryset, ordering):
+ field_names = [field.name for field in queryset.model._meta.fields]
+ return [term for term in ordering if term.lstrip('-') in field_names]
+
+ def filter_queryset(self, request, queryset, view):
+ ordering = self.get_ordering(request)
+
+ if ordering:
+ # Skip any incorrect parameters
+ ordering = self.remove_invalid_fields(queryset, ordering)
+
+ if not ordering:
+ # Use 'ordering' attribtue by default
+ ordering = self.get_default_ordering(view)
+
+ if ordering:
+ return queryset.order_by(*ordering)
+
return queryset
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 05ec93d3..9ccc7898 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -3,17 +3,28 @@ Generic views that provide commonly needed behaviour.
"""
from __future__ import unicode_literals
-from django.core.exceptions import ImproperlyConfigured
+from django.core.exceptions import ImproperlyConfigured, PermissionDenied
from django.core.paginator import Paginator, InvalidPage
from django.http import Http404
-from django.shortcuts import get_object_or_404
+from django.shortcuts import get_object_or_404 as _get_object_or_404
from django.utils.translation import ugettext as _
-from rest_framework import views, mixins
-from rest_framework.exceptions import ConfigurationError
+from rest_framework import views, mixins, exceptions
+from rest_framework.request import clone_request
from rest_framework.settings import api_settings
import warnings
+def get_object_or_404(queryset, **filter_kwargs):
+ """
+ Same as Django's standard shortcut, but make sure to raise 404
+ if the filter_kwargs don't match the required types.
+ """
+ try:
+ return _get_object_or_404(queryset, **filter_kwargs)
+ except (TypeError, ValueError):
+ raise Http404
+
+
class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
@@ -274,7 +285,7 @@ class GenericAPIView(views.APIView):
)
filter_kwargs = {self.slug_field: slug}
else:
- raise ConfigurationError(
+ raise exceptions.ConfigurationError(
'Expected view %s to be called with a URL keyword argument '
'named "%s". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.' %
@@ -310,6 +321,41 @@ class GenericAPIView(views.APIView):
"""
pass
+ def metadata(self, request):
+ """
+ Return a dictionary of metadata about the view.
+ Used to return responses for OPTIONS requests.
+
+ We override the default behavior, and add some extra information
+ about the required request body for POST and PUT operations.
+ """
+ ret = super(GenericAPIView, self).metadata(request)
+
+ actions = {}
+ for method in ('PUT', 'POST'):
+ if method not in self.allowed_methods:
+ continue
+
+ cloned_request = clone_request(request, method)
+ try:
+ # Test global permissions
+ self.check_permissions(cloned_request)
+ # Test object permissions
+ if method == 'PUT':
+ self.get_object()
+ except (exceptions.APIException, PermissionDenied, Http404):
+ pass
+ else:
+ # If user has appropriate permissions for the view, include
+ # appropriate metadata about the fields that should be supplied.
+ serializer = self.get_serializer()
+ actions[method] = serializer.metadata()
+
+ if actions:
+ ret['actions'] = actions
+
+ return ret
+
##########################################################
### Concrete view classes that provide method handlers ###
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index ae703771..f11def6d 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -10,6 +10,7 @@ from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
from rest_framework.request import clone_request
+import warnings
def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None):
@@ -42,7 +43,6 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None)
class CreateModelMixin(object):
"""
Create a model instance.
- Should be mixed in with any `GenericAPIView`.
"""
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
@@ -67,7 +67,6 @@ class CreateModelMixin(object):
class ListModelMixin(object):
"""
List a queryset.
- Should be mixed in with `MultipleObjectAPIView`.
"""
empty_error = "Empty list and '%(class_name)s.allow_empty' is False."
@@ -77,6 +76,12 @@ class ListModelMixin(object):
# Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets.
if not self.allow_empty and not self.object_list:
+ warnings.warn(
+ 'The `allow_empty` parameter is due to be deprecated. '
+ 'To use `allow_empty=False` style behavior, You should override '
+ '`get_queryset()` and explicitly raise a 404 on empty querysets.',
+ PendingDeprecationWarning
+ )
class_name = self.__class__.__name__
error_msg = self.empty_error % {'class_name': class_name}
raise Http404(error_msg)
@@ -94,7 +99,6 @@ class ListModelMixin(object):
class RetrieveModelMixin(object):
"""
Retrieve a model instance.
- Should be mixed in with `SingleObjectAPIView`.
"""
def retrieve(self, request, *args, **kwargs):
self.object = self.get_object()
@@ -105,17 +109,12 @@ class RetrieveModelMixin(object):
class UpdateModelMixin(object):
"""
Update a model instance.
- Should be mixed in with `SingleObjectAPIView`.
"""
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
- self.object = None
- try:
- self.object = self.get_object()
- except Http404:
- # If this is a PUT-as-create operation, we need to ensure that
- # we have relevant permissions, as if this was a POST request.
- self.check_permissions(clone_request(request, 'POST'))
+ self.object = self.get_object_or_none()
+
+ if self.object is None:
created = True
save_kwargs = {'force_insert': True}
success_status_code = status.HTTP_201_CREATED
@@ -139,6 +138,16 @@ class UpdateModelMixin(object):
kwargs['partial'] = True
return self.update(request, *args, **kwargs)
+ def get_object_or_none(self):
+ try:
+ return self.get_object()
+ except Http404:
+ # If this is a PUT-as-create operation, we need to ensure that
+ # we have relevant permissions, as if this was a POST request.
+ # This will either raise a PermissionDenied exception,
+ # or simply return None
+ self.check_permissions(clone_request(self.request, 'POST'))
+
def pre_save(self, obj):
"""
Set any attributes on the object that are implicit in the request.
@@ -168,7 +177,6 @@ class UpdateModelMixin(object):
class DestroyModelMixin(object):
"""
Destroy a model instance.
- Should be mixed in with `SingleObjectAPIView`.
"""
def destroy(self, request, *args, **kwargs):
obj = self.get_object()
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index 751f31a7..45fcfd66 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -126,6 +126,11 @@ class DjangoModelPermissions(BasePermission):
if model_cls is None and queryset is not None:
model_cls = queryset.model
+ # Workaround to ensure DjangoModelPermissions are not applied
+ # to the root view when using DefaultRouter.
+ if model_cls is None and getattr(view, '_ignore_model_permissions'):
+ return True
+
assert model_cls, ('Cannot apply DjangoModelPermissions on a view that'
' does not have `.model` or `.queryset` property.')
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index fc5054b2..edaf76d6 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -8,10 +8,11 @@ from __future__ import unicode_literals
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
from django import forms
+from django.db.models.fields import BLANK_CHOICE_DASH
from django.forms import widgets
from django.forms.models import ModelChoiceIterator
from django.utils.translation import ugettext_lazy as _
-from rest_framework.fields import Field, WritableField, get_component
+from rest_framework.fields import Field, WritableField, get_component, is_simple_callable
from rest_framework.reverse import reverse
from rest_framework.compat import urlparse
from rest_framework.compat import smart_text
@@ -47,7 +48,7 @@ class RelatedField(WritableField):
DeprecationWarning, stacklevel=2)
kwargs['required'] = not kwargs.pop('null')
- self.queryset = kwargs.pop('queryset', None)
+ queryset = kwargs.pop('queryset', None)
self.many = kwargs.pop('many', self.many)
if self.many:
self.widget = self.many_widget
@@ -56,6 +57,11 @@ class RelatedField(WritableField):
kwargs['read_only'] = kwargs.pop('read_only', self.read_only)
super(RelatedField, self).__init__(*args, **kwargs)
+ if not self.required:
+ self.empty_label = BLANK_CHOICE_DASH[0][1]
+
+ self.queryset = queryset
+
def initialize(self, parent, field_name):
super(RelatedField, self).initialize(parent, field_name)
if self.queryset is None and not self.read_only:
@@ -66,7 +72,6 @@ class RelatedField(WritableField):
else: # Reverse
self.queryset = manager.field.rel.to._default_manager.all()
except Exception:
- raise
msg = ('Serializer related fields must include a `queryset`' +
' argument or set `read_only=True')
raise Exception(msg)
@@ -139,7 +144,12 @@ class RelatedField(WritableField):
return None
if self.many:
- return [self.to_native(item) for item in value.all()]
+ if is_simple_callable(getattr(value, 'all', None)):
+ return [self.to_native(item) for item in value.all()]
+ else:
+ # Also support non-queryset iterables.
+ # This allows us to also support plain lists of related items.
+ return [self.to_native(item) for item in value]
return self.to_native(value)
def field_from_native(self, data, files, field_name, into):
@@ -221,15 +231,28 @@ class PrimaryKeyRelatedField(RelatedField):
def field_to_native(self, obj, field_name):
if self.many:
# To-many relationship
- try:
+
+ queryset = None
+ if not self.source:
# Prefer obj.serializable_value for performance reasons
- queryset = obj.serializable_value(self.source or field_name)
- except AttributeError:
+ try:
+ queryset = obj.serializable_value(field_name)
+ except AttributeError:
+ pass
+ if queryset is None:
# RelatedManager (reverse relationship)
- queryset = getattr(obj, self.source or field_name)
+ source = self.source or field_name
+ queryset = obj
+ for component in source.split('.'):
+ queryset = get_component(queryset, component)
# Forward relationship
- return [self.to_native(item.pk) for item in queryset.all()]
+ if is_simple_callable(getattr(queryset, 'all', None)):
+ return [self.to_native(item.pk) for item in queryset.all()]
+ else:
+ # Also support non-queryset iterables.
+ # This allows us to also support plain lists of related items.
+ return [self.to_native(item.pk) for item in queryset]
# To-one relationship
try:
@@ -434,7 +457,7 @@ class HyperlinkedRelatedField(RelatedField):
raise Exception('Writable related fields must include a `queryset` argument')
try:
- http_prefix = value.startswith('http:') or value.startswith('https:')
+ http_prefix = value.startswith(('http:', 'https:'))
except AttributeError:
msg = self.error_messages['incorrect_type']
raise ValidationError(msg % type(value).__name__)
@@ -465,17 +488,35 @@ class HyperlinkedIdentityField(Field):
"""
Represents the instance, or a property on the instance, using hyperlinking.
"""
+ lookup_field = 'pk'
+ read_only = True
+
+ # These are all pending deprecation
pk_url_kwarg = 'pk'
slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
- read_only = True
def __init__(self, *args, **kwargs):
- # TODO: Make view_name mandatory, and have the
- # HyperlinkedModelSerializer set it on-the-fly
- self.view_name = kwargs.pop('view_name', None)
- # Optionally the format of the target hyperlink may be specified
+ try:
+ self.view_name = kwargs.pop('view_name')
+ except KeyError:
+ msg = "HyperlinkedIdentityField requires 'view_name' argument"
+ raise ValueError(msg)
+
self.format = kwargs.pop('format', None)
+ lookup_field = kwargs.pop('lookup_field', None)
+ self.lookup_field = lookup_field or self.lookup_field
+
+ # These are pending deprecation
+ if 'pk_url_kwarg' in kwargs:
+ msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_url_kwarg' in kwargs:
+ msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_field' in kwargs:
+ msg = 'slug_field is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
self.slug_field = kwargs.pop('slug_field', self.slug_field)
default_slug_kwarg = self.slug_url_kwarg or self.slug_field
@@ -487,8 +528,7 @@ class HyperlinkedIdentityField(Field):
def field_to_native(self, obj, field_name):
request = self.context.get('request', None)
format = self.context.get('format', None)
- view_name = self.view_name or self.parent.opts.view_name
- kwargs = {self.pk_url_kwarg: obj.pk}
+ view_name = self.view_name
if request is None:
warnings.warn("Using `HyperlinkedIdentityField` without including the "
@@ -508,29 +548,51 @@ class HyperlinkedIdentityField(Field):
if format and self.format and self.format != format:
format = self.format
+ # Return the hyperlink, or error if incorrectly configured.
try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ return self.get_url(obj, view_name, request, format)
except NoReverseMatch:
- pass
-
- slug = getattr(obj, self.slug_field, None)
+ msg = (
+ 'Could not resolve URL for hyperlinked relationship using '
+ 'view name "%s". You may have failed to include the related '
+ 'model in your API, or incorrectly configured the '
+ '`lookup_field` attribute on this field.'
+ )
+ raise Exception(msg % view_name)
- if not slug:
- raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+ def get_url(self, obj, view_name, request, format):
+ """
+ Given an object, return the URL that hyperlinks to the object.
- kwargs = {self.slug_url_kwarg: slug}
+ May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
+ attributes are not configured to correctly match the URL conf.
+ """
+ lookup_field = getattr(obj, self.lookup_field)
+ kwargs = {self.lookup_field: lookup_field}
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch:
pass
- kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
+ if self.pk_url_kwarg != 'pk':
+ # Only try pk lookup if it has been explicitly set.
+ # Otherwise, the default `lookup_field = 'pk'` has us covered.
+ kwargs = {self.pk_url_kwarg: obj.pk}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except NoReverseMatch:
+ pass
- raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+ slug = getattr(obj, self.slug_field, None)
+ if slug:
+ # Only use slug lookup if a slug field exists on the model
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except NoReverseMatch:
+ pass
+
+ raise NoReverseMatch()
### Old-style many classes for backwards compat
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 1917a080..b2fe43ea 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -9,7 +9,6 @@ REST framework also provides an HTML renderer the renders the browsable API.
from __future__ import unicode_literals
import copy
-import string
import json
from django import forms
from django.http.multipartparser import parse_header
@@ -36,6 +35,7 @@ class BaseRenderer(object):
media_type = None
format = None
+ charset = 'utf-8'
def render(self, data, accepted_media_type=None, renderer_context=None):
raise NotImplemented('Renderer class requires .render() to be implemented')
@@ -43,16 +43,21 @@ class BaseRenderer(object):
class JSONRenderer(BaseRenderer):
"""
- Renderer which serializes to json.
+ Renderer which serializes to JSON.
+ Applies JSON's backslash-u character escaping for non-ascii characters.
"""
media_type = 'application/json'
format = 'json'
encoder_class = encoders.JSONEncoder
+ ensure_ascii = True
+ charset = 'utf-8'
+ # Note that JSON encodings must be utf-8, utf-16 or utf-32.
+ # See: http://www.ietf.org/rfc/rfc4627.txt
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
- Render `obj` into json.
+ Render `data` into JSON.
"""
if data is None:
return ''
@@ -72,7 +77,25 @@ class JSONRenderer(BaseRenderer):
except (ValueError, TypeError):
indent = None
- return json.dumps(data, cls=self.encoder_class, indent=indent)
+ ret = json.dumps(data, cls=self.encoder_class,
+ indent=indent, ensure_ascii=self.ensure_ascii)
+
+ # On python 2.x json.dumps() returns bytestrings if ensure_ascii=True,
+ # but if ensure_ascii=False, the return type is underspecified,
+ # and may (or may not) be unicode.
+ # On python 3.x json.dumps() returns unicode strings.
+ if isinstance(ret, six.text_type):
+ return bytes(ret.encode(self.charset))
+ return ret
+
+
+class UnicodeJSONRenderer(JSONRenderer):
+ ensure_ascii = False
+ charset = 'utf-8'
+ """
+ Renderer which serializes to JSON.
+ Does *not* apply JSON's character escaping for non-ascii characters.
+ """
class JSONPRenderer(JSONRenderer):
@@ -105,7 +128,7 @@ class JSONPRenderer(JSONRenderer):
callback = self.get_callback(renderer_context)
json = super(JSONPRenderer, self).render(data, accepted_media_type,
renderer_context)
- return "%s(%s);" % (callback, json)
+ return callback.encode(self.charset) + b'(' + json + b');'
class XMLRenderer(BaseRenderer):
@@ -115,6 +138,7 @@ class XMLRenderer(BaseRenderer):
media_type = 'application/xml'
format = 'xml'
+ charset = 'utf-8'
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -125,7 +149,7 @@ class XMLRenderer(BaseRenderer):
stream = StringIO()
- xml = SimplerXMLGenerator(stream, "utf-8")
+ xml = SimplerXMLGenerator(stream, self.charset)
xml.startDocument()
xml.startElement("root", {})
@@ -164,6 +188,7 @@ class YAMLRenderer(BaseRenderer):
media_type = 'application/yaml'
format = 'yaml'
encoder = encoders.SafeDumper
+ charset = 'utf-8'
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -174,7 +199,7 @@ class YAMLRenderer(BaseRenderer):
if data is None:
return ''
- return yaml.dump(data, stream=None, Dumper=self.encoder)
+ return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder)
class TemplateHTMLRenderer(BaseRenderer):
@@ -204,6 +229,7 @@ class TemplateHTMLRenderer(BaseRenderer):
'%(status_code)s.html',
'api_exception.html'
]
+ charset = 'utf-8'
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -275,6 +301,7 @@ class StaticHTMLRenderer(TemplateHTMLRenderer):
"""
media_type = 'text/html'
format = 'html'
+ charset = 'utf-8'
def render(self, data, accepted_media_type=None, renderer_context=None):
renderer_context = renderer_context or {}
@@ -296,6 +323,7 @@ class BrowsableAPIRenderer(BaseRenderer):
media_type = 'text/html'
format = 'api'
template = 'rest_framework/api.html'
+ charset = 'utf-8'
def get_default_renderer(self, view):
"""
@@ -320,8 +348,8 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer_context['indent'] = 4
content = renderer.render(data, accepted_media_type, renderer_context)
- if not all(char in string.printable for char in content):
- return '[%d bytes of binary content]'
+ if renderer.charset is None:
+ return '[%d bytes of binary content]' % len(content)
return content
@@ -336,7 +364,9 @@ class BrowsableAPIRenderer(BaseRenderer):
return # Cannot use form overloading
try:
- view.check_permissions(clone_request(request, method))
+ view.check_permissions(request)
+ if obj is not None:
+ view.check_object_permissions(request, obj)
except exceptions.APIException:
return False # Doesn't have permissions
return True
@@ -366,12 +396,40 @@ class BrowsableAPIRenderer(BaseRenderer):
if getattr(v, 'default', None) is not None:
kwargs['initial'] = v.default
- kwargs['label'] = k
+ if getattr(v, 'label', None) is not None:
+ kwargs['label'] = v.label
+
+ if getattr(v, 'help_text', None) is not None:
+ kwargs['help_text'] = v.help_text
fields[k] = v.form_field_class(**kwargs)
return fields
+ def _get_form(self, view, method, request):
+ # We need to impersonate a request with the correct method,
+ # so that eg. any dynamic get_serializer_class methods return the
+ # correct form for each method.
+ restore = view.request
+ request = clone_request(request, method)
+ view.request = request
+ try:
+ return self.get_form(view, method, request)
+ finally:
+ view.request = restore
+
+ def _get_raw_data_form(self, view, method, request, media_types):
+ # We need to impersonate a request with the correct method,
+ # so that eg. any dynamic get_serializer_class methods return the
+ # correct form for each method.
+ restore = view.request
+ request = clone_request(request, method)
+ view.request = request
+ try:
+ return self.get_raw_data_form(view, method, request, media_types)
+ finally:
+ view.request = restore
+
def get_form(self, view, method, request):
"""
Get a form, possibly bound to either the input or output data.
@@ -449,10 +507,7 @@ class BrowsableAPIRenderer(BaseRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
- Renders *obj* using the :attr:`template` set on the class.
-
- The context used in the template contains all the information
- needed to self-document the response to this request.
+ Render the HTML for the browsable API representation.
"""
accepted_media_type = accepted_media_type or ''
renderer_context = renderer_context or {}
@@ -465,15 +520,15 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer = self.get_default_renderer(view)
content = self.get_content(renderer, data, accepted_media_type, renderer_context)
- put_form = self.get_form(view, 'PUT', request)
- post_form = self.get_form(view, 'POST', request)
- patch_form = self.get_form(view, 'PATCH', request)
- delete_form = self.get_form(view, 'DELETE', request)
- options_form = self.get_form(view, 'OPTIONS', request)
+ put_form = self._get_form(view, 'PUT', request)
+ post_form = self._get_form(view, 'POST', request)
+ patch_form = self._get_form(view, 'PATCH', request)
+ delete_form = self._get_form(view, 'DELETE', request)
+ options_form = self._get_form(view, 'OPTIONS', request)
- raw_data_put_form = self.get_raw_data_form(view, 'PUT', request, media_types)
- raw_data_post_form = self.get_raw_data_form(view, 'POST', request, media_types)
- raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request, media_types)
+ raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types)
+ raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types)
+ raw_data_patch_form = self._get_raw_data_form(view, 'PATCH', request, media_types)
raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form
name = self.get_name(view)
diff --git a/rest_framework/request.py b/rest_framework/request.py
index a434659c..0d88ebc7 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -173,7 +173,7 @@ class Request(object):
by the authentication classes provided to the request.
"""
if not hasattr(self, '_user'):
- self._authenticator, self._user, self._auth = self._authenticate()
+ self._authenticate()
return self._user
@user.setter
@@ -192,7 +192,7 @@ class Request(object):
request, such as an authentication token.
"""
if not hasattr(self, '_auth'):
- self._authenticator, self._user, self._auth = self._authenticate()
+ self._authenticate()
return self._auth
@auth.setter
@@ -210,7 +210,7 @@ class Request(object):
to authenticate the request, or `None`.
"""
if not hasattr(self, '_authenticator'):
- self._authenticator, self._user, self._auth = self._authenticate()
+ self._authenticate()
return self._authenticator
def _load_data_and_files(self):
@@ -330,11 +330,18 @@ class Request(object):
Returns a three-tuple of (authenticator, user, authtoken).
"""
for authenticator in self.authenticators:
- user_auth_tuple = authenticator.authenticate(self)
+ try:
+ user_auth_tuple = authenticator.authenticate(self)
+ except exceptions.APIException:
+ self._not_authenticated()
+ raise
+
if not user_auth_tuple is None:
- user, auth = user_auth_tuple
- return (authenticator, user, auth)
- return self._not_authenticated()
+ self._authenticator = authenticator
+ self._user, self._auth = user_auth_tuple
+ return
+
+ self._not_authenticated()
def _not_authenticated(self):
"""
@@ -343,17 +350,17 @@ class Request(object):
By default this will be (None, AnonymousUser, None).
"""
+ self._authenticator = None
+
if api_settings.UNAUTHENTICATED_USER:
- user = api_settings.UNAUTHENTICATED_USER()
+ self._user = api_settings.UNAUTHENTICATED_USER()
else:
- user = None
+ self._user = None
if api_settings.UNAUTHENTICATED_TOKEN:
- auth = api_settings.UNAUTHENTICATED_TOKEN()
+ self._auth = api_settings.UNAUTHENTICATED_TOKEN()
else:
- auth = None
-
- return (None, user, auth)
+ self._auth = None
def __getattr__(self, attr):
"""
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 26e4ab37..5877c8a3 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -1,5 +1,5 @@
"""
-The Response class in REST framework is similiar to HTTPResponse, except that
+The Response class in REST framework is similar to HTTPResponse, except that
it is initialized with unrendered data, instead of a pre-rendered string.
The appropriate renderer is called during Django's template response rendering.
@@ -12,13 +12,13 @@ from rest_framework.compat import six
class Response(SimpleTemplateResponse):
"""
- An HttpResponse that allows it's data to be rendered into
+ An HttpResponse that allows its data to be rendered into
arbitrary media types.
"""
def __init__(self, data=None, status=200,
template_name=None, headers=None,
- exception=False):
+ exception=False, content_type=None):
"""
Alters the init arguments slightly.
For example, drop 'template_name', and instead use 'data'.
@@ -30,6 +30,7 @@ class Response(SimpleTemplateResponse):
self.data = data
self.template_name = template_name
self.exception = exception
+ self.content_type = content_type
if headers:
for name, value in six.iteritems(headers):
@@ -46,8 +47,21 @@ class Response(SimpleTemplateResponse):
assert context, ".renderer_context not set on Response"
context['response'] = self
- self['Content-Type'] = media_type
- return renderer.render(self.data, media_type, context)
+ charset = renderer.charset
+ content_type = self.content_type
+
+ if content_type is None and charset is not None:
+ content_type = "{0}; charset={1}".format(media_type, charset)
+ elif content_type is None:
+ content_type = media_type
+ self['Content-Type'] = content_type
+
+ ret = renderer.render(self.data, media_type, context)
+ if isinstance(ret, six.text_type):
+ assert charset, 'renderer returned unicode, and did not specify ' \
+ 'a charset value.'
+ return bytes(ret.encode(charset))
+ return ret
@property
def status_text(self):
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 0707635a..9764e569 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -16,8 +16,8 @@ For example, you might have a `urls.py` that looks something like this:
from __future__ import unicode_literals
from collections import namedtuple
-from django.conf.urls import url, patterns
-from rest_framework.decorators import api_view
+from rest_framework import views
+from rest_framework.compat import patterns, url
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.urlpatterns import format_suffix_patterns
@@ -71,7 +71,7 @@ class SimpleRouter(BaseRouter):
routes = [
# List route.
Route(
- url=r'^{prefix}/$',
+ url=r'^{prefix}{trailing_slash}$',
mapping={
'get': 'list',
'post': 'create'
@@ -81,7 +81,7 @@ class SimpleRouter(BaseRouter):
),
# Detail route.
Route(
- url=r'^{prefix}/{lookup}/$',
+ url=r'^{prefix}/{lookup}{trailing_slash}$',
mapping={
'get': 'retrieve',
'put': 'update',
@@ -94,7 +94,7 @@ class SimpleRouter(BaseRouter):
# Dynamically generated routes.
# Generated using @action or @link decorators on methods of the viewset.
Route(
- url=r'^{prefix}/{lookup}/{methodname}/$',
+ url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$',
mapping={
'{httpmethod}': '{methodname}',
},
@@ -103,6 +103,10 @@ class SimpleRouter(BaseRouter):
),
]
+ def __init__(self, trailing_slash=True):
+ self.trailing_slash = trailing_slash and '/' or ''
+ super(SimpleRouter, self).__init__()
+
def get_default_base_name(self, viewset):
"""
If `base_name` is not specified, attempt to automatically determine
@@ -127,23 +131,23 @@ class SimpleRouter(BaseRouter):
"""
# Determine any `@action` or `@link` decorated methods on the viewset
- dynamic_routes = {}
+ dynamic_routes = []
for methodname in dir(viewset):
attr = getattr(viewset, methodname)
- httpmethod = getattr(attr, 'bind_to_method', None)
- if httpmethod:
- dynamic_routes[httpmethod] = methodname
+ httpmethods = getattr(attr, 'bind_to_methods', None)
+ if httpmethods:
+ dynamic_routes.append((httpmethods, methodname))
ret = []
for route in self.routes:
if route.mapping == {'{httpmethod}': '{methodname}'}:
# Dynamic routes (@link or @action decorator)
- for httpmethod, methodname in dynamic_routes.items():
+ for httpmethods, methodname in dynamic_routes:
initkwargs = route.initkwargs.copy()
initkwargs.update(getattr(viewset, methodname).kwargs)
ret.append(Route(
url=replace_methodname(route.url, methodname),
- mapping={httpmethod: methodname},
+ mapping=dict((httpmethod, methodname) for httpmethod in httpmethods),
name=replace_methodname(route.name, methodname),
initkwargs=initkwargs,
))
@@ -192,7 +196,11 @@ class SimpleRouter(BaseRouter):
continue
# Build the url pattern
- regex = route.url.format(prefix=prefix, lookup=lookup)
+ regex = route.url.format(
+ prefix=prefix,
+ lookup=lookup,
+ trailing_slash=self.trailing_slash
+ )
view = viewset.as_view(mapping, **route.initkwargs)
name = route.name.format(basename=basename)
ret.append(url(regex, view, name=name))
@@ -217,14 +225,16 @@ class DefaultRouter(SimpleRouter):
for prefix, viewset, basename in self.registry:
api_root_dict[prefix] = list_name.format(basename=basename)
- @api_view(('GET',))
- def api_root(request, format=None):
- ret = {}
- for key, url_name in api_root_dict.items():
- ret[key] = reverse(url_name, request=request, format=format)
- return Response(ret)
+ class APIRoot(views.APIView):
+ _ignore_model_permissions = True
+
+ def get(self, request, format=None):
+ ret = {}
+ for key, url_name in api_root_dict.items():
+ ret[key] = reverse(url_name, request=request, format=format)
+ return Response(ret)
- return api_root
+ return APIRoot.as_view()
def get_urls(self):
"""
diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py
index 4a333fb3..da36d23f 100755
--- a/rest_framework/runtests/runtests.py
+++ b/rest_framework/runtests/runtests.py
@@ -10,6 +10,7 @@ import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'
+import django
from django.conf import settings
from django.test.utils import get_runner
@@ -35,7 +36,11 @@ def main():
else:
print(usage())
sys.exit(1)
- failures = test_runner.run_tests(['tests' + test_case])
+ test_module_name = 'rest_framework.tests'
+ if django.VERSION[0] == 1 and django.VERSION[1] < 6:
+ test_module_name = 'tests'
+
+ failures = test_runner.run_tests([test_module_name + test_case])
sys.exit(failures)
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index 9b519f27..9dd7b545 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -4,6 +4,8 @@ DEBUG = True
TEMPLATE_DEBUG = DEBUG
DEBUG_PROPAGATE_EXCEPTIONS = True
+ALLOWED_HOSTS = ['*']
+
ADMINS = (
# ('Your Name', 'your_email@domain.com'),
)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 87f08374..4acbc704 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -25,7 +25,7 @@ from rest_framework.compat import get_concrete_model, six
#
# example_field = serializers.CharField(...)
#
-# This helps keep the seperation between model fields, form fields, and
+# This helps keep the separation between model fields, form fields, and
# serializer fields more explicit.
from rest_framework.relations import *
@@ -61,7 +61,7 @@ class DictWithMetadata(dict):
def __getstate__(self):
"""
Used by pickle (e.g., caching).
- Overriden to remove the metadata from the dict, since it shouldn't be
+ Overridden to remove the metadata from the dict, since it shouldn't be
pickled and may in some instances be unpickleable.
"""
return dict(self)
@@ -202,7 +202,7 @@ class BaseSerializer(WritableField):
# If 'fields' is specified, use those fields, in that order.
if self.opts.fields:
- assert isinstance(self.opts.fields, (list, tuple)), '`include` must be a list or tuple'
+ assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple'
new = SortedDict()
for key in self.opts.fields:
new[key] = ret[key]
@@ -210,7 +210,7 @@ class BaseSerializer(WritableField):
# Remove anything in 'exclude'
if self.opts.exclude:
- assert isinstance(self.opts.fields, (list, tuple)), '`exclude` must be a list or tuple'
+ assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple'
for key in self.opts.exclude:
ret.pop(key, None)
@@ -317,7 +317,8 @@ class BaseSerializer(WritableField):
self._errors = {}
if data is not None or files is not None:
attrs = self.restore_fields(data, files)
- attrs = self.perform_validation(attrs)
+ if attrs is not None:
+ attrs = self.perform_validation(attrs)
else:
self._errors['non_field_errors'] = ['No input provided']
@@ -381,24 +382,28 @@ class BaseSerializer(WritableField):
obj = getattr(self.parent.object, field_name) if self.parent.object else None
obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
- if value in (None, ''):
- into[(self.source or field_name)] = None
+ if self.source == '*':
+ if value:
+ into.update(value)
else:
- kwargs = {
- 'instance': obj,
- 'data': value,
- 'context': self.context,
- 'partial': self.partial,
- 'many': self.many,
- 'allow_add_remove': self.allow_add_remove
- }
- serializer = self.__class__(**kwargs)
-
- if serializer.is_valid():
- into[self.source or field_name] = serializer.object
+ if value in (None, ''):
+ into[(self.source or field_name)] = None
else:
- # Propagate errors up to our parent
- raise NestedValidationError(serializer.errors)
+ kwargs = {
+ 'instance': obj,
+ 'data': value,
+ 'context': self.context,
+ 'partial': self.partial,
+ 'many': self.many,
+ 'allow_add_remove': self.allow_add_remove
+ }
+ serializer = self.__class__(**kwargs)
+
+ if serializer.is_valid():
+ into[self.source or field_name] = serializer.object
+ else:
+ # Propagate errors up to our parent
+ raise NestedValidationError(serializer.errors)
def get_identity(self, data):
"""
@@ -521,6 +526,17 @@ class BaseSerializer(WritableField):
return self.object
+ def metadata(self):
+ """
+ Return a dictionary of metadata about the fields on the serializer.
+ Useful for things like responding to OPTIONS requests, or generating
+ API schemas for auto-documentation.
+ """
+ return SortedDict(
+ [(field_name, field.metadata())
+ for field_name, field in six.iteritems(self.fields)]
+ )
+
class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)):
pass
@@ -591,11 +607,16 @@ class ModelSerializer(Serializer):
forward_rels += [field for field in opts.many_to_many if field.serialize]
for model_field in forward_rels:
+ has_through_model = False
+
if model_field.rel:
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
related_model = model_field.rel.to
+ if to_many and not model_field.rel.through._meta.auto_created:
+ has_through_model = True
+
if model_field.rel and nested:
if len(inspect.getargspec(self.get_nested_field).args) == 2:
warnings.warn(
@@ -624,6 +645,9 @@ class ModelSerializer(Serializer):
field = self.get_field(model_field)
if field:
+ if has_through_model:
+ field.read_only = True
+
ret[model_field.name] = field
# Deal with reverse relationships
@@ -641,6 +665,12 @@ class ModelSerializer(Serializer):
continue
related_model = relation.model
to_many = relation.field.rel.multiple
+ has_through_model = False
+ is_m2m = isinstance(relation.field,
+ models.fields.related.ManyToManyField)
+
+ if is_m2m and not relation.field.rel.through._meta.auto_created:
+ has_through_model = True
if nested:
field = self.get_nested_field(None, related_model, to_many)
@@ -648,13 +678,22 @@ class ModelSerializer(Serializer):
field = self.get_related_field(None, related_model, to_many)
if field:
+ if has_through_model:
+ field.read_only = True
+
ret[accessor_name] = field
# Add the `read_only` flag to any fields that have bee specified
# in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
+ assert field_name not in self.base_fields.keys(), \
+ "field '%s' on serializer '%s' specfied in " \
+ "`read_only_fields`, but also added " \
+ "as an explict field. Remove it from `read_only_fields`." % \
+ (field_name, self.__class__.__name__)
assert field_name in ret, \
- "read_only_fields on '%s' included invalid item '%s'" % \
+ "Noexistant field '%s' specified in `read_only_fields` " \
+ "on serializer '%s'." % \
(self.__class__.__name__, field_name)
ret[field_name].read_only = True
@@ -703,25 +742,51 @@ class ModelSerializer(Serializer):
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
- has_default = model_field.has_default()
- if model_field.null or model_field.blank or has_default:
+ if model_field.null or model_field.blank:
kwargs['required'] = False
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
- if has_default:
+ if model_field.has_default():
kwargs['default'] = model_field.get_default()
if issubclass(model_field.__class__, models.TextField):
kwargs['widget'] = widgets.Textarea
+ if model_field.verbose_name is not None:
+ kwargs['label'] = model_field.verbose_name
+
+ if model_field.help_text is not None:
+ kwargs['help_text'] = model_field.help_text
+
# TODO: TypedChoiceField?
if model_field.flatchoices: # This ModelField contains choices
kwargs['choices'] = model_field.flatchoices
return ChoiceField(**kwargs)
+ # put this below the ChoiceField because min_value isn't a valid initializer
+ if issubclass(model_field.__class__, models.PositiveIntegerField) or\
+ issubclass(model_field.__class__, models.PositiveSmallIntegerField):
+ kwargs['min_value'] = 0
+
+ attribute_dict = {
+ models.CharField: ['max_length'],
+ models.CommaSeparatedIntegerField: ['max_length'],
+ models.DecimalField: ['max_digits', 'decimal_places'],
+ models.EmailField: ['max_length'],
+ models.FileField: ['max_length'],
+ models.ImageField: ['max_length'],
+ models.SlugField: ['max_length'],
+ models.URLField: ['max_length'],
+ }
+
+ if model_field.__class__ in attribute_dict:
+ attributes = attribute_dict[model_field.__class__]
+ for attribute in attributes:
+ kwargs.update({attribute: getattr(model_field, attribute)})
+
try:
return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
@@ -867,7 +932,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
def __init__(self, meta):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
- self.lookup_field = getattr(meta, 'slug_field', None)
+ self.lookup_field = getattr(meta, 'lookup_field', None)
class HyperlinkedModelSerializer(ModelSerializer):
@@ -879,13 +944,24 @@ class HyperlinkedModelSerializer(ModelSerializer):
_default_view_name = '%(model_name)s-detail'
_hyperlink_field_class = HyperlinkedRelatedField
- url = HyperlinkedIdentityField()
+ # Just a placeholder to ensure 'url' is the first field
+ # The field itself is actually created on initialization,
+ # when the view_name and lookup_field arguments are available.
+ url = Field()
def __init__(self, *args, **kwargs):
super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
+
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model)
+ url_field = HyperlinkedIdentityField(
+ view_name=self.opts.view_name,
+ lookup_field=self.opts.lookup_field
+ )
+ url_field.initialize(self, 'url')
+ self.fields['url'] = url_field
+
def _get_default_view_name(self, model):
"""
Return the view name to use if 'view_name' is not specified in 'Meta'
diff --git a/rest_framework/static/rest_framework/css/bootstrap-tweaks.css b/rest_framework/static/rest_framework/css/bootstrap-tweaks.css
index c650ef2e..6bfb778c 100644
--- a/rest_framework/static/rest_framework/css/bootstrap-tweaks.css
+++ b/rest_framework/static/rest_framework/css/bootstrap-tweaks.css
@@ -19,4 +19,167 @@ a single block in the template.
.navbar-inverse .brand:hover a {
color: white;
text-decoration: none;
-} \ No newline at end of file
+}
+
+/* custom navigation styles */
+.wrapper .navbar{
+ width: 100%;
+ position: absolute;
+ left: 0;
+ top: 0;
+}
+
+.navbar .navbar-inner{
+ background: #2C2C2C;
+ color: white;
+ border: none;
+ border-top: 5px solid #A30000;
+ border-radius: 0px;
+}
+
+.navbar .navbar-inner .nav li, .navbar .navbar-inner .nav li a, .navbar .navbar-inner .brand:hover{
+ color: white;
+}
+
+.nav-list > .active > a, .nav-list > .active > a:hover {
+ background: #2c2c2c;
+}
+
+.navbar .navbar-inner .dropdown-menu li a, .navbar .navbar-inner .dropdown-menu li{
+ color: #A30000;
+}
+.navbar .navbar-inner .dropdown-menu li a:hover{
+ background: #eeeeee;
+ color: #c20000;
+}
+
+/*=== dabapps bootstrap styles ====*/
+
+html{
+ width:100%;
+ background: none;
+}
+
+body, .navbar .navbar-inner .container-fluid {
+ max-width: 1150px;
+ margin: 0 auto;
+}
+
+body{
+ background: url("../img/grid.png") repeat-x;
+ background-attachment: fixed;
+}
+
+#content{
+ margin: 0;
+}
+
+/* sticky footer and footer */
+html, body {
+ height: 100%;
+}
+.wrapper {
+ min-height: 100%;
+ height: auto !important;
+ height: 100%;
+ margin: 0 auto -60px;
+}
+
+.form-switcher {
+ margin-bottom: 0;
+}
+
+.well {
+ -webkit-box-shadow: none;
+ -moz-box-shadow: none;
+ box-shadow: none;
+}
+
+.well .form-actions {
+ padding-bottom: 0;
+ margin-bottom: 0;
+}
+
+.well form {
+ margin-bottom: 0;
+}
+
+.well form .help-block {
+ color: #999;
+}
+
+.nav-tabs {
+ border: 0;
+}
+
+.nav-tabs > li {
+ float: right;
+}
+
+.nav-tabs li a {
+ margin-right: 0;
+}
+
+.nav-tabs > .active > a {
+ background: #f5f5f5;
+}
+
+.nav-tabs > .active > a:hover {
+ background: #f5f5f5;
+}
+
+.tabbable.first-tab-active .tab-content
+{
+ border-top-right-radius: 0;
+}
+
+#footer, #push {
+ height: 60px; /* .push must be the same height as .footer */
+}
+
+#footer{
+ text-align: right;
+}
+
+#footer p {
+ text-align: center;
+ color: gray;
+ border-top: 1px solid #DDD;
+ padding-top: 10px;
+}
+
+#footer a {
+ color: gray;
+ font-weight: bold;
+}
+
+#footer a:hover {
+ color: gray;
+}
+
+.page-header {
+ border-bottom: none;
+ padding-bottom: 0px;
+ margin-bottom: 20px;
+}
+
+/* custom general page styles */
+.hero-unit h2, .hero-unit h1{
+ color: #A30000;
+}
+
+body a, body a{
+ color: #A30000;
+}
+
+body a:hover{
+ color: #c20000;
+}
+
+#content a span{
+ text-decoration: underline;
+ }
+
+.request-info {
+ clear:both;
+}
diff --git a/rest_framework/static/rest_framework/css/default.css b/rest_framework/static/rest_framework/css/default.css
index d806267b..0261a303 100644
--- a/rest_framework/static/rest_framework/css/default.css
+++ b/rest_framework/static/rest_framework/css/default.css
@@ -69,152 +69,3 @@ pre {
margin-bottom: 20px;
}
-
-/*=== dabapps bootstrap styles ====*/
-
-html{
- width:100%;
- background: none;
-}
-
-body, .navbar .navbar-inner .container-fluid {
- max-width: 1150px;
- margin: 0 auto;
-}
-
-body{
- background: url("../img/grid.png") repeat-x;
- background-attachment: fixed;
-}
-
-#content{
- margin: 0;
-}
-/* custom navigation styles */
-.wrapper .navbar{
- width: 100%;
- position: absolute;
- left: 0;
- top: 0;
-}
-
-.navbar .navbar-inner{
- background: #2C2C2C;
- color: white;
- border: none;
- border-top: 5px solid #A30000;
- border-radius: 0px;
-}
-
-.navbar .navbar-inner .nav li, .navbar .navbar-inner .nav li a, .navbar .navbar-inner .brand{
- color: white;
-}
-
-.nav-list > .active > a, .nav-list > .active > a:hover {
- background: #2c2c2c;
-}
-
-.navbar .navbar-inner .dropdown-menu li a, .navbar .navbar-inner .dropdown-menu li{
- color: #A30000;
-}
-.navbar .navbar-inner .dropdown-menu li a:hover{
- background: #eeeeee;
- color: #c20000;
-}
-
-/* custom general page styles */
-.hero-unit h2, .hero-unit h1{
- color: #A30000;
-}
-
-body a, body a{
- color: #A30000;
-}
-
-body a:hover{
- color: #c20000;
-}
-
-#content a span{
- text-decoration: underline;
- }
-
-/* sticky footer and footer */
-html, body {
- height: 100%;
-}
-.wrapper {
- min-height: 100%;
- height: auto !important;
- height: 100%;
- margin: 0 auto -60px;
-}
-
-.form-switcher {
- margin-bottom: 0;
-}
-
-.well {
- -webkit-box-shadow: none;
- -moz-box-shadow: none;
- box-shadow: none;
-}
-
-.well .form-actions {
- padding-bottom: 0;
- margin-bottom: 0;
-}
-
-.well form {
- margin-bottom: 0;
-}
-
-.nav-tabs {
- border: 0;
-}
-
-.nav-tabs > li {
- float: right;
-}
-
-.nav-tabs li a {
- margin-right: 0;
-}
-
-.nav-tabs > .active > a {
- background: #f5f5f5;
-}
-
-.nav-tabs > .active > a:hover {
- background: #f5f5f5;
-}
-
-.tabbable.first-tab-active .tab-content
-{
- border-top-right-radius: 0;
-}
-
-#footer, #push {
- height: 60px; /* .push must be the same height as .footer */
-}
-
-#footer{
- text-align: right;
-}
-
-#footer p {
- text-align: center;
- color: gray;
- border-top: 1px solid #DDD;
- padding-top: 10px;
-}
-
-#footer a {
- color: gray;
- font-weight: bold;
-}
-
-#footer a:hover {
- color: gray;
-}
-
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 4410f285..9d939e73 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -13,8 +13,10 @@
<title>{% block title %}Django REST framework{% endblock %}</title>
{% block style %}
- {% block bootstrap_theme %}<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>{% endblock %}
- <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
+ {% block bootstrap_theme %}
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
+ {% endblock %}
<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/prettify.css" %}"/>
<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>
{% endblock %}
@@ -30,8 +32,8 @@
<div class="navbar {% block bootstrap_navbar_variant %}navbar-inverse{% endblock %}">
<div class="navbar-inner">
<div class="container-fluid">
- <span class="brand" href="/">
- {% block branding %}<a href='http://django-rest-framework.org'>Django REST framework <span class="version">{{ version }}</span></a>{% endblock %}
+ <span href="/">
+ {% block branding %}<a class='brand' href='http://django-rest-framework.org'>Django REST framework <span class="version">{{ version }}</span></a>{% endblock %}
</span>
<ul class="nav pull-right">
{% block userlinks %}
@@ -109,8 +111,7 @@
<div class="content-main">
<div class="page-header"><h1>{{ name }}</h1></div>
{{ description }}
-
- <div class="request-info">
+ <div class="request-info" style="clear: both" >
<pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre>
</div>
<div class="response-info">
diff --git a/rest_framework/templates/rest_framework/form.html b/rest_framework/templates/rest_framework/form.html
index dc7acc70..b27f652e 100644
--- a/rest_framework/templates/rest_framework/form.html
+++ b/rest_framework/templates/rest_framework/form.html
@@ -6,7 +6,7 @@
{{ field.label_tag|add_class:"control-label" }}
<div class="controls">
{{ field }}
- <span class="help-inline">{{ field.help_text }}</span>
+ <span class="help-block">{{ field.help_text }}</span>
<!--{{ field.errors|add_class:"help-block" }}-->
</div>
</div>
diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html
index 380d5820..be9a0072 100644
--- a/rest_framework/templates/rest_framework/login_base.html
+++ b/rest_framework/templates/rest_framework/login_base.html
@@ -4,52 +4,50 @@
<head>
{% block style %}
- {% block bootstrap_theme %}<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>{% endblock %}
- <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
+ {% block bootstrap_theme %}
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
+ {% endblock %}
<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>
{% endblock %}
</head>
<body class="container">
-<div class="container-fluid" style="margin-top: 30px">
- <div class="row-fluid">
-
- <div class="well" style="width: 320px; margin-left: auto; margin-right: auto">
- <div class="row-fluid">
- <div>
- {% block branding %}<h3 style="margin: 0 0 20px;">Django REST framework</h3>{% endblock %}
- </div>
- </div><!-- /row fluid -->
-
+ <div class="container-fluid" style="margin-top: 30px">
<div class="row-fluid">
- <div>
- <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">
- {% csrf_token %}
- <div id="div_id_username" class="clearfix control-group">
- <div class="controls">
- <Label class="span4">Username:</label>
- <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">
- </div>
+ <div class="well" style="width: 320px; margin-left: auto; margin-right: auto">
+ <div class="row-fluid">
+ <div>
+ {% block branding %}<h3 style="margin: 0 0 20px;">Django REST framework</h3>{% endblock %}
</div>
- <div id="div_id_password" class="clearfix control-group">
- <div class="controls">
- <Label class="span4">Password:</label>
- <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">
- </div>
+ </div><!-- /row fluid -->
+
+ <div class="row-fluid">
+ <div>
+ <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">
+ {% csrf_token %}
+ <div id="div_id_username" class="clearfix control-group">
+ <div class="controls">
+ <Label class="span4">Username:</label>
+ <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">
+ </div>
+ </div>
+ <div id="div_id_password" class="clearfix control-group">
+ <div class="controls">
+ <Label class="span4">Password:</label>
+ <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">
+ </div>
+ </div>
+ <input type="hidden" name="next" value="{{ next }}" />
+ <div class="form-actions-no-box">
+ <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit">
+ </div>
+ </form>
</div>
- <input type="hidden" name="next" value="{{ next }}" />
- <div class="form-actions-no-box">
- <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit">
- </div>
- </form>
- </div>
- </div><!-- /row fluid -->
- </div><!--/span-->
-
- </div><!-- /.row-fluid -->
- </div>
-
- </div>
+ </div><!-- /.row-fluid -->
+ </div><!--/.well-->
+ </div><!-- /.row-fluid -->
+ </div><!-- /.container-fluid -->
</body>
</html>
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index c86b6456..e9c1cdd5 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -15,7 +15,7 @@ register = template.Library()
# When 1.3 becomes unsupported by REST framework, we can instead start to
# use the {% load staticfiles %} tag, remove the following code,
-# and add a dependancy that `django.contrib.staticfiles` must be installed.
+# and add a dependency that `django.contrib.staticfiles` must be installed.
# Note: We can't put this into the `compat` module because the compat import
# from rest_framework.compat import ...
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index f2117538..e2d4eacd 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -1,5 +1,7 @@
from __future__ import unicode_literals
from django.db import models
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import serializers
def foobar():
@@ -32,7 +34,7 @@ class Anchor(RESTFrameworkModel):
class BasicModel(RESTFrameworkModel):
- text = models.CharField(max_length=100)
+ text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description."))
class SlugBasedModel(RESTFrameworkModel):
@@ -58,13 +60,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
-# Model to test filtering.
-class FilterableItem(RESTFrameworkModel):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
-
-
# Model for regression test for #285
class Comment(RESTFrameworkModel):
@@ -166,3 +161,9 @@ class NullableOneToOneSource(RESTFrameworkModel):
name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
related_name='nullable_source')
+
+
+# Serializer used to test BasicModel
+class BasicModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py
deleted file mode 100644
index cbf93c65..00000000
--- a/rest_framework/tests/relations.py
+++ /dev/null
@@ -1,47 +0,0 @@
-"""
-General tests for relational fields.
-"""
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-
-
-class NullModel(models.Model):
- pass
-
-
-class FieldTests(TestCase):
- def test_pk_related_field_with_empty_string(self):
- """
- Regression test for #446
-
- https://github.com/tomchristie/django-rest-framework/issues/446
- """
- field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
- self.assertRaises(serializers.ValidationError, field.from_native, '')
- self.assertRaises(serializers.ValidationError, field.from_native, [])
-
- def test_hyperlinked_related_field_with_empty_string(self):
- field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
- self.assertRaises(serializers.ValidationError, field.from_native, '')
- self.assertRaises(serializers.ValidationError, field.from_native, [])
-
- def test_slug_related_field_with_empty_string(self):
- field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
- self.assertRaises(serializers.ValidationError, field.from_native, '')
- self.assertRaises(serializers.ValidationError, field.from_native, [])
-
-
-class TestManyRelateMixin(TestCase):
- def test_missing_many_to_many_related_field(self):
- '''
- Regression test for #632
-
- https://github.com/tomchristie/django-rest-framework/pull/632
- '''
- field = serializers.RelatedField(many=True, read_only=False)
-
- into = {}
- field.field_from_native({}, None, 'field_name', into)
- self.assertEqual(into['field_name'], [])
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/test_authentication.py
index 8e6d3e51..d46ac079 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/test_authentication.py
@@ -6,6 +6,8 @@ from django.utils import unittest
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions
from rest_framework import permissions
+from rest_framework import renderers
+from rest_framework.response import Response
from rest_framework import status
from rest_framework.authentication import (
BaseAuthentication,
@@ -48,7 +50,7 @@ urlpatterns = patterns('',
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
(r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
- (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
+ (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
permission_classes=[permissions.TokenHasReadWriteScope]))
)
@@ -56,14 +58,14 @@ if oauth2_provider is not None:
urlpatterns += patterns('',
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
- url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
+ url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
permission_classes=[permissions.TokenHasReadWriteScope])),
)
class BasicAuthTests(TestCase):
"""Basic authentication"""
- urls = 'rest_framework.tests.authentication'
+ urls = 'rest_framework.tests.test_authentication'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
@@ -102,7 +104,7 @@ class BasicAuthTests(TestCase):
class SessionAuthTests(TestCase):
"""User session authentication"""
- urls = 'rest_framework.tests.authentication'
+ urls = 'rest_framework.tests.test_authentication'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
@@ -149,7 +151,7 @@ class SessionAuthTests(TestCase):
class TokenAuthTests(TestCase):
"""Token authentication"""
- urls = 'rest_framework.tests.authentication'
+ urls = 'rest_framework.tests.test_authentication'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
@@ -243,7 +245,7 @@ class IncorrectCredentialsTests(TestCase):
class OAuthTests(TestCase):
"""OAuth 1.0a authentication"""
- urls = 'rest_framework.tests.authentication'
+ urls = 'rest_framework.tests.test_authentication'
def setUp(self):
# these imports are here because oauth is optional and hiding them in try..except block or compat
@@ -429,7 +431,7 @@ class OAuthTests(TestCase):
class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication"""
- urls = 'rest_framework.tests.authentication'
+ urls = 'rest_framework.tests.test_authentication'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
@@ -553,3 +555,40 @@ class OAuth2Tests(TestCase):
auth = self._create_authorization_header(token=read_write_access_token.token)
response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
+
+
+class FailingAuthAccessedInRenderer(TestCase):
+ def setUp(self):
+ class AuthAccessingRenderer(renderers.BaseRenderer):
+ media_type = 'text/plain'
+ format = 'txt'
+
+ def render(self, data, media_type=None, renderer_context=None):
+ request = renderer_context['request']
+ if request.user.is_authenticated():
+ return b'authenticated'
+ return b'not authenticated'
+
+ class FailingAuth(BaseAuthentication):
+ def authenticate(self, request):
+ raise exceptions.AuthenticationFailed('authentication failed')
+
+ class ExampleView(APIView):
+ authentication_classes = (FailingAuth,)
+ renderer_classes = (AuthAccessingRenderer,)
+
+ def get(self, request):
+ return Response({'foo': 'bar'})
+
+ self.view = ExampleView.as_view()
+
+ def test_failing_auth_accessed_in_renderer(self):
+ """
+ When authentication fails the renderer should still be able to access
+ `request.user` without raising an exception. Particularly relevant
+ to HTML responses that might reasonably access `request.user`.
+ """
+ request = factory.get('/')
+ response = self.view(request)
+ content = response.render().content
+ self.assertEqual(content, b'not authenticated')
diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/test_breadcrumbs.py
index d9ed647e..41ddf2ce 100644
--- a/rest_framework/tests/breadcrumbs.py
+++ b/rest_framework/tests/test_breadcrumbs.py
@@ -36,7 +36,7 @@ urlpatterns = patterns('',
class BreadcrumbTests(TestCase):
"""Tests the breadcrumb functionality used by the HTML renderer."""
- urls = 'rest_framework.tests.breadcrumbs'
+ urls = 'rest_framework.tests.test_breadcrumbs'
def test_root_breadcrumbs(self):
url = '/'
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/test_decorators.py
index 1016fed3..1016fed3 100644
--- a/rest_framework/tests/decorators.py
+++ b/rest_framework/tests/test_decorators.py
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/test_description.py
index 52c1a34c..52c1a34c 100644
--- a/rest_framework/tests/description.py
+++ b/rest_framework/tests/test_description.py
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/test_fields.py
index 3cdfa0f6..69a0468e 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -2,15 +2,16 @@
General serializer field tests.
"""
from __future__ import unicode_literals
+
import datetime
from decimal import Decimal
-
+from uuid import uuid4
+from django.core import validators
from django.db import models
from django.test import TestCase
-from django.core import validators
-
+from django.utils.datastructures import SortedDict
from rest_framework import serializers
-from rest_framework.serializers import Serializer
+from rest_framework.tests.models import RESTFrameworkModel
class TimestampedModel(models.Model):
@@ -63,6 +64,20 @@ class BasicFieldTests(TestCase):
serializer = CharPrimaryKeyModelSerializer()
self.assertEqual(serializer.fields['id'].read_only, False)
+ def test_dict_field_ordering(self):
+ """
+ Field should preserve dictionary ordering, if it exists.
+ See: https://github.com/tomchristie/django-rest-framework/issues/832
+ """
+ ret = SortedDict()
+ ret['c'] = 1
+ ret['b'] = 1
+ ret['a'] = 1
+ ret['z'] = 1
+ field = serializers.Field()
+ keys = list(field.to_native(ret).keys())
+ self.assertEqual(keys, ['c', 'b', 'a', 'z'])
+
class DateFieldTest(TestCase):
"""
@@ -573,7 +588,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure the serializer works correctly
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_value=9010,
min_value=9000,
max_digits=6,
@@ -591,7 +606,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure max_value violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_value=100)
s = DecimalSerializer(data={'decimal_field': '123'})
@@ -603,7 +618,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure min_value violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(min_value=100)
s = DecimalSerializer(data={'decimal_field': '99'})
@@ -615,7 +630,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure max_digits violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_digits=5)
s = DecimalSerializer(data={'decimal_field': '123.456'})
@@ -627,7 +642,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure max_decimal_places violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(decimal_places=3)
s = DecimalSerializer(data={'decimal_field': '123.4567'})
@@ -639,10 +654,215 @@ class DecimalFieldTest(TestCase):
"""
Make sure max_whole_digits violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
s = DecimalSerializer(data={'decimal_field': '12345.6'})
self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
+
+
+class ChoiceFieldTests(TestCase):
+ """
+ Tests for the ChoiceField options generator
+ """
+
+ SAMPLE_CHOICES = [
+ ('red', 'Red'),
+ ('green', 'Green'),
+ ('blue', 'Blue'),
+ ]
+
+ def test_choices_required(self):
+ """
+ Make sure proper choices are rendered if field is required
+ """
+ f = serializers.ChoiceField(required=True, choices=self.SAMPLE_CHOICES)
+ self.assertEqual(f.choices, self.SAMPLE_CHOICES)
+
+ def test_choices_not_required(self):
+ """
+ Make sure proper choices (plus blank) are rendered if the field isn't required
+ """
+ f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)
+ self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES)
+
+
+class EmailFieldTests(TestCase):
+ """
+ Tests for EmailField attribute values
+ """
+
+ class EmailFieldModel(RESTFrameworkModel):
+ email_field = models.EmailField(blank=True)
+
+ class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel):
+ email_field = models.EmailField(max_length=150, blank=True)
+
+ def test_default_model_value(self):
+ class EmailFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.EmailFieldModel
+
+ serializer = EmailFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75)
+
+ def test_given_model_value(self):
+ class EmailFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.EmailFieldWithGivenMaxLengthModel
+
+ serializer = EmailFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150)
+
+ def test_given_serializer_value(self):
+ class EmailFieldSerializer(serializers.ModelSerializer):
+ email_field = serializers.EmailField(source='email_field', max_length=20, required=False)
+
+ class Meta:
+ model = self.EmailFieldModel
+
+ serializer = EmailFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20)
+
+
+class SlugFieldTests(TestCase):
+ """
+ Tests for SlugField attribute values
+ """
+
+ class SlugFieldModel(RESTFrameworkModel):
+ slug_field = models.SlugField(blank=True)
+
+ class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel):
+ slug_field = models.SlugField(max_length=84, blank=True)
+
+ def test_default_model_value(self):
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.SlugFieldModel
+
+ serializer = SlugFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50)
+
+ def test_given_model_value(self):
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.SlugFieldWithGivenMaxLengthModel
+
+ serializer = SlugFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84)
+
+ def test_given_serializer_value(self):
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ slug_field = serializers.SlugField(source='slug_field',
+ max_length=20, required=False)
+
+ class Meta:
+ model = self.SlugFieldModel
+
+ serializer = SlugFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['slug_field'],
+ 'max_length'), 20)
+
+ def test_invalid_slug(self):
+ """
+ Make sure an invalid slug raises ValidationError
+ """
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True)
+
+ class Meta:
+ model = self.SlugFieldModel
+
+ s = SlugFieldSerializer(data={'slug_field': 'a b'})
+
+ self.assertEqual(s.is_valid(), False)
+ self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]})
+
+
+class URLFieldTests(TestCase):
+ """
+ Tests for URLField attribute values
+ """
+
+ class URLFieldModel(RESTFrameworkModel):
+ url_field = models.URLField(blank=True)
+
+ class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel):
+ url_field = models.URLField(max_length=128, blank=True)
+
+ def test_default_model_value(self):
+ class URLFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.URLFieldModel
+
+ serializer = URLFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 200)
+
+ def test_given_model_value(self):
+ class URLFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.URLFieldWithGivenMaxLengthModel
+
+ serializer = URLFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 128)
+
+ def test_given_serializer_value(self):
+ class URLFieldSerializer(serializers.ModelSerializer):
+ url_field = serializers.URLField(source='url_field',
+ max_length=20, required=False)
+
+ class Meta:
+ model = self.URLFieldWithGivenMaxLengthModel
+
+ serializer = URLFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 20)
+
+
+class FieldMetadata(TestCase):
+ def setUp(self):
+ self.required_field = serializers.Field()
+ self.required_field.label = uuid4().hex
+ self.required_field.required = True
+
+ self.optional_field = serializers.Field()
+ self.optional_field.label = uuid4().hex
+ self.optional_field.required = False
+
+ def test_required(self):
+ self.assertEqual(self.required_field.metadata()['required'], True)
+
+ def test_optional(self):
+ self.assertEqual(self.optional_field.metadata()['required'], False)
+
+ def test_label(self):
+ for field in (self.required_field, self.optional_field):
+ self.assertEqual(field.metadata()['label'], field.label)
+
+
+class FieldCallableDefault(TestCase):
+ def setUp(self):
+ self.simple_callable = lambda: 'foo bar'
+
+ def test_default_can_be_simple_callable(self):
+ """
+ Ensure that the 'default' argument can also be a simple callable.
+ """
+ field = serializers.WritableField(default=self.simple_callable)
+ into = {}
+ field.field_from_native({}, {}, 'field', into)
+ self.assertEqual(into, {'field': 'foo bar'})
diff --git a/rest_framework/tests/files.py b/rest_framework/tests/test_files.py
index 487046ac..487046ac 100644
--- a/rest_framework/tests/files.py
+++ b/rest_framework/tests/test_files.py
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/test_filters.py
index 023bd016..aaed6247 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/test_filters.py
@@ -1,23 +1,30 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
+from django.db import models
from django.core.urlresolvers import reverse
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url
-from rest_framework.tests.models import FilterableItem, BasicModel
+from rest_framework.tests.models import BasicModel
factory = RequestFactory()
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
if django_filters:
# Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView):
model = FilterableItem
filter_fields = ['decimal', 'date']
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
# These class are used to test a filter class.
class SeveralFieldsFilter(django_filters.FilterSet):
@@ -32,7 +39,7 @@ if django_filters:
class FilterClassRootView(generics.ListCreateAPIView):
model = FilterableItem
filter_class = SeveralFieldsFilter
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
# These classes are used to test a misconfigured filter class.
class MisconfiguredFilter(django_filters.FilterSet):
@@ -45,12 +52,12 @@ if django_filters:
class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
model = FilterableItem
filter_class = MisconfiguredFilter
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
class FilterClassDetailView(generics.RetrieveAPIView):
model = FilterableItem
filter_class = SeveralFieldsFilter
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
# Regression test for #814
class FilterableItemSerializer(serializers.ModelSerializer):
@@ -61,11 +68,21 @@ if django_filters:
queryset = FilterableItem.objects.all()
serializer_class = FilterableItemSerializer
filter_fields = ['decimal', 'date']
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class GetQuerysetView(generics.ListCreateAPIView):
+ serializer_class = FilterableItemSerializer
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ def get_queryset(self):
+ return FilterableItem.objects.all()
urlpatterns = patterns('',
url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
url(r'^$', FilterClassRootView.as_view(), name='root-view'),
+ url(r'^get-queryset/$', GetQuerysetView.as_view(),
+ name='get-queryset-view'),
)
@@ -141,6 +158,17 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_filter_with_get_queryset_only(self):
+ """
+ Regression test for #834.
+ """
+ view = GetQuerysetView.as_view()
+ request = factory.get('/get-queryset/')
+ view(request).render()
+ # Used to raise "issubclass() arg 2 must be a class or tuple of classes"
+ # here when neither `model' nor `queryset' was specified.
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_filtered_class_root_view(self):
"""
GET requests to filtered ListCreateAPIView that have a filter_class set
@@ -215,7 +243,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
"""
Integration tests for filtered detail views.
"""
- urls = 'rest_framework.tests.filterset'
+ urls = 'rest_framework.tests.test_filters'
def _get_url(self, item):
return reverse('detail-view', kwargs=dict(pk=item.pk))
@@ -256,3 +284,191 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, valid_item_data)
+
+
+class SearchFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class SearchFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # z abc
+ # zz bcd
+ # zzz cde
+ # ...
+ for idx in range(10):
+ title = 'z' * (idx + 1)
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ SearchFilterModel(title=title, text=text).save()
+
+ def test_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'z', 'text': 'abc'},
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+ def test_exact_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=zzz')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'zzz', 'text': 'cde'}
+ ]
+ )
+
+ def test_startswith_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', '^text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+
+class OrdringFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class OrderingFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # zyx abc
+ # yxw bcd
+ # xwv cde
+ for idx in range(3):
+ title = (
+ chr(ord('z') - idx) +
+ chr(ord('y') - idx) +
+ chr(ord('x') - idx)
+ )
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ OrdringFilterModel(title=title, text=text).save()
+
+ def test_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
+ def test_reverse_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=-text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_incorrectfield_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=foobar')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering_using_string(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/test_genericrelations.py
index c38bfb9f..c38bfb9f 100644
--- a/rest_framework/tests/genericrelations.py
+++ b/rest_framework/tests/test_genericrelations.py
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/test_generics.py
index eca50d82..37734195 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/test_generics.py
@@ -2,7 +2,7 @@ from __future__ import unicode_literals
from django.db import models
from django.shortcuts import get_object_or_404
from django.test import TestCase
-from rest_framework import generics, serializers, status
+from rest_framework import generics, renderers, serializers, status
from rest_framework.tests.utils import RequestFactory
from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
from rest_framework.compat import six
@@ -39,6 +39,7 @@ class SlugBasedInstanceView(InstanceView):
"""
model = SlugBasedModel
serializer_class = SlugSerializer
+ lookup_field = 'slug'
class TestRootView(TestCase):
@@ -120,7 +121,25 @@ class TestRootView(TestCase):
'text/html'
],
'name': 'Root',
- 'description': 'Example description for OPTIONS.'
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {
+ 'POST': {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'string',
+ "label": "Text comes here",
+ "help_text": "Text description."
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'integer',
+ 'label': 'ID',
+ },
+ }
+ }
}
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected)
@@ -223,9 +242,9 @@ class TestInstanceView(TestCase):
"""
OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
"""
- request = factory.options('/')
- with self.assertNumQueries(0):
- response = self.view(request).render()
+ request = factory.options('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
expected = {
'parses': [
'application/json',
@@ -237,11 +256,39 @@ class TestInstanceView(TestCase):
'text/html'
],
'name': 'Instance',
- 'description': 'Example description for OPTIONS.'
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {
+ 'PUT': {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'string',
+ 'label': 'Text comes here',
+ 'help_text': 'Text description.'
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'integer',
+ 'label': 'ID',
+ },
+ }
+ }
}
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected)
+ def test_get_instance_view_incorrect_arg(self):
+ """
+ GET requests with an incorrect pk type, should raise 404, not 500.
+ Regression test for #890.
+ """
+ request = factory.get('/a')
+ with self.assertNumQueries(0):
+ response = self.view(request, pk='a').render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
def test_put_cannot_set_id(self):
"""
PUT requests to create a new object should not be able to set the id.
@@ -434,22 +481,14 @@ class TestFilterBackendAppliedToViews(TestCase):
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
- self.root_view = RootView.as_view()
- self.instance_view = InstanceView.as_view()
- self.original_root_backend = getattr(RootView, 'filter_backend')
- self.original_instance_backend = getattr(InstanceView, 'filter_backend')
-
- def tearDown(self):
- setattr(RootView, 'filter_backend', self.original_root_backend)
- setattr(InstanceView, 'filter_backend', self.original_instance_backend)
def test_get_root_view_filters_by_name_with_filter_backend(self):
"""
GET requests to ListCreateAPIView should return filtered list.
"""
- setattr(RootView, 'filter_backend', InclusiveFilterBackend)
+ root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/')
- response = self.root_view(request).render()
+ response = root_view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
@@ -458,9 +497,9 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to ListCreateAPIView should return empty list when all models are filtered out.
"""
- setattr(RootView, 'filter_backend', ExclusiveFilterBackend)
+ root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/')
- response = self.root_view(request).render()
+ response = root_view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, [])
@@ -468,9 +507,9 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
"""
- setattr(InstanceView, 'filter_backend', ExclusiveFilterBackend)
+ instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/1')
- response = self.instance_view(request, pk=1).render()
+ response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {'detail': 'Not found'})
@@ -478,8 +517,40 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
"""
- setattr(InstanceView, 'filter_backend', InclusiveFilterBackend)
+ instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/1')
- response = self.instance_view(request, pk=1).render()
+ response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
+
+
+class TwoFieldModel(models.Model):
+ field_a = models.CharField(max_length=100)
+ field_b = models.CharField(max_length=100)
+
+
+class DynamicSerializerView(generics.ListCreateAPIView):
+ model = TwoFieldModel
+ renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
+
+ def get_serializer_class(self):
+ if self.request.method == 'POST':
+ class DynamicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TwoFieldModel
+ fields = ('field_b',)
+ return DynamicSerializer
+ return super(DynamicSerializerView, self).get_serializer_class()
+
+
+class TestFilterBackendAppliedToViews(TestCase):
+
+ def test_dynamic_serializer_form_in_browsable_api(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ view = DynamicSerializerView.as_view()
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertContains(response, 'field_b')
+ self.assertNotContains(response, 'field_a')
diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py
index 8f2e2b5a..8957a43c 100644
--- a/rest_framework/tests/htmlrenderer.py
+++ b/rest_framework/tests/test_htmlrenderer.py
@@ -42,7 +42,7 @@ urlpatterns = patterns('',
class TemplateHTMLRendererTests(TestCase):
- urls = 'rest_framework.tests.htmlrenderer'
+ urls = 'rest_framework.tests.test_htmlrenderer'
def setUp(self):
"""
@@ -66,23 +66,23 @@ class TemplateHTMLRendererTests(TestCase):
def test_simple_html_view(self):
response = self.client.get('/')
self.assertContains(response, "example: foobar")
- self.assertEqual(response['Content-Type'], 'text/html')
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_not_found_html_view(self):
response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.content, six.b("404 Not Found"))
- self.assertEqual(response['Content-Type'], 'text/html')
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_permission_denied_html_view(self):
response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.content, six.b("403 Forbidden"))
- self.assertEqual(response['Content-Type'], 'text/html')
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
class TemplateHTMLRendererExceptionTests(TestCase):
- urls = 'rest_framework.tests.htmlrenderer'
+ urls = 'rest_framework.tests.test_htmlrenderer'
def setUp(self):
"""
@@ -109,10 +109,10 @@ class TemplateHTMLRendererExceptionTests(TestCase):
response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.content, six.b("404: Not found"))
- self.assertEqual(response['Content-Type'], 'text/html')
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.content, six.b("403: Permission denied"))
- self.assertEqual(response['Content-Type'], 'text/html')
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py
index 9a61f299..1894ddb2 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/test_hyperlinkedserializers.py
@@ -27,6 +27,14 @@ class PhotoSerializer(serializers.Serializer):
return Photo(**attrs)
+class AlbumSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title')
+
+ class Meta:
+ model = Album
+ fields = ('title', 'url')
+
+
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
@@ -73,6 +81,8 @@ class PhotoListCreate(generics.ListCreateAPIView):
class AlbumDetail(generics.RetrieveAPIView):
model = Album
+ serializer_class = AlbumSerializer
+ lookup_field = 'title'
class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
@@ -96,7 +106,7 @@ urlpatterns = patterns('',
class TestBasicHyperlinkedView(TestCase):
- urls = 'rest_framework.tests.hyperlinkedserializers'
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
def setUp(self):
"""
@@ -133,7 +143,7 @@ class TestBasicHyperlinkedView(TestCase):
class TestManyToManyHyperlinkedView(TestCase):
- urls = 'rest_framework.tests.hyperlinkedserializers'
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
def setUp(self):
"""
@@ -180,8 +190,38 @@ class TestManyToManyHyperlinkedView(TestCase):
self.assertEqual(response.data, self.data[0])
+class TestHyperlinkedIdentityFieldLookup(TestCase):
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 Album instances.
+ """
+ titles = ['foo', 'bar', 'baz']
+ for title in titles:
+ album = Album(title=title)
+ album.save()
+ self.detail_view = AlbumDetail.as_view()
+ self.data = {
+ 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'},
+ 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
+ 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'}
+ }
+
+ def test_lookup_field(self):
+ """
+ GET requests to AlbumDetail view should return serialized Albums
+ with a url field keyed by `title`.
+ """
+ for album in Album.objects.all():
+ request = factory.get('/albums/{0}/'.format(album.title))
+ response = self.detail_view(request, title=album.title)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[album.title])
+
+
class TestCreateWithForeignKeys(TestCase):
- urls = 'rest_framework.tests.hyperlinkedserializers'
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
def setUp(self):
"""
@@ -206,7 +246,7 @@ class TestCreateWithForeignKeys(TestCase):
class TestCreateWithForeignKeysAndCustomSlug(TestCase):
- urls = 'rest_framework.tests.hyperlinkedserializers'
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
def setUp(self):
"""
@@ -231,7 +271,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
class TestOptionalRelationHyperlinkedView(TestCase):
- urls = 'rest_framework.tests.hyperlinkedserializers'
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
def setUp(self):
"""
diff --git a/rest_framework/tests/multitable_inheritance.py b/rest_framework/tests/test_multitable_inheritance.py
index 00c15327..00c15327 100644
--- a/rest_framework/tests/multitable_inheritance.py
+++ b/rest_framework/tests/test_multitable_inheritance.py
diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/test_negotiation.py
index 43721b84..7f84827f 100644
--- a/rest_framework/tests/negotiation.py
+++ b/rest_framework/tests/test_negotiation.py
@@ -3,19 +3,24 @@ from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework.negotiation import DefaultContentNegotiation
from rest_framework.request import Request
+from rest_framework.renderers import BaseRenderer
factory = RequestFactory()
-class MockJSONRenderer(object):
+class MockJSONRenderer(BaseRenderer):
media_type = 'application/json'
-class MockHTMLRenderer(object):
+class MockHTMLRenderer(BaseRenderer):
media_type = 'text/html'
+class NoCharsetSpecifiedRenderer(BaseRenderer):
+ media_type = 'my/media'
+
+
class TestAcceptedMediaType(TestCase):
def setUp(self):
self.renderers = [MockJSONRenderer(), MockHTMLRenderer()]
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/test_pagination.py
index 6b8ef02f..e538a78e 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/test_pagination.py
@@ -1,18 +1,24 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
-import django
+from django.db import models
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters
-from rest_framework.tests.models import BasicModel, FilterableItem
+from rest_framework.tests.models import BasicModel
factory = RequestFactory()
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
class RootView(generics.ListCreateAPIView):
"""
Example description for OPTIONS.
@@ -124,7 +130,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
model = FilterableItem
paginate_by = 10
filter_class = DecimalFilter
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
view = FilterFieldsRootView.as_view()
@@ -171,7 +177,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
class BasicFilterFieldsRootView(generics.ListCreateAPIView):
model = FilterableItem
paginate_by = 10
- filter_backend = DecimalFilterBackend
+ filter_backends = (DecimalFilterBackend,)
view = BasicFilterFieldsRootView.as_view()
diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/test_parsers.py
index 7699e10c..7699e10c 100644
--- a/rest_framework/tests/parsers.py
+++ b/rest_framework/tests/test_parsers.py
diff --git a/rest_framework/tests/permissions.py b/rest_framework/tests/test_permissions.py
index b3993be5..6caaf65b 100644
--- a/rest_framework/tests/permissions.py
+++ b/rest_framework/tests/test_permissions.py
@@ -108,6 +108,48 @@ class ModelPermissionsIntegrationTests(TestCase):
response = instance_view(request, pk='2')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ def test_options_permitted(self):
+ request = factory.options('/', content_type='application/json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['POST'])
+
+ request = factory.options('/1', content_type='application/json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
+
+ def test_options_disallowed(self):
+ request = factory.options('/', content_type='application/json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ request = factory.options('/1', content_type='application/json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ def test_options_updateonly(self):
+ request = factory.options('/', content_type='application/json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ request = factory.options('/1', content_type='application/json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
+
class OwnerModel(models.Model):
text = models.CharField(max_length=100)
diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py
new file mode 100644
index 00000000..d19219c9
--- /dev/null
+++ b/rest_framework/tests/test_relations.py
@@ -0,0 +1,100 @@
+"""
+General tests for relational fields.
+"""
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import BlogPost
+
+
+class NullModel(models.Model):
+ pass
+
+
+class FieldTests(TestCase):
+ def test_pk_related_field_with_empty_string(self):
+ """
+ Regression test for #446
+
+ https://github.com/tomchristie/django-rest-framework/issues/446
+ """
+ field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_hyperlinked_related_field_with_empty_string(self):
+ field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_slug_related_field_with_empty_string(self):
+ field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+
+class TestManyRelatedMixin(TestCase):
+ def test_missing_many_to_many_related_field(self):
+ '''
+ Regression test for #632
+
+ https://github.com/tomchristie/django-rest-framework/pull/632
+ '''
+ field = serializers.RelatedField(many=True, read_only=False)
+
+ into = {}
+ field.field_from_native({}, None, 'field_name', into)
+ self.assertEqual(into['field_name'], [])
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class RelatedFieldSourceTests(TestCase):
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='get_blogposts_manager')
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='get_blogposts_queryset')
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='a.b.c')
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py
index b1eed9a7..2ca7f4f2 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/test_relations_hyperlink.py
@@ -4,6 +4,7 @@ from django.test.client import RequestFactory
from rest_framework import serializers
from rest_framework.compat import patterns, url
from rest_framework.tests.models import (
+ BlogPost,
ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
)
@@ -16,6 +17,7 @@ def dummy_view(request, pk):
pass
urlpatterns = patterns('',
+ url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
@@ -69,7 +71,7 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid
class HyperlinkedManyToManyTests(TestCase):
- urls = 'rest_framework.tests.relations_hyperlink'
+ urls = 'rest_framework.tests.test_relations_hyperlink'
def setUp(self):
for idx in range(1, 4):
@@ -177,7 +179,7 @@ class HyperlinkedManyToManyTests(TestCase):
class HyperlinkedForeignKeyTests(TestCase):
- urls = 'rest_framework.tests.relations_hyperlink'
+ urls = 'rest_framework.tests.test_relations_hyperlink'
def setUp(self):
target = ForeignKeyTarget(name='target-1')
@@ -305,7 +307,7 @@ class HyperlinkedForeignKeyTests(TestCase):
class HyperlinkedNullableForeignKeyTests(TestCase):
- urls = 'rest_framework.tests.relations_hyperlink'
+ urls = 'rest_framework.tests.test_relations_hyperlink'
def setUp(self):
target = ForeignKeyTarget(name='target-1')
@@ -433,7 +435,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
class HyperlinkedNullableOneToOneTests(TestCase):
- urls = 'rest_framework.tests.relations_hyperlink'
+ urls = 'rest_framework.tests.test_relations_hyperlink'
def setUp(self):
target = OneToOneTarget(name='target-1')
@@ -451,3 +453,72 @@ class HyperlinkedNullableOneToOneTests(TestCase):
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
]
self.assertEqual(serializer.data, expected)
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class HyperlinkedRelatedFieldSourceTests(TestCase):
+ urls = 'rest_framework.tests.test_relations_hyperlink'
+
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='get_blogposts_manager',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='get_blogposts_queryset',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='a.b.c',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/test_relations_nested.py
index 8325580f..8325580f 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/test_relations_nested.py
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/test_relations_pk.py
index 5ce8b567..e2a1b815 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/test_relations_pk.py
@@ -1,7 +1,11 @@
from __future__ import unicode_literals
+from django.db import models
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+from rest_framework.tests.models import (
+ BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
+ NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource,
+)
from rest_framework.compat import six
@@ -124,6 +128,7 @@ class PKManyToManyTests(TestCase):
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
+ self.assertFalse(serializer.fields['targets'].read_only)
expected = [
{'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
@@ -135,6 +140,7 @@ class PKManyToManyTests(TestCase):
def test_reverse_many_to_many_create(self):
data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
serializer = ManyToManyTargetSerializer(data=data)
+ self.assertFalse(serializer.fields['sources'].read_only)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
@@ -421,3 +427,116 @@ class PKNullableOneToOneTests(TestCase):
{'id': 2, 'name': 'target-2', 'nullable_source': 1},
]
self.assertEqual(serializer.data, expected)
+
+
+# The below models and tests ensure that serializer fields corresponding
+# to a ManyToManyField field with a user-specified ``through`` model are
+# set to read only
+
+
+class ManyToManyThroughTarget(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManyThrough(models.Model):
+ source = models.ForeignKey('ManyToManyThroughSource')
+ target = models.ForeignKey(ManyToManyThroughTarget)
+
+
+class ManyToManyThroughSource(models.Model):
+ name = models.CharField(max_length=100)
+ targets = models.ManyToManyField(ManyToManyThroughTarget,
+ related_name='sources',
+ through='ManyToManyThrough')
+
+
+class ManyToManyThroughTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyThroughTarget
+ fields = ('id', 'name', 'sources')
+
+
+class ManyToManyThroughSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyThroughSource
+ fields = ('id', 'name', 'targets')
+
+
+class PKManyToManyThroughTests(TestCase):
+ def setUp(self):
+ self.source = ManyToManyThroughSource.objects.create(
+ name='through-source-1')
+ self.target = ManyToManyThroughTarget.objects.create(
+ name='through-target-1')
+
+ def test_many_to_many_create(self):
+ data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]}
+ serializer = ManyToManyThroughSourceSerializer(data=data)
+ self.assertTrue(serializer.fields['targets'].read_only)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(obj.name, 'source-2')
+ self.assertEqual(obj.targets.count(), 0)
+
+ def test_many_to_many_reverse_create(self):
+ data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]}
+ serializer = ManyToManyThroughTargetSerializer(data=data)
+ self.assertTrue(serializer.fields['sources'].read_only)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ obj = serializer.save()
+ self.assertEqual(obj.name, 'target-2')
+ self.assertEqual(obj.sources.count(), 0)
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+
+class PrimaryKeyRelatedFieldSourceTests(TestCase):
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager')
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset')
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c')
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/test_relations_slug.py
index 435c821c..435c821c 100644
--- a/rest_framework/tests/relations_slug.py
+++ b/rest_framework/tests/test_relations_slug.py
diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/test_renderers.py
index 40bac9cb..95b59741 100644
--- a/rest_framework/tests/renderers.py
+++ b/rest_framework/tests/test_renderers.py
@@ -1,14 +1,18 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
from decimal import Decimal
from django.core.cache import cache
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
+from django.utils.translation import ugettext_lazy as _
from rest_framework import status, permissions
from rest_framework.compat import yaml, etree, patterns, url, include
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
- XMLRenderer, JSONPRenderer, BrowsableAPIRenderer
+ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer
from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings
from rest_framework.compat import StringIO
@@ -26,7 +30,7 @@ RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
expected_results = [
- ((elem for elem in [1, 2, 3]), JSONRenderer, '[1, 2, 3]') # Generator
+ ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]') # Generator
]
@@ -129,12 +133,12 @@ class RendererEndToEndTests(TestCase):
End-to-end testing of renderers using an RendererMixin on a generic view.
"""
- urls = 'rest_framework.tests.renderers'
+ urls = 'rest_framework.tests.test_renderers'
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -142,13 +146,13 @@ class RendererEndToEndTests(TestCase):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS)
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -156,7 +160,7 @@ class RendererEndToEndTests(TestCase):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -164,7 +168,7 @@ class RendererEndToEndTests(TestCase):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -175,7 +179,7 @@ class RendererEndToEndTests(TestCase):
RendererB.media_type
)
resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -192,7 +196,7 @@ class RendererEndToEndTests(TestCase):
RendererB.format
)
resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -200,7 +204,7 @@ class RendererEndToEndTests(TestCase):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -213,7 +217,7 @@ class RendererEndToEndTests(TestCase):
)
resp = self.client.get('/' + param,
HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -235,6 +239,13 @@ class JSONRendererTests(TestCase):
Tests specific to the JSON Renderer
"""
+ def test_render_lazy_strings(self):
+ """
+ JSONRenderer should deal with lazy translated strings.
+ """
+ ret = JSONRenderer().render(_('test'))
+ self.assertEqual(ret, b'"test"')
+
def test_without_content_type_args(self):
"""
Test basic JSON rendering.
@@ -243,7 +254,7 @@ class JSONRendererTests(TestCase):
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json')
# Fix failing test case which depends on version of JSON library.
- self.assertEqual(content, _flat_repr)
+ self.assertEqual(content.decode('utf-8'), _flat_repr)
def test_with_content_type_args(self):
"""
@@ -252,7 +263,24 @@ class JSONRendererTests(TestCase):
obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json; indent=2')
- self.assertEqual(strip_trailing_whitespace(content), _indented_repr)
+ self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr)
+
+ def test_check_ascii(self):
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8'))
+
+
+class UnicodeJSONRendererTests(TestCase):
+ """
+ Tests specific for the Unicode JSON Renderer
+ """
+ def test_proper_encoding(self):
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = UnicodeJSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ self.assertEqual(content, '{"countries": ["United Kingdom", "France", "España"]}'.encode('utf-8'))
class JSONPRendererTests(TestCase):
@@ -260,7 +288,7 @@ class JSONPRendererTests(TestCase):
Tests specific to the JSONP Renderer
"""
- urls = 'rest_framework.tests.renderers'
+ urls = 'rest_framework.tests.test_renderers'
def test_without_callback_with_json_renderer(self):
"""
@@ -269,7 +297,7 @@ class JSONPRendererTests(TestCase):
resp = self.client.get('/jsonp/jsonrenderer',
HTTP_ACCEPT='application/javascript')
self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript')
+ self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content,
('callback(%s);' % _flat_repr).encode('ascii'))
@@ -280,7 +308,7 @@ class JSONPRendererTests(TestCase):
resp = self.client.get('/jsonp/nojsonrenderer',
HTTP_ACCEPT='application/javascript')
self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript')
+ self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content,
('callback(%s);' % _flat_repr).encode('ascii'))
@@ -292,7 +320,7 @@ class JSONPRendererTests(TestCase):
resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func,
HTTP_ACCEPT='application/javascript')
self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript')
+ self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content,
('%s(%s);' % (callback_func, _flat_repr)).encode('ascii'))
@@ -433,7 +461,7 @@ class CacheRenderTest(TestCase):
Tests specific to caching responses
"""
- urls = 'rest_framework.tests.renderers'
+ urls = 'rest_framework.tests.test_renderers'
cache_key = 'just_a_cache_key'
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/test_request.py
index 97e5af20..a5c5e84c 100644
--- a/rest_framework/tests/request.py
+++ b/rest_framework/tests/test_request.py
@@ -254,7 +254,7 @@ urlpatterns = patterns('',
class TestContentParsingWithAuthentication(TestCase):
- urls = 'rest_framework.tests.request'
+ urls = 'rest_framework.tests.test_request'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
diff --git a/rest_framework/tests/response.py b/rest_framework/tests/test_response.py
index aecf83f4..eea3c641 100644
--- a/rest_framework/tests/response.py
+++ b/rest_framework/tests/test_response.py
@@ -1,14 +1,18 @@
from __future__ import unicode_literals
from django.test import TestCase
+from rest_framework.tests.models import BasicModel, BasicModelSerializer
from rest_framework.compat import patterns, url, include
from rest_framework.response import Response
from rest_framework.views import APIView
+from rest_framework import generics
+from rest_framework import routers
from rest_framework import status
from rest_framework.renderers import (
BaseRenderer,
JSONRenderer,
BrowsableAPIRenderer
)
+from rest_framework import viewsets
from rest_framework.settings import api_settings
from rest_framework.compat import six
@@ -21,6 +25,9 @@ class MockJsonRenderer(BaseRenderer):
media_type = 'application/json'
+class MockTextMediaRenderer(BaseRenderer):
+ media_type = 'text/html'
+
DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'
@@ -44,13 +51,26 @@ class RendererB(BaseRenderer):
return RENDERER_B_SERIALIZER(data)
+class RendererC(RendererB):
+ media_type = 'mock/rendererc'
+ format = 'formatc'
+ charset = "rendererc"
+
+
class MockView(APIView):
- renderer_classes = (RendererA, RendererB)
+ renderer_classes = (RendererA, RendererB, RendererC)
def get(self, request, **kwargs):
return Response(DUMMYCONTENT, status=DUMMYSTATUS)
+class MockViewSettingContentType(APIView):
+ renderer_classes = (RendererA, RendererB, RendererC)
+
+ def get(self, request, **kwargs):
+ return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview')
+
+
class HTMLView(APIView):
renderer_classes = (BrowsableAPIRenderer, )
@@ -65,11 +85,29 @@ class HTMLView1(APIView):
return Response('text')
+class HTMLNewModelViewSet(viewsets.ModelViewSet):
+ model = BasicModel
+
+
+class HTMLNewModelView(generics.ListCreateAPIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+ permission_classes = []
+ serializer_class = BasicModelSerializer
+ model = BasicModel
+
+
+new_model_viewset_router = routers.DefaultRouter()
+new_model_viewset_router.register(r'', HTMLNewModelViewSet)
+
+
urlpatterns = patterns('',
- url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
- url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^html$', HTMLView.as_view()),
url(r'^html1$', HTMLView1.as_view()),
+ url(r'^html_new_model$', HTMLNewModelView.as_view()),
+ url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)),
url(r'^restframework', include('rest_framework.urls', namespace='rest_framework'))
)
@@ -80,12 +118,12 @@ class RendererIntegrationTests(TestCase):
End-to-end testing of renderers using an ResponseMixin on a generic view.
"""
- urls = 'rest_framework.tests.response'
+ urls = 'rest_framework.tests.test_response'
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -93,13 +131,13 @@ class RendererIntegrationTests(TestCase):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS)
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -107,7 +145,7 @@ class RendererIntegrationTests(TestCase):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -115,7 +153,7 @@ class RendererIntegrationTests(TestCase):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -126,7 +164,7 @@ class RendererIntegrationTests(TestCase):
RendererB.media_type
)
resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -134,7 +172,7 @@ class RendererIntegrationTests(TestCase):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -142,7 +180,7 @@ class RendererIntegrationTests(TestCase):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -151,7 +189,7 @@ class RendererIntegrationTests(TestCase):
the renderer with the matching format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format,
HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -160,7 +198,7 @@ class Issue122Tests(TestCase):
"""
Tests that covers #122.
"""
- urls = 'rest_framework.tests.response'
+ urls = 'rest_framework.tests.test_response'
def test_only_html_renderer(self):
"""
@@ -173,3 +211,68 @@ class Issue122Tests(TestCase):
Test if no infinite recursion occurs.
"""
self.client.get('/html1')
+
+
+class Issue467Tests(TestCase):
+ """
+ Tests for #467
+ """
+
+ urls = 'rest_framework.tests.test_response'
+
+ def test_form_has_label_and_help_text(self):
+ resp = self.client.get('/html_new_model')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertContains(resp, 'Text comes here')
+ self.assertContains(resp, 'Text description.')
+
+
+class Issue807Tests(TestCase):
+ """
+ Covers #807
+ """
+
+ urls = 'rest_framework.tests.test_response'
+
+ def test_does_not_append_charset_by_default(self):
+ """
+ Renderers don't include a charset unless set explicitly.
+ """
+ headers = {"HTTP_ACCEPT": RendererA.media_type}
+ resp = self.client.get('/', **headers)
+ expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8')
+ self.assertEqual(expected, resp['Content-Type'])
+
+ def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
+ """
+ If renderer class has charset attribute declared, it gets appended
+ to Response's Content-Type
+ """
+ headers = {"HTTP_ACCEPT": RendererC.media_type}
+ resp = self.client.get('/', **headers)
+ expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset)
+ self.assertEqual(expected, resp['Content-Type'])
+
+ def test_content_type_set_explictly_on_response(self):
+ """
+ The content type may be set explictly on the response.
+ """
+ headers = {"HTTP_ACCEPT": RendererC.media_type}
+ resp = self.client.get('/setbyview', **headers)
+ self.assertEqual('setbyview', resp['Content-Type'])
+
+ def test_viewset_label_help_text(self):
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ 'text/html'
+ )
+ resp = self.client.get('/html_new_model_viewset/' + param)
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertContains(resp, 'Text comes here')
+ self.assertContains(resp, 'Text description.')
+
+ def test_form_has_label_and_help_text(self):
+ resp = self.client.get('/html_new_model')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertContains(resp, 'Text comes here')
+ self.assertContains(resp, 'Text description.')
diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/test_reverse.py
index cb8d8132..93ef5637 100644
--- a/rest_framework/tests/reverse.py
+++ b/rest_framework/tests/test_reverse.py
@@ -19,7 +19,7 @@ class ReverseTests(TestCase):
"""
Tests for fully qualified URLs when using `reverse`.
"""
- urls = 'rest_framework.tests.reverse'
+ urls = 'rest_framework.tests.test_reverse'
def test_reversed_urls_are_fully_qualified(self):
request = factory.get('/view')
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
new file mode 100644
index 00000000..a7534f70
--- /dev/null
+++ b/rest_framework/tests/test_routers.py
@@ -0,0 +1,150 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework import serializers, viewsets
+from rest_framework.compat import include, patterns, url
+from rest_framework.decorators import link, action
+from rest_framework.response import Response
+from rest_framework.routers import SimpleRouter
+
+factory = RequestFactory()
+
+urlpatterns = patterns('',)
+
+
+class BasicViewSet(viewsets.ViewSet):
+ def list(self, request, *args, **kwargs):
+ return Response({'method': 'list'})
+
+ @action()
+ def action1(self, request, *args, **kwargs):
+ return Response({'method': 'action1'})
+
+ @action()
+ def action2(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @action(methods=['post', 'delete'])
+ def action3(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @link()
+ def link1(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @link()
+ def link2(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
+
+class TestSimpleRouter(TestCase):
+ def setUp(self):
+ self.router = SimpleRouter()
+
+ def test_link_and_action_decorator(self):
+ routes = self.router.get_routes(BasicViewSet)
+ decorator_routes = routes[2:]
+ # Make sure all these endpoints exist and none have been clobbered
+ for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']):
+ route = decorator_routes[i]
+ # check url listing
+ self.assertEqual(route.url,
+ '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
+ # check method to function mapping
+ if endpoint == 'action3':
+ methods_map = ['post', 'delete']
+ elif endpoint.startswith('action'):
+ methods_map = ['post']
+ else:
+ methods_map = ['get']
+ for method in methods_map:
+ self.assertEqual(route.mapping[method], endpoint)
+
+
+class RouterTestModel(models.Model):
+ uuid = models.CharField(max_length=20)
+ text = models.CharField(max_length=200)
+
+
+class TestCustomLookupFields(TestCase):
+ """
+ Ensure that custom lookup fields are correctly routed.
+ """
+ urls = 'rest_framework.tests.test_routers'
+
+ def setUp(self):
+ class NoteSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RouterTestModel
+ lookup_field = 'uuid'
+ fields = ('url', 'uuid', 'text')
+
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+ serializer_class = NoteSerializer
+ lookup_field = 'uuid'
+
+ RouterTestModel.objects.create(uuid='123', text='foo bar')
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+
+ from rest_framework.tests import test_routers
+ urls = getattr(test_routers, 'urlpatterns')
+ urls += patterns('',
+ url(r'^', include(self.router.urls)),
+ )
+
+ def test_custom_lookup_field_route(self):
+ detail_route = self.router.urls[-1]
+ detail_url_pattern = detail_route.regex.pattern
+ self.assertIn('<uuid>', detail_url_pattern)
+
+ def test_retrieve_lookup_field_list_view(self):
+ response = self.client.get('/notes/')
+ self.assertEqual(response.data,
+ [{
+ "url": "http://testserver/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }]
+ )
+
+ def test_retrieve_lookup_field_detail_view(self):
+ response = self.client.get('/notes/123/')
+ self.assertEqual(response.data,
+ {
+ "url": "http://testserver/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }
+ )
+
+
+class TestTrailingSlash(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_have_trailing_slash_by_default(self):
+ expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestTrailingSlash(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+
+ self.router = SimpleRouter(trailing_slash=False)
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_can_have_trailing_slash_removed(self):
+ expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/test_serializer.py
index 84e1ee4e..8b87a084 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -1,10 +1,14 @@
from __future__ import unicode_literals
-from django.utils.datastructures import MultiValueDict
+from django.db import models
+from django.db.models.fields import BLANK_CHOICE_DASH
from django.test import TestCase
-from rest_framework import serializers
+from django.utils.datastructures import MultiValueDict
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import serializers, fields, relations
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
- ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
+ ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel)
+from rest_framework.tests.models import BasicModelSerializer
import datetime
import pickle
@@ -43,6 +47,17 @@ class CommentSerializer(serializers.Serializer):
return instance
+class NamesSerializer(serializers.Serializer):
+ first = serializers.CharField()
+ last = serializers.CharField(required=False, default='')
+ initials = serializers.CharField(required=False, default='')
+
+
+class PersonIdentifierSerializer(serializers.Serializer):
+ ssn = serializers.CharField()
+ names = NamesSerializer(source='names', required=False)
+
+
class BookSerializer(serializers.ModelSerializer):
isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'})
@@ -78,6 +93,29 @@ class PersonSerializer(serializers.ModelSerializer):
read_only_fields = ('age',)
+class NestedSerializer(serializers.Serializer):
+ info = serializers.Field()
+
+
+class ModelSerializerWithNestedSerializer(serializers.ModelSerializer):
+ nested = NestedSerializer(source='*')
+
+ class Meta:
+ model = Person
+
+
+class PersonSerializerInvalidReadOnly(serializers.ModelSerializer):
+ """
+ Testing for #652.
+ """
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+ read_only_fields = ('age', 'info')
+
+
class AlbumsSerializer(serializers.ModelSerializer):
class Meta:
@@ -91,11 +129,6 @@ class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
fields = ['some_integer']
-class BrokenModelSerializer(serializers.ModelSerializer):
- class Meta:
- fields = ['some_field']
-
-
class BasicTests(TestCase):
def setUp(self):
self.comment = Comment(
@@ -141,6 +174,42 @@ class BasicTests(TestCase):
self.assertFalse(serializer.object is expected)
self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
+ def test_create_nested(self):
+ """Test a serializer with nested data."""
+ names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'}
+ data = {'ssn': '1234567890', 'names': names}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is data)
+ self.assertEqual(serializer.data['names'], names)
+
+ def test_create_partial_nested(self):
+ """Test a serializer with nested data which has missing fields."""
+ names = {'first': 'John'}
+ data = {'ssn': '1234567890', 'names': names}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ expected_names = {'first': 'John', 'last': '', 'initials': ''}
+ data['names'] = expected_names
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is expected_names)
+ self.assertEqual(serializer.data['names'], expected_names)
+
+ def test_null_nested(self):
+ """Test a serializer with a nonexistent nested field"""
+ data = {'ssn': '1234567890'}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is data)
+ expected = {'ssn': '1234567890', 'names': None}
+ self.assertEqual(serializer.data, expected)
+
def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data)
expected = self.comment
@@ -189,6 +258,12 @@ class BasicTests(TestCase):
# Assert age is unchanged (35)
self.assertEqual(instance.age, self.person_data['age'])
+ def test_invalid_read_only_fields(self):
+ """
+ Regression test for #652.
+ """
+ self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, [])
+
class DictStyleSerializer(serializers.Serializer):
"""
@@ -344,19 +419,34 @@ class ValidationTests(TestCase):
Assert that a meaningful exception message is outputted when the model
field is missing (e.g. when mistyping ``model``).
"""
+ class BrokenModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ fields = ['some_field']
+
try:
- serializer = BrokenModelSerializer()
+ BrokenModelSerializer()
except AssertionError as e:
self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option")
except:
self.fail('Wrong exception type thrown.')
+ def test_writable_star_source_on_nested_serializer(self):
+ """
+ Assert that a nested serializer instantiated with source='*' correctly
+ expands the data into the outer serializer.
+ """
+ serializer = ModelSerializerWithNestedSerializer(data={
+ 'name': 'marko',
+ 'nested': {'info': 'hi'}},
+ )
+ self.assertEqual(serializer.is_valid(), True)
+
class CustomValidationTests(TestCase):
class CommentSerializerWithFieldValidator(CommentSerializer):
def validate_email(self, attrs, source):
- value = attrs[source]
+ attrs[source]
return attrs
def validate_content(self, attrs, source):
@@ -853,23 +943,6 @@ class RelatedTraversalTest(TestCase):
self.assertEqual(serializer.data, expected)
- def test_queryset_nested_traversal(self):
- """
- Relational fields should be able to use methods as their source.
- """
- BlogPost.objects.create(title='blah')
-
- class QuerysetMethodSerializer(serializers.Serializer):
- blogposts = serializers.RelatedField(many=True, source='get_all_blogposts')
-
- class ClassWithQuerysetMethod(object):
- def get_all_blogposts(self):
- return BlogPost.objects
-
- obj = ClassWithQuerysetMethod()
- serializer = QuerysetMethodSerializer(obj)
- self.assertEqual(serializer.data, {'blogposts': ['BlogPost object']})
-
class SerializerMethodFieldTests(TestCase):
def setUp(self):
@@ -1000,6 +1073,130 @@ class SerializerPickleTests(TestCase):
repr(pickle.loads(pickle.dumps(data, 0)))
+# test for issue #725
+class SeveralChoicesModel(models.Model):
+ color = models.CharField(
+ max_length=10,
+ choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')],
+ blank=False
+ )
+ drink = models.CharField(
+ max_length=10,
+ choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')],
+ blank=False,
+ default='beer'
+ )
+ os = models.CharField(
+ max_length=10,
+ choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')],
+ blank=True
+ )
+ music_genre = models.CharField(
+ max_length=10,
+ choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')],
+ blank=True,
+ default='metal'
+ )
+
+
+class SerializerChoiceFields(TestCase):
+
+ def setUp(self):
+ super(SerializerChoiceFields, self).setUp()
+
+ class SeveralChoicesSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = SeveralChoicesModel
+ fields = ('color', 'drink', 'os', 'music_genre')
+
+ self.several_choices_serializer = SeveralChoicesSerializer
+
+ def test_choices_blank_false_not_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['color'].choices,
+ [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')]
+ )
+
+ def test_choices_blank_false_with_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['drink'].choices,
+ [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')]
+ )
+
+ def test_choices_blank_true_not_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['os'].choices,
+ BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')]
+ )
+
+ def test_choices_blank_true_with_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['music_genre'].choices,
+ BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')]
+ )
+
+
+# Regression tests for #675
+class Ticket(models.Model):
+ assigned = models.ForeignKey(
+ Person, related_name='assigned_tickets')
+ reviewer = models.ForeignKey(
+ Person, blank=True, null=True, related_name='reviewed_tickets')
+
+
+class SerializerRelatedChoicesTest(TestCase):
+
+ def setUp(self):
+ super(SerializerRelatedChoicesTest, self).setUp()
+
+ class RelatedChoicesSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Ticket
+ fields = ('assigned', 'reviewer')
+
+ self.related_fields_serializer = RelatedChoicesSerializer
+
+ def test_empty_queryset_required(self):
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['assigned'].queryset.count(), 0)
+ self.assertEqual(
+ [x for x in serializer.fields['assigned'].widget.choices],
+ []
+ )
+
+ def test_empty_queryset_not_required(self):
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['reviewer'].queryset.count(), 0)
+ self.assertEqual(
+ [x for x in serializer.fields['reviewer'].widget.choices],
+ [('', '---------')]
+ )
+
+ def test_with_some_persons_required(self):
+ Person.objects.create(name="Lionel Messi")
+ Person.objects.create(name="Xavi Hernandez")
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['assigned'].queryset.count(), 2)
+ self.assertEqual(
+ [x for x in serializer.fields['assigned'].widget.choices],
+ [(1, 'Person object - 1'), (2, 'Person object - 2')]
+ )
+
+ def test_with_some_persons_not_required(self):
+ Person.objects.create(name="Lionel Messi")
+ Person.objects.create(name="Xavi Hernandez")
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['reviewer'].queryset.count(), 2)
+ self.assertEqual(
+ [x for x in serializer.fields['reviewer'].widget.choices],
+ [('', '---------'), (1, 'Person object - 1'), (2, 'Person object - 2')]
+ )
+
+
class DepthTest(TestCase):
def test_implicit_nesting(self):
@@ -1125,3 +1322,312 @@ class DeserializeListTestCase(TestCase):
self.assertFalse(serializer.is_valid())
expected = [{}, {'email': ['This field is required.']}, {}]
self.assertEqual(serializer.errors, expected)
+
+
+# Test for issue 747
+
+class LazyStringModel(object):
+ def __init__(self, lazystring):
+ self.lazystring = lazystring
+
+
+class LazyStringSerializer(serializers.Serializer):
+ lazystring = serializers.Field()
+
+ def restore_object(self, attrs, instance=None):
+ if instance is not None:
+ instance.lazystring = attrs.get('lazystring', instance.lazystring)
+ return instance
+ return LazyStringModel(**attrs)
+
+
+class LazyStringsTestCase(TestCase):
+ def setUp(self):
+ self.model = LazyStringModel(lazystring=_('lazystring'))
+
+ def test_lazy_strings_are_translated(self):
+ serializer = LazyStringSerializer(self.model)
+ self.assertEqual(type(serializer.data['lazystring']),
+ type('lazystring'))
+
+
+# Test for issue #467
+
+class FieldLabelTest(TestCase):
+ def setUp(self):
+ self.serializer_class = BasicModelSerializer
+
+ def test_label_from_model(self):
+ """
+ Validates that label and help_text are correctly copied from the model class.
+ """
+ serializer = self.serializer_class()
+ text_field = serializer.fields['text']
+
+ self.assertEqual('Text comes here', text_field.label)
+ self.assertEqual('Text description.', text_field.help_text)
+
+ def test_field_ctor(self):
+ """
+ This is check that ctor supports both label and help_text.
+ """
+ self.assertEqual('Label', fields.Field(label='Label', help_text='Help').label)
+ self.assertEqual('Help', fields.CharField(label='Label', help_text='Help').help_text)
+ self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label)
+
+
+class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
+
+ def setUp(self):
+ class AMOAFModel(RESTFrameworkModel):
+ char_field = models.CharField(max_length=1024, blank=True)
+ comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
+ decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
+ email_field = models.EmailField(max_length=1024, blank=True)
+ file_field = models.FileField(max_length=1024, blank=True)
+ image_field = models.ImageField(max_length=1024, blank=True)
+ slug_field = models.SlugField(max_length=1024, blank=True)
+ url_field = models.URLField(max_length=1024, blank=True)
+
+ class AMOAFSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = AMOAFModel
+
+ self.serializer_class = AMOAFSerializer
+ self.fields_attributes = {
+ 'char_field': [
+ ('max_length', 1024),
+ ],
+ 'comma_separated_integer_field': [
+ ('max_length', 1024),
+ ],
+ 'decimal_field': [
+ ('max_digits', 64),
+ ('decimal_places', 32),
+ ],
+ 'email_field': [
+ ('max_length', 1024),
+ ],
+ 'file_field': [
+ ('max_length', 1024),
+ ],
+ 'image_field': [
+ ('max_length', 1024),
+ ],
+ 'slug_field': [
+ ('max_length', 1024),
+ ],
+ 'url_field': [
+ ('max_length', 1024),
+ ],
+ }
+
+ def field_test(self, field):
+ serializer = self.serializer_class(data={})
+ self.assertEqual(serializer.is_valid(), True)
+
+ for attribute in self.fields_attributes[field]:
+ self.assertEqual(
+ getattr(serializer.fields[field], attribute[0]),
+ attribute[1]
+ )
+
+ def test_char_field(self):
+ self.field_test('char_field')
+
+ def test_comma_separated_integer_field(self):
+ self.field_test('comma_separated_integer_field')
+
+ def test_decimal_field(self):
+ self.field_test('decimal_field')
+
+ def test_email_field(self):
+ self.field_test('email_field')
+
+ def test_file_field(self):
+ self.field_test('file_field')
+
+ def test_image_field(self):
+ self.field_test('image_field')
+
+ def test_slug_field(self):
+ self.field_test('slug_field')
+
+ def test_url_field(self):
+ self.field_test('url_field')
+
+
+class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
+
+ def setUp(self):
+ class DVOAFModel(RESTFrameworkModel):
+ positive_integer_field = models.PositiveIntegerField(blank=True)
+ positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
+ email_field = models.EmailField(blank=True)
+ file_field = models.FileField(blank=True)
+ image_field = models.ImageField(blank=True)
+ slug_field = models.SlugField(blank=True)
+ url_field = models.URLField(blank=True)
+
+ class DVOAFSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DVOAFModel
+
+ self.serializer_class = DVOAFSerializer
+ self.fields_attributes = {
+ 'positive_integer_field': [
+ ('min_value', 0),
+ ],
+ 'positive_small_integer_field': [
+ ('min_value', 0),
+ ],
+ 'email_field': [
+ ('max_length', 75),
+ ],
+ 'file_field': [
+ ('max_length', 100),
+ ],
+ 'image_field': [
+ ('max_length', 100),
+ ],
+ 'slug_field': [
+ ('max_length', 50),
+ ],
+ 'url_field': [
+ ('max_length', 200),
+ ],
+ }
+
+ def field_test(self, field):
+ serializer = self.serializer_class(data={})
+ self.assertEqual(serializer.is_valid(), True)
+
+ for attribute in self.fields_attributes[field]:
+ self.assertEqual(
+ getattr(serializer.fields[field], attribute[0]),
+ attribute[1]
+ )
+
+ def test_positive_integer_field(self):
+ self.field_test('positive_integer_field')
+
+ def test_positive_small_integer_field(self):
+ self.field_test('positive_small_integer_field')
+
+ def test_email_field(self):
+ self.field_test('email_field')
+
+ def test_file_field(self):
+ self.field_test('file_field')
+
+ def test_image_field(self):
+ self.field_test('image_field')
+
+ def test_slug_field(self):
+ self.field_test('slug_field')
+
+ def test_url_field(self):
+ self.field_test('url_field')
+
+
+class MetadataSerializer(serializers.Serializer):
+ field1 = serializers.CharField(3, required=True)
+ field2 = serializers.CharField(10, required=False)
+
+
+class MetadataSerializerTestCase(TestCase):
+ def setUp(self):
+ self.serializer = MetadataSerializer()
+
+ def test_serializer_metadata(self):
+ metadata = self.serializer.metadata()
+ expected = {
+ 'field1': {
+ 'required': True,
+ 'max_length': 3,
+ 'type': 'string',
+ 'read_only': False
+ },
+ 'field2': {
+ 'required': False,
+ 'max_length': 10,
+ 'type': 'string',
+ 'read_only': False
+ }
+ }
+ self.assertEqual(expected, metadata)
+
+
+### Regression test for #840
+
+class SimpleModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class SimpleModelSerializer(serializers.ModelSerializer):
+ text = serializers.CharField()
+ other = serializers.CharField()
+
+ class Meta:
+ model = SimpleModel
+
+ def validate_other(self, attrs, source):
+ del attrs['other']
+ return attrs
+
+
+class FieldValidationRemovingAttr(TestCase):
+ def test_removing_non_model_field_in_validation(self):
+ """
+ Removing an attr during field valiation should ensure that it is not
+ passed through when restoring the object.
+
+ This allows additional non-model fields to be supported.
+
+ Regression test for #840.
+ """
+ serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.object.text, 'foo')
+
+
+### Regression test for #878
+
+class SimpleTargetModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class SimplePKSourceModelSerializer(serializers.Serializer):
+ targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True)
+ text = serializers.CharField()
+
+
+class SimpleSlugSourceModelSerializer(serializers.Serializer):
+ targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk')
+ text = serializers.CharField()
+
+
+class SerializerSupportsManyRelationships(TestCase):
+ def setUp(self):
+ SimpleTargetModel.objects.create(text='foo')
+ SimpleTargetModel.objects.create(text='bar')
+
+ def test_serializer_supports_pk_many_relationships(self):
+ """
+ Regression test for #878.
+
+ Note that pk behavior has a different code path to usual cases,
+ for performance reasons.
+ """
+ serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
+
+ def test_serializer_supports_slug_many_relationships(self):
+ """
+ Regression test for #878.
+ """
+ serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/test_serializer_bulk_update.py
index 8b0ded1a..8b0ded1a 100644
--- a/rest_framework/tests/serializer_bulk_update.py
+++ b/rest_framework/tests/test_serializer_bulk_update.py
diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/test_serializer_nested.py
index 71d0e24b..71d0e24b 100644
--- a/rest_framework/tests/serializer_nested.py
+++ b/rest_framework/tests/test_serializer_nested.py
diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/test_settings.py
index 857375c2..857375c2 100644
--- a/rest_framework/tests/settings.py
+++ b/rest_framework/tests/test_settings.py
diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/test_throttling.py
index 11cbd8eb..da400b2f 100644
--- a/rest_framework/tests/throttling.py
+++ b/rest_framework/tests/test_throttling.py
@@ -36,7 +36,7 @@ class MockView_MinuteThrottling(APIView):
class ThrottlingTests(TestCase):
- urls = 'rest_framework.tests.throttling'
+ urls = 'rest_framework.tests.test_throttling'
def setUp(self):
"""
diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/test_urlpatterns.py
index 29ed4a96..29ed4a96 100644
--- a/rest_framework/tests/urlpatterns.py
+++ b/rest_framework/tests/test_urlpatterns.py
diff --git a/rest_framework/tests/validation.py b/rest_framework/tests/test_validation.py
index cbdd6515..a6ec0e99 100644
--- a/rest_framework/tests/validation.py
+++ b/rest_framework/tests/test_validation.py
@@ -63,3 +63,25 @@ class TestPreSaveValidationExclusions(TestCase):
# does not have `blank=True`, so this serializer should not validate.
serializer = ShouldValidateModelSerializer(data={'renamed': ''})
self.assertEqual(serializer.is_valid(), False)
+
+
+class ValidationSerializer(serializers.Serializer):
+ foo = serializers.CharField()
+
+ def validate_foo(self, attrs, source):
+ raise serializers.ValidationError("foo invalid")
+
+ def validate(self, attrs):
+ raise serializers.ValidationError("serializer invalid")
+
+
+class TestAvoidValidation(TestCase):
+ """
+ If serializer was initialized with invalid data (None or non dict-like), it
+ should avoid validation layer (validate_<field> and validate methods)
+ """
+ def test_serializer_errors_has_only_invalid_data_error(self):
+ serializer = ValidationSerializer(data='invalid data')
+ self.assertFalse(serializer.is_valid())
+ self.assertDictEqual(serializer.errors,
+ {'non_field_errors': ['Invalid data']})
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/test_views.py
index 994cf6dc..2767d24c 100644
--- a/rest_framework/tests/views.py
+++ b/rest_framework/tests/test_views.py
@@ -1,12 +1,15 @@
from __future__ import unicode_literals
+
+import copy
+
from django.test import TestCase
from django.test.client import RequestFactory
+
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.views import APIView
-import copy
factory = RequestFactory()
diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py
deleted file mode 100644
index f8c2579e..00000000
--- a/rest_framework/tests/testcases.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# http://djangosnippets.org/snippets/1011/
-from __future__ import unicode_literals
-from django.conf import settings
-from django.core.management import call_command
-from django.db.models import loading
-from django.test import TestCase
-
-NO_SETTING = ('!', None)
-
-
-class TestSettingsManager(object):
- """
- A class which can modify some Django settings temporarily for a
- test and then revert them to their original values later.
-
- Automatically handles resyncing the DB if INSTALLED_APPS is
- modified.
-
- """
- def __init__(self):
- self._original_settings = {}
-
- def set(self, **kwargs):
- for k, v in kwargs.iteritems():
- self._original_settings.setdefault(k, getattr(settings, k,
- NO_SETTING))
- setattr(settings, k, v)
- if 'INSTALLED_APPS' in kwargs:
- self.syncdb()
-
- def syncdb(self):
- loading.cache.loaded = False
- call_command('syncdb', verbosity=0)
-
- def revert(self):
- for k, v in self._original_settings.iteritems():
- if v == NO_SETTING:
- delattr(settings, k)
- else:
- setattr(settings, k, v)
- if 'INSTALLED_APPS' in self._original_settings:
- self.syncdb()
- self._original_settings = {}
-
-
-class SettingsTestCase(TestCase):
- """
- A subclass of the Django TestCase with a settings_manager
- attribute which is an instance of TestSettingsManager.
-
- Comes with a tearDown() method that calls
- self.settings_manager.revert().
-
- """
- def __init__(self, *args, **kwargs):
- super(SettingsTestCase, self).__init__(*args, **kwargs)
- self.settings_manager = TestSettingsManager()
-
- def tearDown(self):
- self.settings_manager.revert()
-
-
-class TestModelsTestCase(SettingsTestCase):
- def setUp(self, *args, **kwargs):
- installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',)
- self.settings_manager.set(INSTALLED_APPS=installed_apps)
diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py
index 08f88e11..554ebd1a 100644
--- a/rest_framework/tests/tests.py
+++ b/rest_framework/tests/tests.py
@@ -4,11 +4,13 @@ runner to pick up the tests. Yowzers.
"""
from __future__ import unicode_literals
import os
+import django
modules = [filename.rsplit('.', 1)[0]
for filename in os.listdir(os.path.dirname(__file__))
if filename.endswith('.py') and not filename.startswith('_')]
__test__ = dict()
-for module in modules:
- exec("from rest_framework.tests.%s import *" % module)
+if django.VERSION < (1, 6):
+ for module in modules:
+ exec("from rest_framework.tests.%s import *" % module)
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index b6de18a8..b26a2085 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -3,7 +3,8 @@ Helper classes for parsers.
"""
from __future__ import unicode_literals
from django.utils.datastructures import SortedDict
-from rest_framework.compat import timezone
+from django.utils.functional import Promise
+from rest_framework.compat import timezone, force_text
from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
import datetime
import decimal
@@ -19,7 +20,9 @@ class JSONEncoder(json.JSONEncoder):
def default(self, o):
# For Date Time string spec, see ECMA 262
# http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15
- if isinstance(o, datetime.datetime):
+ if isinstance(o, Promise):
+ return force_text(o)
+ elif isinstance(o, datetime.datetime):
r = o.isoformat()
if o.microsecond:
r = r[:23] + r[26:]
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 555fa2f4..e1b6705b 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -2,13 +2,15 @@
Provides an APIView class that is the base of all views in REST framework.
"""
from __future__ import unicode_literals
+
from django.core.exceptions import PermissionDenied
from django.http import Http404, HttpResponse
+from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions
from rest_framework.compat import View
-from rest_framework.response import Response
from rest_framework.request import Request
+from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.utils.formatting import get_view_name, get_view_description
@@ -51,21 +53,6 @@ class APIView(View):
'Vary': 'Accept'
}
- def metadata(self, request):
- return {
- 'name': get_view_name(self.__class__),
- 'description': get_view_description(self.__class__),
- 'renders': [renderer.media_type for renderer in self.renderer_classes],
- 'parses': [parser.media_type for parser in self.parser_classes],
- }
- # TODO: Add 'fields', from serializer info, if it exists.
- # serializer = self.get_serializer()
- # if serializer is not None:
- # field_name_types = {}
- # for name, field in form.fields.iteritems():
- # field_name_types[name] = field.__class__.__name__
- # content['fields'] = field_name_types
-
def http_method_not_allowed(self, request, *args, **kwargs):
"""
If `request.method` does not correspond to a handler method,
@@ -348,3 +335,15 @@ class APIView(View):
a less useful default implementation.
"""
return Response(self.metadata(request), status=status.HTTP_200_OK)
+
+ def metadata(self, request):
+ """
+ Return a dictionary of metadata about the view.
+ Used to return responses for OPTIONS requests.
+ """
+ ret = SortedDict()
+ ret['name'] = get_view_name(self.__class__)
+ ret['description'] = get_view_description(self.__class__)
+ ret['renders'] = [renderer.media_type for renderer in self.renderer_classes]
+ ret['parses'] = [parser.media_type for parser in self.parser_classes]
+ return ret
diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py
index 0eb3e86d..d91323f2 100644
--- a/rest_framework/viewsets.py
+++ b/rest_framework/viewsets.py
@@ -108,10 +108,18 @@ class ViewSet(ViewSetMixin, views.APIView):
pass
+class GenericViewSet(ViewSetMixin, generics.GenericAPIView):
+ """
+ The GenericViewSet class does not provide any actions by default,
+ but does include the base set of generic view behavior, such as
+ the `get_object` and `get_queryset` methods.
+ """
+ pass
+
+
class ReadOnlyModelViewSet(mixins.RetrieveModelMixin,
mixins.ListModelMixin,
- ViewSetMixin,
- generics.GenericAPIView):
+ GenericViewSet):
"""
A viewset that provides default `list()` and `retrieve()` actions.
"""
@@ -123,8 +131,7 @@ class ModelViewSet(mixins.CreateModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
mixins.ListModelMixin,
- ViewSetMixin,
- generics.GenericAPIView):
+ GenericViewSet):
"""
A viewset that provides default `create()`, `retrieve()`, `update()`,
`partial_update()`, `destroy()` and `list()` actions.