aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorRyan Kaskel2013-05-18 14:17:50 +0100
committerRyan Kaskel2013-05-18 14:17:50 +0100
commit22874e441dd71101296a656e753bfc17907b5cca (patch)
tree6ebf7971e5bf8d40c6d60fa857cbe0c04fc91372 /rest_framework
parentb5640bb77843c50f42a649982b9b9592113c6f59 (diff)
parenta0e3c44c99a61a6dc878308bdf0890fbb10c41e4 (diff)
downloaddjango-rest-framework-22874e441dd71101296a656e753bfc17907b5cca.tar.bz2
Merge latest changes from master.
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/authentication.py30
-rw-r--r--rest_framework/compat.py60
-rw-r--r--rest_framework/decorators.py30
-rw-r--r--rest_framework/fields.py97
-rw-r--r--rest_framework/filters.py101
-rw-r--r--rest_framework/generics.py364
-rw-r--r--rest_framework/mixins.py66
-rw-r--r--rest_framework/negotiation.py4
-rw-r--r--rest_framework/pagination.py6
-rw-r--r--rest_framework/parsers.py90
-rw-r--r--rest_framework/permissions.py27
-rw-r--r--rest_framework/relations.py232
-rw-r--r--rest_framework/renderers.py60
-rw-r--r--rest_framework/request.py5
-rw-r--r--rest_framework/response.py6
-rw-r--r--rest_framework/routers.py249
-rw-r--r--rest_framework/runtests/settings.py2
-rw-r--r--rest_framework/serializers.py194
-rw-r--r--rest_framework/settings.py18
-rw-r--r--rest_framework/templates/rest_framework/base.html2
-rw-r--r--rest_framework/templates/rest_framework/login.html54
-rw-r--r--rest_framework/templates/rest_framework/login_base.html51
-rw-r--r--rest_framework/templatetags/rest_framework.py95
-rw-r--r--rest_framework/tests/authentication.py44
-rw-r--r--rest_framework/tests/description.py63
-rw-r--r--rest_framework/tests/fields.py208
-rw-r--r--rest_framework/tests/filters.py474
-rw-r--r--rest_framework/tests/filterset.py169
-rw-r--r--rest_framework/tests/generics.py105
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py40
-rw-r--r--rest_framework/tests/models.py7
-rw-r--r--rest_framework/tests/pagination.py24
-rw-r--r--rest_framework/tests/parsers.py33
-rw-r--r--rest_framework/tests/relations.py55
-rw-r--r--rest_framework/tests/relations_hyperlink.py87
-rw-r--r--rest_framework/tests/relations_nested.py24
-rw-r--r--rest_framework/tests/relations_pk.py74
-rw-r--r--rest_framework/tests/routers.py55
-rw-r--r--rest_framework/tests/serializer.py224
-rw-r--r--rest_framework/tests/serializer_bulk_update.py34
-rw-r--r--rest_framework/tests/serializer_nested.py4
-rw-r--r--rest_framework/throttling.py8
-rw-r--r--rest_framework/utils/breadcrumbs.py35
-rw-r--r--rest_framework/utils/formatting.py80
-rw-r--r--rest_framework/views.py107
-rw-r--r--rest_framework/viewsets.py139
47 files changed, 3067 insertions, 871 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index cf005636..0b1e67fb 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,4 +1,4 @@
-__version__ = '2.2.4'
+__version__ = '2.3.3'
VERSION = __version__ # synonym
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 8f4ec536..9caca788 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -1,15 +1,17 @@
"""
-Provides a set of pluggable authentication policies.
+Provides various authentication policies.
"""
from __future__ import unicode_literals
+import base64
+from datetime import datetime
+
from django.contrib.auth import authenticate
from django.core.exceptions import ImproperlyConfigured
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
-from rest_framework.compat import oauth2_provider, oauth2_provider_forms, oauth2_provider_backends
+from rest_framework.compat import oauth2_provider
from rest_framework.authtoken.models import Token
-import base64
def get_authorization_header(request):
@@ -228,7 +230,7 @@ class OAuthAuthentication(BaseAuthentication):
try:
consumer_key = oauth_request.get_parameter('oauth_consumer_key')
consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
- except oauth_provider_store.InvalidConsumerError as err:
+ except oauth_provider.store.InvalidConsumerError as err:
raise exceptions.AuthenticationFailed(err)
if consumer.status != oauth_provider.consts.ACCEPTED:
@@ -238,7 +240,7 @@ class OAuthAuthentication(BaseAuthentication):
try:
token_param = oauth_request.get_parameter('oauth_token')
token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param)
- except oauth_provider_store.InvalidTokenError:
+ except oauth_provider.store.InvalidTokenError:
msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token')
raise exceptions.AuthenticationFailed(msg)
@@ -315,16 +317,12 @@ class OAuth2Authentication(BaseAuthentication):
Authenticate the request, given the access token.
"""
- # Authenticate the client
- oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST)
- if not oauth2_client_form.is_valid():
- raise exceptions.AuthenticationFailed('Client could not be validated')
- client = oauth2_client_form.cleaned_data.get('client')
-
- # Retrieve the `OAuth2AccessToken` instance from the access_token
- auth_backend = oauth2_provider_backends.AccessTokenBackend()
- token = auth_backend.authenticate(access_token, client)
- if token is None:
+ try:
+ token = oauth2_provider.models.AccessToken.objects.select_related('user')
+ # TODO: Change to timezone aware datetime when oauth2_provider add
+ # support to it.
+ token = token.get(token=access_token, expires__gt=datetime.now())
+ except oauth2_provider.models.AccessToken.DoesNotExist:
raise exceptions.AuthenticationFailed('Invalid token')
user = token.user
@@ -333,7 +331,7 @@ class OAuth2Authentication(BaseAuthentication):
msg = 'User inactive or deleted: %s' % user.username
raise exceptions.AuthenticationFailed(msg)
- return (token.user, token)
+ return (user, token)
def authenticate_header(self, request):
"""
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 7b2ef738..cd39f544 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -6,6 +6,7 @@ versions of django/python, and compatibility wrappers around optional packages.
from __future__ import unicode_literals
import django
+from django.core.exceptions import ImproperlyConfigured
# Try to import six from Django, fallback to included `six`.
try:
@@ -87,9 +88,7 @@ else:
raise ImportError("User model is not to be found.")
-# First implementation of Django class-based views did not include head method
-# in base View class - https://code.djangoproject.com/ticket/15668
-if django.VERSION >= (1, 4):
+if django.VERSION >= (1, 5):
from django.views.generic import View
else:
from django.views.generic import View as _View
@@ -97,6 +96,8 @@ else:
from django.utils.functional import update_wrapper
class View(_View):
+ # 1.3 does not include head method in base View class
+ # See: https://code.djangoproject.com/ticket/15668
@classonlymethod
def as_view(cls, **initkwargs):
"""
@@ -126,11 +127,15 @@ else:
update_wrapper(view, cls.dispatch, assigned=())
return view
-# Taken from @markotibold's attempt at supporting PATCH.
-# https://github.com/markotibold/django-rest-framework/tree/patch
-http_method_names = set(View.http_method_names)
-http_method_names.add('patch')
-View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django
+ # _allowed_methods only present from 1.5 onwards
+ def _allowed_methods(self):
+ return [m.upper() for m in self.http_method_names if hasattr(self, m)]
+
+
+# PATCH method is not implemented by Django
+if 'patch' not in View.http_method_names:
+ View.http_method_names = View.http_method_names + ['patch']
+
# PUT, DELETE do not require CSRF until 1.4. They should. Make it better.
if django.VERSION >= (1, 4):
@@ -395,6 +400,41 @@ except ImportError:
kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None)
return datetime.datetime(**kw)
+
+# smart_urlquote is new on Django 1.4
+try:
+ from django.utils.html import smart_urlquote
+except ImportError:
+ import re
+ from django.utils.encoding import smart_str
+ try:
+ from urllib.parse import quote, urlsplit, urlunsplit
+ except ImportError: # Python 2
+ from urllib import quote
+ from urlparse import urlsplit, urlunsplit
+
+ unquoted_percents_re = re.compile(r'%(?![0-9A-Fa-f]{2})')
+
+ def smart_urlquote(url):
+ "Quotes a URL if it isn't already quoted."
+ # Handle IDN before quoting.
+ scheme, netloc, path, query, fragment = urlsplit(url)
+ try:
+ netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE
+ except UnicodeError: # invalid domain part
+ pass
+ else:
+ url = urlunsplit((scheme, netloc, path, query, fragment))
+
+ # An URL is considered unquoted if it contains no % characters or
+ # contains a % not followed by two hexadecimal digits. See #9655.
+ if '%' not in url or unquoted_percents_re.search(url):
+ # See http://bugs.python.org/issue2637
+ url = quote(smart_str(url), safe=b'!*\'();:@&=+$,/?#[]~')
+
+ return force_text(url)
+
+
# Markdown is optional
try:
import markdown
@@ -438,21 +478,19 @@ except ImportError:
try:
import oauth_provider
from oauth_provider.store import store as oauth_provider_store
-except ImportError:
+except (ImportError, ImproperlyConfigured):
oauth_provider = None
oauth_provider_store = None
# OAuth 2 support is optional
try:
import provider.oauth2 as oauth2_provider
- from provider.oauth2 import backends as oauth2_provider_backends
from provider.oauth2 import models as oauth2_provider_models
from provider.oauth2 import forms as oauth2_provider_forms
from provider import scope as oauth2_provider_scope
from provider import constants as oauth2_constants
except ImportError:
oauth2_provider = None
- oauth2_provider_backends = None
oauth2_provider_models = None
oauth2_provider_forms = None
oauth2_provider_scope = None
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 8250cd3b..81e585e1 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -1,3 +1,11 @@
+"""
+The most imporant decorator in this module is `@api_view`, which is used
+for writing function-based views with REST framework.
+
+There are also various decorators for setting the API policies on function
+based views, as well as the `@action` and `@link` decorators, which are
+used to annotate methods on viewsets that should be included by routers.
+"""
from __future__ import unicode_literals
from rest_framework.compat import six
from rest_framework.views import APIView
@@ -97,3 +105,25 @@ def permission_classes(permission_classes):
func.permission_classes = permission_classes
return func
return decorator
+
+
+def link(**kwargs):
+ """
+ Used to mark a method on a ViewSet that should be routed for GET requests.
+ """
+ def decorator(func):
+ func.bind_to_method = 'get'
+ func.kwargs = kwargs
+ return func
+ return decorator
+
+
+def action(**kwargs):
+ """
+ Used to mark a method on a ViewSet that should be routed for POST requests.
+ """
+ def decorator(func):
+ func.bind_to_method = 'post'
+ func.kwargs = kwargs
+ return func
+ return decorator
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 09f076ab..2ab603cf 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,7 +1,13 @@
+"""
+Serializer fields perform validation on incoming data.
+
+They are very similar to Django's form fields.
+"""
from __future__ import unicode_literals
import copy
import datetime
+from decimal import Decimal, DecimalException
import inspect
import re
import warnings
@@ -9,10 +15,12 @@ import warnings
from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
+from django.db.models.fields import BLANK_CHOICE_DASH
from django import forms
from django.forms import widgets
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
+from django.utils.datastructures import SortedDict
from rest_framework import ISO_8601
from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time
@@ -44,7 +52,7 @@ def get_component(obj, attr_name):
return that attribute on the object.
"""
if isinstance(obj, dict):
- val = obj[attr_name]
+ val = obj.get(attr_name)
else:
val = getattr(obj, attr_name)
@@ -164,7 +172,11 @@ class Field(object):
elif hasattr(value, '__iter__') and not isinstance(value, (dict, six.string_types)):
return [self.to_native(item) for item in value]
elif isinstance(value, dict):
- return dict(map(self.to_native, (k, v)) for k, v in value.items())
+ # Make sure we preserve field ordering, if it exists
+ ret = SortedDict()
+ for key, val in value.items():
+ ret[key] = self.to_native(val)
+ return ret
return force_text(value)
def attributes(self):
@@ -194,9 +206,9 @@ class WritableField(Field):
# 'blank' is to be deprecated in favor of 'required'
if blank is not None:
- warnings.warn('The `blank` keyword argument is due to deprecated. '
+ warnings.warn('The `blank` keyword argument is deprecated. '
'Use the `required` keyword argument instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
required = not(blank)
super(WritableField, self).__init__(source=source)
@@ -396,6 +408,8 @@ class ChoiceField(WritableField):
def __init__(self, choices=(), *args, **kwargs):
super(ChoiceField, self).__init__(*args, **kwargs)
self.choices = choices
+ if not self.required:
+ self.choices = BLANK_CHOICE_DASH + self.choices
def _get_choices(self):
return self._choices
@@ -494,7 +508,7 @@ class DateField(WritableField):
}
empty = None
input_formats = api_settings.DATE_INPUT_FORMATS
- format = None
+ format = api_settings.DATE_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
@@ -557,7 +571,7 @@ class DateTimeField(WritableField):
}
empty = None
input_formats = api_settings.DATETIME_INPUT_FORMATS
- format = None
+ format = api_settings.DATETIME_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
@@ -626,7 +640,7 @@ class TimeField(WritableField):
}
empty = None
input_formats = api_settings.TIME_INPUT_FORMATS
- format = None
+ format = api_settings.TIME_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
@@ -721,6 +735,75 @@ class FloatField(WritableField):
raise ValidationError(msg)
+class DecimalField(WritableField):
+ type_name = 'DecimalField'
+ form_field_class = forms.DecimalField
+
+ default_error_messages = {
+ 'invalid': _('Enter a number.'),
+ 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
+ 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
+ 'max_digits': _('Ensure that there are no more than %s digits in total.'),
+ 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'),
+ 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.')
+ }
+
+ def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs):
+ self.max_value, self.min_value = max_value, min_value
+ self.max_digits, self.decimal_places = max_digits, decimal_places
+ super(DecimalField, self).__init__(*args, **kwargs)
+
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+
+ def from_native(self, value):
+ """
+ Validates that the input is a decimal number. Returns a Decimal
+ instance. Returns None for empty values. Ensures that there are no more
+ than max_digits in the number, and no more than decimal_places digits
+ after the decimal point.
+ """
+ if value in validators.EMPTY_VALUES:
+ return None
+ value = smart_text(value).strip()
+ try:
+ value = Decimal(value)
+ except DecimalException:
+ raise ValidationError(self.error_messages['invalid'])
+ return value
+
+ def validate(self, value):
+ super(DecimalField, self).validate(value)
+ if value in validators.EMPTY_VALUES:
+ return
+ # Check for NaN, Inf and -Inf values. We can't compare directly for NaN,
+ # since it is never equal to itself. However, NaN is the only value that
+ # isn't equal to itself, so we can use this to identify NaN
+ if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
+ raise ValidationError(self.error_messages['invalid'])
+ sign, digittuple, exponent = value.as_tuple()
+ decimals = abs(exponent)
+ # digittuple doesn't include any leading zeros.
+ digits = len(digittuple)
+ if decimals > digits:
+ # We have leading zeros up to or past the decimal point. Count
+ # everything past the decimal point as a digit. We do not count
+ # 0 before the decimal point as a digit since that would mean
+ # we would not allow max_digits = decimal_places.
+ digits = decimals
+ whole_digits = digits - decimals
+
+ if self.max_digits is not None and digits > self.max_digits:
+ raise ValidationError(self.error_messages['max_digits'] % self.max_digits)
+ if self.decimal_places is not None and decimals > self.decimal_places:
+ raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places)
+ if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
+ raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places))
+ return value
+
+
class FileField(WritableField):
use_files = True
type_name = 'FileField'
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index 6fea46fa..c058bc71 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -1,5 +1,12 @@
+"""
+Provides generic filtering backends that can be used to filter the results
+returned by list views.
+"""
from __future__ import unicode_literals
-from rest_framework.compat import django_filters
+from django.db import models
+from rest_framework.compat import django_filters, six
+from functools import reduce
+import operator
FilterSet = django_filters and django_filters.FilterSet or None
@@ -25,36 +32,112 @@ class DjangoFilterBackend(BaseFilterBackend):
def __init__(self):
assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed'
- def get_filter_class(self, view):
+ def get_filter_class(self, view, queryset=None):
"""
Return the django-filters `FilterSet` used to filter the queryset.
"""
filter_class = getattr(view, 'filter_class', None)
filter_fields = getattr(view, 'filter_fields', None)
- view_model = getattr(view, 'model', None)
if filter_class:
filter_model = filter_class.Meta.model
- assert issubclass(filter_model, view_model), \
- 'FilterSet model %s does not match view model %s' % \
- (filter_model, view_model)
+ assert issubclass(filter_model, queryset.model), \
+ 'FilterSet model %s does not match queryset model %s' % \
+ (filter_model, queryset.model)
return filter_class
if filter_fields:
class AutoFilterSet(self.default_filter_set):
class Meta:
- model = view_model
+ model = queryset.model
fields = filter_fields
return AutoFilterSet
return None
def filter_queryset(self, request, queryset, view):
- filter_class = self.get_filter_class(view)
+ filter_class = self.get_filter_class(view, queryset)
if filter_class:
- return filter_class(request.QUERY_PARAMS, queryset=queryset)
+ return filter_class(request.QUERY_PARAMS, queryset=queryset).qs
+
+ return queryset
+
+
+class SearchFilter(BaseFilterBackend):
+ search_param = 'search' # The URL query parameter used for the search.
+
+ def get_search_terms(self, request):
+ """
+ Search terms are set by a ?search=... query parameter,
+ and may be comma and/or whitespace delimited.
+ """
+ params = request.QUERY_PARAMS.get(self.search_param, '')
+ return params.replace(',', ' ').split()
+
+ def construct_search(self, field_name):
+ if field_name.startswith('^'):
+ return "%s__istartswith" % field_name[1:]
+ elif field_name.startswith('='):
+ return "%s__iexact" % field_name[1:]
+ elif field_name.startswith('@'):
+ return "%s__search" % field_name[1:]
+ else:
+ return "%s__icontains" % field_name
+
+ def filter_queryset(self, request, queryset, view):
+ search_fields = getattr(view, 'search_fields', None)
+
+ if not search_fields:
+ return queryset
+
+ orm_lookups = [self.construct_search(str(search_field))
+ for search_field in search_fields]
+
+ for search_term in self.get_search_terms(request):
+ or_queries = [models.Q(**{orm_lookup: search_term})
+ for orm_lookup in orm_lookups]
+ queryset = queryset.filter(reduce(operator.or_, or_queries))
+
+ return queryset
+
+
+class OrderingFilter(BaseFilterBackend):
+ ordering_param = 'ordering' # The URL query parameter used for the ordering.
+
+ def get_ordering(self, request):
+ """
+ Search terms are set by a ?search=... query parameter,
+ and may be comma and/or whitespace delimited.
+ """
+ params = request.QUERY_PARAMS.get(self.ordering_param)
+ if params:
+ return [param.strip() for param in params.split(',')]
+
+ def get_default_ordering(self, view):
+ ordering = getattr(view, 'ordering', None)
+ if isinstance(ordering, six.string_types):
+ return (ordering,)
+ return ordering
+
+ def remove_invalid_fields(self, queryset, ordering):
+ field_names = [field.name for field in queryset.model._meta.fields]
+ return [term for term in ordering if term.lstrip('-') in field_names]
+
+ def filter_queryset(self, request, queryset, view):
+ ordering = self.get_ordering(request)
+
+ if ordering:
+ # Skip any incorrect parameters
+ ordering = self.remove_invalid_fields(queryset, ordering)
+
+ if not ordering:
+ # Use 'ordering' attribtue by default
+ ordering = self.get_default_ordering(view)
+
+ if ordering:
+ return queryset.order_by(*ordering)
return queryset
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 36ecf915..05ec93d3 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -2,32 +2,59 @@
Generic views that provide commonly needed behaviour.
"""
from __future__ import unicode_literals
+
+from django.core.exceptions import ImproperlyConfigured
+from django.core.paginator import Paginator, InvalidPage
+from django.http import Http404
+from django.shortcuts import get_object_or_404
+from django.utils.translation import ugettext as _
from rest_framework import views, mixins
+from rest_framework.exceptions import ConfigurationError
from rest_framework.settings import api_settings
-from django.views.generic.detail import SingleObjectMixin
-from django.views.generic.list import MultipleObjectMixin
-
+import warnings
-### Base classes for the generic views ###
class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
- model = None
+ # You'll need to either set these attributes,
+ # or override `get_queryset()`/`get_serializer_class()`.
+ queryset = None
serializer_class = None
+
+ # This shortcut may be used instead of setting either or both
+ # of the `queryset`/`serializer_class` attributes, although using
+ # the explicit style is generally preferred.
+ model = None
+
+ # If you want to use object lookups other than pk, set this attribute.
+ # For more complex lookup requirements override `get_object()`.
+ lookup_field = 'pk'
+
+ # Pagination settings
+ paginate_by = api_settings.PAGINATE_BY
+ paginate_by_param = api_settings.PAGINATE_BY_PARAM
+ pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
+ page_kwarg = 'page'
+
+ # The filter backend classes to use for queryset filtering
+ filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
+
+ # The following attributes may be subject to change,
+ # and should be considered private API.
model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
- filter_backend = api_settings.FILTER_BACKEND
+ paginator_class = Paginator
- def filter_queryset(self, queryset):
- """
- Given a queryset, filter it with whichever filter backend is in use.
- """
- if not self.filter_backend:
- return queryset
- backend = self.filter_backend()
- return backend.filter_queryset(self.request, queryset, self)
+ ######################################
+ # These are pending deprecation...
+
+ pk_url_kwarg = 'pk'
+ slug_url_kwarg = 'slug'
+ slug_field = 'slug'
+ allow_empty = True
+ filter_backend = api_settings.FILTER_BACKEND
def get_serializer_context(self):
"""
@@ -39,24 +66,6 @@ class GenericAPIView(views.APIView):
'view': self
}
- def get_serializer_class(self):
- """
- Return the class to use for the serializer.
-
- Defaults to using `self.serializer_class`, falls back to constructing a
- model serializer class using `self.model_serializer_class`, with
- `self.model` as the model.
- """
- serializer_class = self.serializer_class
-
- if serializer_class is None:
- class DefaultSerializer(self.model_serializer_class):
- class Meta:
- model = self.model
- serializer_class = DefaultSerializer
-
- return serializer_class
-
def get_serializer(self, instance=None, data=None,
files=None, many=False, partial=False):
"""
@@ -68,31 +77,7 @@ class GenericAPIView(views.APIView):
return serializer_class(instance, data=data, files=files,
many=many, partial=partial, context=context)
- def pre_save(self, obj):
- """
- Placeholder method for calling before saving an object.
- May be used eg. to set attributes on the object that are implicit
- in either the request, or the url.
- """
- pass
-
- def post_save(self, obj, created=False):
- """
- Placeholder method for calling after saving an object.
- """
- pass
-
-
-class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
- """
- Base class for generic views onto a queryset.
- """
-
- paginate_by = api_settings.PAGINATE_BY
- paginate_by_param = api_settings.PAGINATE_BY_PARAM
- pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
-
- def get_pagination_serializer(self, page=None):
+ def get_pagination_serializer(self, page):
"""
Return a serializer instance to use with paginated data.
"""
@@ -104,40 +89,232 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
context = self.get_serializer_context()
return pagination_serializer_class(instance=page, context=context)
- def get_paginate_by(self, queryset):
+ def paginate_queryset(self, queryset, page_size=None):
+ """
+ Paginate a queryset if required, either returning a page object,
+ or `None` if pagination is not configured for this view.
+ """
+ deprecated_style = False
+ if page_size is not None:
+ warnings.warn('The `page_size` parameter to `paginate_queryset()` '
+ 'is due to be deprecated. '
+ 'Note that the return style of this method is also '
+ 'changed, and will simply return a page object '
+ 'when called without a `page_size` argument.',
+ PendingDeprecationWarning, stacklevel=2)
+ deprecated_style = True
+ else:
+ # Determine the required page size.
+ # If pagination is not configured, simply return None.
+ page_size = self.get_paginate_by()
+ if not page_size:
+ return None
+
+ if not self.allow_empty:
+ warnings.warn(
+ 'The `allow_empty` parameter is due to be deprecated. '
+ 'To use `allow_empty=False` style behavior, You should override '
+ '`get_queryset()` and explicitly raise a 404 on empty querysets.',
+ PendingDeprecationWarning, stacklevel=2
+ )
+
+ paginator = self.paginator_class(queryset, page_size,
+ allow_empty_first_page=self.allow_empty)
+ page_kwarg = self.kwargs.get(self.page_kwarg)
+ page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
+ page = page_kwarg or page_query_param or 1
+ try:
+ page_number = int(page)
+ except ValueError:
+ if page == 'last':
+ page_number = paginator.num_pages
+ else:
+ raise Http404(_("Page is not 'last', nor can it be converted to an int."))
+ try:
+ page = paginator.page(page_number)
+ except InvalidPage as e:
+ raise Http404(_('Invalid page (%(page_number)s): %(message)s') % {
+ 'page_number': page_number,
+ 'message': str(e)
+ })
+
+ if deprecated_style:
+ return (paginator, page, page.object_list, page.has_other_pages())
+ return page
+
+ def filter_queryset(self, queryset):
+ """
+ Given a queryset, filter it with whichever filter backend is in use.
+
+ You are unlikely to want to override this method, although you may need
+ to call it either from a list view, or from a custom `get_object`
+ method if you want to apply the configured filtering backend to the
+ default queryset.
+ """
+ filter_backends = self.filter_backends or []
+ if not filter_backends and self.filter_backend:
+ warnings.warn(
+ 'The `filter_backend` attribute and `FILTER_BACKEND` setting '
+ 'are due to be deprecated in favor of a `filter_backends` '
+ 'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take '
+ 'a *list* of filter backend classes.',
+ PendingDeprecationWarning, stacklevel=2
+ )
+ filter_backends = [self.filter_backend]
+
+ for backend in filter_backends:
+ queryset = backend().filter_queryset(self.request, queryset, self)
+ return queryset
+
+ ########################
+ ### The following methods provide default implementations
+ ### that you may want to override for more complex cases.
+
+ def get_paginate_by(self, queryset=None):
"""
Return the size of pages to use with pagination.
+
+ If `PAGINATE_BY_PARAM` is set it will attempt to get the page size
+ from a named query parameter in the url, eg. ?page_size=100
+
+ Otherwise defaults to using `self.paginate_by`.
"""
+ if queryset is not None:
+ warnings.warn('The `queryset` parameter to `get_paginate_by()` '
+ 'is due to be deprecated.',
+ PendingDeprecationWarning, stacklevel=2)
+
if self.paginate_by_param:
query_params = self.request.QUERY_PARAMS
try:
return int(query_params[self.paginate_by_param])
except (KeyError, ValueError):
pass
+
return self.paginate_by
+ def get_serializer_class(self):
+ """
+ Return the class to use for the serializer.
+ Defaults to using `self.serializer_class`.
+
+ You may want to override this if you need to provide different
+ serializations depending on the incoming request.
-class SingleObjectAPIView(SingleObjectMixin, GenericAPIView):
- """
- Base class for generic views onto a model instance.
- """
+ (Eg. admins get full serialization, others get basic serilization)
+ """
+ serializer_class = self.serializer_class
+ if serializer_class is not None:
+ return serializer_class
- pk_url_kwarg = 'pk' # Not provided in Django 1.3
- slug_url_kwarg = 'slug' # Not provided in Django 1.3
- slug_field = 'slug'
+ assert self.model is not None, \
+ "'%s' should either include a 'serializer_class' attribute, " \
+ "or use the 'model' attribute as a shortcut for " \
+ "automatically generating a serializer class." \
+ % self.__class__.__name__
+
+ class DefaultSerializer(self.model_serializer_class):
+ class Meta:
+ model = self.model
+ return DefaultSerializer
+
+ def get_queryset(self):
+ """
+ Get the list of items for this view.
+ This must be an iterable, and may be a queryset.
+ Defaults to using `self.queryset`.
+
+ You may want to override this if you need to provide different
+ querysets depending on the incoming request.
+
+ (Eg. return a list of items that is specific to the user)
+ """
+ if self.queryset is not None:
+ return self.queryset._clone()
+
+ if self.model is not None:
+ return self.model._default_manager.all()
+
+ raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'"
+ % self.__class__.__name__)
def get_object(self, queryset=None):
"""
- Override default to add support for object-level permissions.
+ Returns the object the view is displaying.
+
+ You may want to override this if you need to provide non-standard
+ queryset lookups. Eg if objects are referenced using multiple
+ keyword arguments in the url conf.
"""
- obj = super(SingleObjectAPIView, self).get_object(queryset)
+ # Determine the base queryset to use.
+ if queryset is None:
+ queryset = self.filter_queryset(self.get_queryset())
+ else:
+ pass # Deprecation warning
+
+ # Perform the lookup filtering.
+ pk = self.kwargs.get(self.pk_url_kwarg, None)
+ slug = self.kwargs.get(self.slug_url_kwarg, None)
+ lookup = self.kwargs.get(self.lookup_field, None)
+
+ if lookup is not None:
+ filter_kwargs = {self.lookup_field: lookup}
+ elif pk is not None and self.lookup_field == 'pk':
+ warnings.warn(
+ 'The `pk_url_kwarg` attribute is due to be deprecated. '
+ 'Use the `lookup_field` attribute instead',
+ PendingDeprecationWarning
+ )
+ filter_kwargs = {'pk': pk}
+ elif slug is not None and self.lookup_field == 'pk':
+ warnings.warn(
+ 'The `slug_url_kwarg` attribute is due to be deprecated. '
+ 'Use the `lookup_field` attribute instead',
+ PendingDeprecationWarning
+ )
+ filter_kwargs = {self.slug_field: slug}
+ else:
+ raise ConfigurationError(
+ 'Expected view %s to be called with a URL keyword argument '
+ 'named "%s". Fix your URL conf, or set the `.lookup_field` '
+ 'attribute on the view correctly.' %
+ (self.__class__.__name__, self.lookup_field)
+ )
+
+ obj = get_object_or_404(queryset, **filter_kwargs)
+
+ # May raise a permission denied
self.check_object_permissions(self.request, obj)
+
return obj
+ ########################
+ ### The following are placeholder methods,
+ ### and are intended to be overridden.
+ ###
+ ### The are not called by GenericAPIView directly,
+ ### but are used by the mixin methods.
+
+ def pre_save(self, obj):
+ """
+ Placeholder method for calling before saving an object.
+
+ May be used to set attributes on the object that are implicit
+ in either the request, or the url.
+ """
+ pass
+
+ def post_save(self, obj, created=False):
+ """
+ Placeholder method for calling after saving an object.
+ """
+ pass
-### Concrete view classes that provide method handlers ###
-### by composing the mixin classes with a base view. ###
+##########################################################
+### Concrete view classes that provide method handlers ###
+### by composing the mixin classes with the base view. ###
+##########################################################
class CreateAPIView(mixins.CreateModelMixin,
GenericAPIView):
@@ -150,7 +327,7 @@ class CreateAPIView(mixins.CreateModelMixin,
class ListAPIView(mixins.ListModelMixin,
- MultipleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for listing a queryset.
"""
@@ -159,7 +336,7 @@ class ListAPIView(mixins.ListModelMixin,
class RetrieveAPIView(mixins.RetrieveModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for retrieving a model instance.
"""
@@ -168,7 +345,7 @@ class RetrieveAPIView(mixins.RetrieveModelMixin,
class DestroyAPIView(mixins.DestroyModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for deleting a model instance.
@@ -178,7 +355,7 @@ class DestroyAPIView(mixins.DestroyModelMixin,
class UpdateAPIView(mixins.UpdateModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for updating a model instance.
@@ -187,13 +364,12 @@ class UpdateAPIView(mixins.UpdateModelMixin,
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
- kwargs['partial'] = True
- return self.update(request, *args, **kwargs)
+ return self.partial_update(request, *args, **kwargs)
class ListCreateAPIView(mixins.ListModelMixin,
mixins.CreateModelMixin,
- MultipleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for listing a queryset or creating a model instance.
"""
@@ -206,7 +382,7 @@ class ListCreateAPIView(mixins.ListModelMixin,
class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for retrieving, updating a model instance.
"""
@@ -217,13 +393,12 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
- kwargs['partial'] = True
- return self.update(request, *args, **kwargs)
+ return self.partial_update(request, *args, **kwargs)
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for retrieving or deleting a model instance.
"""
@@ -237,7 +412,7 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for retrieving, updating or deleting a model instance.
"""
@@ -248,8 +423,31 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
- kwargs['partial'] = True
- return self.update(request, *args, **kwargs)
+ return self.partial_update(request, *args, **kwargs)
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
+
+
+##########################
+### Deprecated classes ###
+##########################
+
+class MultipleObjectAPIView(GenericAPIView):
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ 'Subclassing `MultipleObjectAPIView` is due to be deprecated. '
+ 'You should simply subclass `GenericAPIView` instead.',
+ PendingDeprecationWarning, stacklevel=2
+ )
+ super(MultipleObjectAPIView, self).__init__(*args, **kwargs)
+
+
+class SingleObjectAPIView(GenericAPIView):
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ 'Subclassing `SingleObjectAPIView` is due to be deprecated. '
+ 'You should simply subclass `GenericAPIView` instead.',
+ PendingDeprecationWarning, stacklevel=2
+ )
+ super(SingleObjectAPIView, self).__init__(*args, **kwargs)
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 7d9a6e65..f3cd5868 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -10,9 +10,10 @@ from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
from rest_framework.request import clone_request
+import warnings
-def _get_validation_exclusions(obj, pk=None, slug_field=None):
+def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None):
"""
Given a model instance, and an optional pk and slug field,
return the full list of all other field names on that model.
@@ -23,21 +24,25 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None):
include = []
if pk:
+ # Pending deprecation
pk_field = obj._meta.pk
while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk
include.append(pk_field.name)
if slug_field:
+ # Pending deprecation
include.append(slug_field)
+ if lookup_field and lookup_field != 'pk':
+ include.append(lookup_field)
+
return [field.name for field in obj._meta.fields if field.name not in include]
class CreateModelMixin(object):
"""
Create a model instance.
- Should be mixed in with any `GenericAPIView`.
"""
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
@@ -62,28 +67,28 @@ class CreateModelMixin(object):
class ListModelMixin(object):
"""
List a queryset.
- Should be mixed in with `MultipleObjectAPIView`.
"""
empty_error = "Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs):
- queryset = self.get_queryset()
- self.object_list = self.filter_queryset(queryset)
+ self.object_list = self.filter_queryset(self.get_queryset())
# Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets.
- allow_empty = self.get_allow_empty()
- if not allow_empty and not self.object_list:
+ if not self.allow_empty and not self.object_list:
+ warnings.warn(
+ 'The `allow_empty` parameter is due to be deprecated. '
+ 'To use `allow_empty=False` style behavior, You should override '
+ '`get_queryset()` and explicitly raise a 404 on empty querysets.',
+ PendingDeprecationWarning
+ )
class_name = self.__class__.__name__
error_msg = self.empty_error % {'class_name': class_name}
raise Http404(error_msg)
- # Pagination size is set by the `.paginate_by` attribute,
- # which may be `None` to disable pagination.
- page_size = self.get_paginate_by(self.object_list)
- if page_size:
- packed = self.paginate_queryset(self.object_list, page_size)
- paginator, page, queryset, is_paginated = packed
+ # Switch between paginated or standard style responses
+ page = self.paginate_queryset(self.object_list)
+ if page is not None:
serializer = self.get_pagination_serializer(page)
else:
serializer = self.get_serializer(self.object_list, many=True)
@@ -94,12 +99,9 @@ class ListModelMixin(object):
class RetrieveModelMixin(object):
"""
Retrieve a model instance.
- Should be mixed in with `SingleObjectAPIView`.
"""
def retrieve(self, request, *args, **kwargs):
- queryset = self.get_queryset()
- filtered_queryset = self.filter_queryset(queryset)
- self.object = self.get_object(filtered_queryset)
+ self.object = self.get_object()
serializer = self.get_serializer(self.object)
return Response(serializer.data)
@@ -107,17 +109,22 @@ class RetrieveModelMixin(object):
class UpdateModelMixin(object):
"""
Update a model instance.
- Should be mixed in with `SingleObjectAPIView`.
"""
- def update(self, request, *args, **kwargs):
- partial = kwargs.pop('partial', False)
- self.object = None
+ def get_object_or_none(self):
try:
- self.object = self.get_object()
+ return self.get_object()
except Http404:
# If this is a PUT-as-create operation, we need to ensure that
# we have relevant permissions, as if this was a POST request.
- self.check_permissions(clone_request(request, 'POST'))
+ # This will either raise a PermissionDenied exception,
+ # or simply return None
+ self.check_permissions(clone_request(self.request, 'POST'))
+
+ def update(self, request, *args, **kwargs):
+ partial = kwargs.pop('partial', False)
+ self.object = self.get_object_or_none()
+
+ if self.object is None:
created = True
save_kwargs = {'force_insert': True}
success_status_code = status.HTTP_201_CREATED
@@ -137,14 +144,22 @@ class UpdateModelMixin(object):
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ def partial_update(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
def pre_save(self, obj):
"""
Set any attributes on the object that are implicit in the request.
"""
# pk and/or slug attributes are implicit in the URL.
+ lookup = self.kwargs.get(self.lookup_field, None)
pk = self.kwargs.get(self.pk_url_kwarg, None)
slug = self.kwargs.get(self.slug_url_kwarg, None)
- slug_field = slug and self.get_slug_field() or None
+ slug_field = slug and self.slug_field or None
+
+ if lookup:
+ setattr(obj, self.lookup_field, lookup)
if pk:
setattr(obj, 'pk', pk)
@@ -155,14 +170,13 @@ class UpdateModelMixin(object):
# Ensure we clean the attributes so that we don't eg return integer
# pk using a string representation, as provided by the url conf kwarg.
if hasattr(obj, 'full_clean'):
- exclude = _get_validation_exclusions(obj, pk, slug_field)
+ exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field)
obj.full_clean(exclude)
class DestroyModelMixin(object):
"""
Destroy a model instance.
- Should be mixed in with `SingleObjectAPIView`.
"""
def destroy(self, request, *args, **kwargs):
obj = self.get_object()
diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py
index 0694d35f..4d205c0e 100644
--- a/rest_framework/negotiation.py
+++ b/rest_framework/negotiation.py
@@ -1,3 +1,7 @@
+"""
+Content negotiation deals with selecting an appropriate renderer given the
+incoming request. Typically this will be based on the request's Accept header.
+"""
from __future__ import unicode_literals
from django.http import Http404
from rest_framework import exceptions
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 03a7a30f..d51ea929 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -1,9 +1,11 @@
+"""
+Pagination serializers determine the structure of the output that should
+be used for paginated responses.
+"""
from __future__ import unicode_literals
from rest_framework import serializers
from rest_framework.templatetags.rest_framework import replace_query_param
-# TODO: Support URLconf kwarg-style paging
-
class NextPageField(serializers.Field):
"""
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 491acd68..25be2e6a 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -6,9 +6,10 @@ on the request, such as form content or json encoded data.
"""
from __future__ import unicode_literals
from django.conf import settings
+from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
-from django.http.multipartparser import MultiPartParserError
+from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter
from rest_framework.compat import yaml, etree
from rest_framework.exceptions import ParseError
from rest_framework.compat import six
@@ -205,3 +206,90 @@ class XMLParser(BaseParser):
pass
return value
+
+
+class FileUploadParser(BaseParser):
+ """
+ Parser for file upload data.
+ """
+ media_type = '*/*'
+
+ def parse(self, stream, media_type=None, parser_context=None):
+ """
+ Returns a DataAndFiles object.
+
+ `.data` will be None (we expect request body to be a file content).
+ `.files` will be a `QueryDict` containing one 'file' element.
+ """
+
+ parser_context = parser_context or {}
+ request = parser_context['request']
+ encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
+ meta = request.META
+ upload_handlers = request.upload_handlers
+ filename = self.get_filename(stream, media_type, parser_context)
+
+ # Note that this code is extracted from Django's handling of
+ # file uploads in MultiPartParser.
+ content_type = meta.get('HTTP_CONTENT_TYPE',
+ meta.get('CONTENT_TYPE', ''))
+ try:
+ content_length = int(meta.get('HTTP_CONTENT_LENGTH',
+ meta.get('CONTENT_LENGTH', 0)))
+ except (ValueError, TypeError):
+ content_length = None
+
+ # See if the handler will want to take care of the parsing.
+ for handler in upload_handlers:
+ result = handler.handle_raw_input(None,
+ meta,
+ content_length,
+ None,
+ encoding)
+ if result is not None:
+ return DataAndFiles(None, {'file': result[1]})
+
+ # This is the standard case.
+ possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size]
+ chunk_size = min([2 ** 31 - 4] + possible_sizes)
+ chunks = ChunkIter(stream, chunk_size)
+ counters = [0] * len(upload_handlers)
+
+ for handler in upload_handlers:
+ try:
+ handler.new_file(None, filename, content_type,
+ content_length, encoding)
+ except StopFutureHandlers:
+ break
+
+ for chunk in chunks:
+ for i, handler in enumerate(upload_handlers):
+ chunk_length = len(chunk)
+ chunk = handler.receive_data_chunk(chunk, counters[i])
+ counters[i] += chunk_length
+ if chunk is None:
+ break
+
+ for i, handler in enumerate(upload_handlers):
+ file_obj = handler.file_complete(counters[i])
+ if file_obj:
+ return DataAndFiles(None, {'file': file_obj})
+ raise ParseError("FileUpload parse error - "
+ "none of upload handlers can handle the stream")
+
+ def get_filename(self, stream, media_type, parser_context):
+ """
+ Detects the uploaded file name. First searches a 'filename' url kwarg.
+ Then tries to parse Content-Disposition header.
+ """
+ try:
+ return parser_context['kwargs']['filename']
+ except KeyError:
+ pass
+
+ try:
+ meta = parser_context['request'].META
+ disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'])
+ return disposition[1]['filename']
+ except (AttributeError, KeyError):
+ pass
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index ae895f39..45fcfd66 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -25,10 +25,12 @@ class BasePermission(object):
"""
Return `True` if permission is granted, `False` otherwise.
"""
- if len(inspect.getargspec(self.has_permission)[0]) == 4:
- warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. '
- 'Use `has_object_permission()` instead for object permissions.',
- PendingDeprecationWarning, stacklevel=2)
+ if len(inspect.getargspec(self.has_permission).args) == 4:
+ warnings.warn(
+ 'The `obj` argument in `has_permission` is deprecated. '
+ 'Use `has_object_permission()` instead for object permissions.',
+ DeprecationWarning, stacklevel=2
+ )
return self.has_permission(request, view, obj)
return True
@@ -87,8 +89,8 @@ class DjangoModelPermissions(BasePermission):
It ensures that the user is authenticated, and has the appropriate
`add`/`change`/`delete` permissions on the model.
- This permission will only be applied against view classes that
- provide a `.model` attribute, such as the generic class-based views.
+ This permission can only be applied against view classes that
+ provide a `.model` or `.queryset` attribute.
"""
# Map methods into required permission codes.
@@ -124,6 +126,11 @@ class DjangoModelPermissions(BasePermission):
if model_cls is None and queryset is not None:
model_cls = queryset.model
+ # Workaround to ensure DjangoModelPermissions are not applied
+ # to the root view when using DefaultRouter.
+ if model_cls is None and getattr(view, '_ignore_model_permissions'):
+ return True
+
assert model_cls, ('Cannot apply DjangoModelPermissions on a view that'
' does not have `.model` or `.queryset` property.')
@@ -136,6 +143,14 @@ class DjangoModelPermissions(BasePermission):
return False
+class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
+ """
+ Similar to DjangoModelPermissions, except that anonymous users are
+ allowed read-only access.
+ """
+ authenticated_users_only = False
+
+
class TokenHasReadWriteScope(BasePermission):
"""
The request is authenticated as a user and the token used has the right scope
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 2a10e9af..884b954c 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -1,3 +1,9 @@
+"""
+Serializer fields that deal with relationships.
+
+These fields allow you to specify the style that should be used to represent
+model relationships, including hyperlinks, primary keys, or slugs.
+"""
from __future__ import unicode_literals
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
@@ -36,9 +42,9 @@ class RelatedField(WritableField):
# 'null' is to be deprecated in favor of 'required'
if 'null' in kwargs:
- warnings.warn('The `null` keyword argument is due to be deprecated. '
+ warnings.warn('The `null` keyword argument is deprecated. '
'Use the `required` keyword argument instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['required'] = not kwargs.pop('null')
self.queryset = kwargs.pop('queryset', None)
@@ -215,12 +221,20 @@ class PrimaryKeyRelatedField(RelatedField):
def field_to_native(self, obj, field_name):
if self.many:
# To-many relationship
- try:
+
+ queryset = None
+ if not self.source:
# Prefer obj.serializable_value for performance reasons
- queryset = obj.serializable_value(self.source or field_name)
- except AttributeError:
+ try:
+ queryset = obj.serializable_value(field_name)
+ except AttributeError:
+ pass
+ if queryset is None:
# RelatedManager (reverse relationship)
- queryset = getattr(obj, self.source or field_name)
+ source = self.source or field_name
+ queryset = obj
+ for component in source.split('.'):
+ queryset = get_component(queryset, component)
# Forward relationship
return [self.to_native(item.pk) for item in queryset.all()]
@@ -282,10 +296,8 @@ class HyperlinkedRelatedField(RelatedField):
"""
Represents a relationship using hyperlinking.
"""
- pk_url_kwarg = 'pk'
- slug_field = 'slug'
- slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
read_only = False
+ lookup_field = 'pk'
default_error_messages = {
'no_match': _('Invalid hyperlink - No URL match'),
@@ -295,69 +307,138 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
}
+ # These are all pending deprecation
+ pk_url_kwarg = 'pk'
+ slug_field = 'slug'
+ slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
+
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
except KeyError:
raise ValueError("Hyperlinked field requires 'view_name' kwarg")
+ self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
+ self.format = kwargs.pop('format', None)
+
+ # These are pending deprecation
+ if 'pk_url_kwarg' in kwargs:
+ msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_url_kwarg' in kwargs:
+ msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_field' in kwargs:
+ msg = 'slug_field is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+
+ self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
self.slug_field = kwargs.pop('slug_field', self.slug_field)
default_slug_kwarg = self.slug_url_kwarg or self.slug_field
- self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
- self.format = kwargs.pop('format', None)
super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
- def get_slug_field(self):
+ def get_url(self, obj, view_name, request, format):
"""
- Get the name of a slug field to be used to look up by slug.
- """
- return self.slug_field
-
- def to_native(self, obj):
- view_name = self.view_name
- request = self.context.get('request', None)
- format = self.format or self.context.get('format', None)
-
- if request is None:
- warnings.warn("Using `HyperlinkedRelatedField` without including the "
- "request in the serializer context is due to be deprecated. "
- "Add `context={'request': request}` when instantiating the serializer.",
- PendingDeprecationWarning, stacklevel=4)
+ Given an object, return the URL that hyperlinks to the object.
- pk = getattr(obj, 'pk', None)
- if pk is None:
- return
- kwargs = {self.pk_url_kwarg: pk}
+ May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
+ attributes are not configured to correctly match the URL conf.
+ """
+ lookup_field = getattr(obj, self.lookup_field)
+ kwargs = {self.lookup_field: lookup_field}
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch:
pass
+ if self.pk_url_kwarg != 'pk':
+ # Only try pk if it has been explicitly set.
+ # Otherwise, the default `lookup_field = 'pk'` has us covered.
+ pk = obj.pk
+ kwargs = {self.pk_url_kwarg: pk}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except NoReverseMatch:
+ pass
+
slug = getattr(obj, self.slug_field, None)
+ if slug is not None:
+ # Only try slug if it corresponds to an attribute on the object.
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ ret = reverse(view_name, kwargs=kwargs, request=request, format=format)
+ if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug':
+ # If the lookup succeeds using the default slug params,
+ # then `slug_field` is being used implicitly, and we
+ # we need to warn about the pending deprecation.
+ msg = 'Implicit slug field hyperlinked fields are pending deprecation.' \
+ 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ return ret
+ except NoReverseMatch:
+ pass
+
+ raise NoReverseMatch()
+
+ def get_object(self, queryset, view_name, view_args, view_kwargs):
+ """
+ Return the object corresponding to a matched URL.
- if not slug:
- raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+ Takes the matched URL conf arguments, and the queryset, and should
+ return an object instance, or raise an `ObjectDoesNotExist` exception.
+ """
+ lookup = view_kwargs.get(self.lookup_field, None)
+ pk = view_kwargs.get(self.pk_url_kwarg, None)
+ slug = view_kwargs.get(self.slug_url_kwarg, None)
+
+ if lookup is not None:
+ filter_kwargs = {self.lookup_field: lookup}
+ elif pk is not None:
+ filter_kwargs = {'pk': pk}
+ elif slug is not None:
+ filter_kwargs = {self.slug_field: slug}
+ else:
+ raise ObjectDoesNotExist()
- kwargs = {self.slug_url_kwarg: slug}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
+ return queryset.get(**filter_kwargs)
- kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
+ def to_native(self, obj):
+ view_name = self.view_name
+ request = self.context.get('request', None)
+ format = self.format or self.context.get('format', None)
+
+ if request is None:
+ msg = (
+ "Using `HyperlinkedRelatedField` without including the request "
+ "in the serializer context is deprecated. "
+ "Add `context={'request': request}` when instantiating "
+ "the serializer."
+ )
+ warnings.warn(msg, DeprecationWarning, stacklevel=4)
+
+ # If the object has not yet been saved then we cannot hyperlink to it.
+ if getattr(obj, 'pk', None) is None:
+ return
+
+ # Return the hyperlink, or error if incorrectly configured.
try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ return self.get_url(obj, view_name, request, format)
except NoReverseMatch:
- pass
-
- raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+ msg = (
+ 'Could not resolve URL for hyperlinked relationship using '
+ 'view name "%s". You may have failed to include the related '
+ 'model in your API, or incorrectly configured the '
+ '`lookup_field` attribute on this field.'
+ )
+ raise Exception(msg % view_name)
def from_native(self, value):
# Convert URL -> model instance pk
# TODO: Use values_list
- if self.queryset is None:
+ queryset = self.queryset
+ if queryset is None:
raise Exception('Writable related fields must include a `queryset` argument')
try:
@@ -381,39 +462,24 @@ class HyperlinkedRelatedField(RelatedField):
if match.view_name != self.view_name:
raise ValidationError(self.error_messages['incorrect_match'])
- pk = match.kwargs.get(self.pk_url_kwarg, None)
- slug = match.kwargs.get(self.slug_url_kwarg, None)
-
- # Try explicit primary key.
- if pk is not None:
- queryset = self.queryset.filter(pk=pk)
- # Next, try looking up by slug.
- elif slug is not None:
- slug_field = self.get_slug_field()
- queryset = self.queryset.filter(**{slug_field: slug})
- # If none of those are defined, it's probably a configuation error.
- else:
- raise ValidationError(self.error_messages['configuration_error'])
-
try:
- obj = queryset.get()
- except ObjectDoesNotExist:
+ return self.get_object(queryset, match.view_name,
+ match.args, match.kwargs)
+ except (ObjectDoesNotExist, TypeError, ValueError):
raise ValidationError(self.error_messages['does_not_exist'])
- except (TypeError, ValueError):
- msg = self.error_messages['incorrect_type']
- raise ValidationError(msg % type(value).__name__)
-
- return obj
class HyperlinkedIdentityField(Field):
"""
Represents the instance, or a property on the instance, using hyperlinking.
"""
+ lookup_field = 'pk'
+ read_only = True
+
+ # These are all pending deprecation
pk_url_kwarg = 'pk'
slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
- read_only = True
def __init__(self, *args, **kwargs):
# TODO: Make view_name mandatory, and have the
@@ -422,6 +488,19 @@ class HyperlinkedIdentityField(Field):
# Optionally the format of the target hyperlink may be specified
self.format = kwargs.pop('format', None)
+ self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
+
+ # These are pending deprecation
+ if 'pk_url_kwarg' in kwargs:
+ msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_url_kwarg' in kwargs:
+ msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_field' in kwargs:
+ msg = 'slug_field is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+
self.slug_field = kwargs.pop('slug_field', self.slug_field)
default_slug_kwarg = self.slug_url_kwarg or self.slug_field
self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
@@ -433,13 +512,14 @@ class HyperlinkedIdentityField(Field):
request = self.context.get('request', None)
format = self.context.get('format', None)
view_name = self.view_name or self.parent.opts.view_name
- kwargs = {self.pk_url_kwarg: obj.pk}
+ lookup_field = getattr(obj, self.lookup_field)
+ kwargs = {self.lookup_field: lookup_field}
if request is None:
warnings.warn("Using `HyperlinkedIdentityField` without including the "
- "request in the serializer context is due to be deprecated. "
+ "request in the serializer context is deprecated. "
"Add `context={'request': request}` when instantiating the serializer.",
- PendingDeprecationWarning, stacklevel=4)
+ DeprecationWarning, stacklevel=4)
# By default use whatever format is given for the current context
# unless the target is a different type to the source.
@@ -482,35 +562,35 @@ class HyperlinkedIdentityField(Field):
class ManyRelatedField(RelatedField):
def __init__(self, *args, **kwargs):
- warnings.warn('`ManyRelatedField()` is due to be deprecated. '
+ warnings.warn('`ManyRelatedField()` is deprecated. '
'Use `RelatedField(many=True)` instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyRelatedField, self).__init__(*args, **kwargs)
class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
def __init__(self, *args, **kwargs):
- warnings.warn('`ManyPrimaryKeyRelatedField()` is due to be deprecated. '
+ warnings.warn('`ManyPrimaryKeyRelatedField()` is deprecated. '
'Use `PrimaryKeyRelatedField(many=True)` instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs)
class ManySlugRelatedField(SlugRelatedField):
def __init__(self, *args, **kwargs):
- warnings.warn('`ManySlugRelatedField()` is due to be deprecated. '
+ warnings.warn('`ManySlugRelatedField()` is deprecated. '
'Use `SlugRelatedField(many=True)` instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManySlugRelatedField, self).__init__(*args, **kwargs)
class ManyHyperlinkedRelatedField(HyperlinkedRelatedField):
def __init__(self, *args, **kwargs):
- warnings.warn('`ManyHyperlinkedRelatedField()` is due to be deprecated. '
+ warnings.warn('`ManyHyperlinkedRelatedField()` is deprecated. '
'Use `HyperlinkedRelatedField(many=True)` instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs)
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 4c15e0db..8361cd40 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -24,6 +24,7 @@ from rest_framework.settings import api_settings
from rest_framework.request import clone_request
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
+from rest_framework.utils.formatting import get_view_name, get_view_description
from rest_framework import exceptions, parsers, status, VERSION
@@ -57,7 +58,7 @@ class JSONRenderer(BaseRenderer):
return ''
# If 'indent' is provided in the context, then pretty print the result.
- # E.g. If we're being called by the BrowseableAPIRenderer.
+ # E.g. If we're being called by the BrowsableAPIRenderer.
renderer_context = renderer_context or {}
indent = renderer_context.get('indent', None)
@@ -335,7 +336,7 @@ class BrowsableAPIRenderer(BaseRenderer):
return # Cannot use form overloading
try:
- view.check_permissions(clone_request(request, method))
+ view.check_permissions(request)
except exceptions.APIException:
return False # Doesn't have permissions
return True
@@ -371,6 +372,30 @@ class BrowsableAPIRenderer(BaseRenderer):
return fields
+ def _get_form(self, view, method, request):
+ # We need to impersonate a request with the correct method,
+ # so that eg. any dynamic get_serializer_class methods return the
+ # correct form for each method.
+ restore = view.request
+ request = clone_request(request, method)
+ view.request = request
+ try:
+ return self.get_form(view, method, request)
+ finally:
+ view.request = restore
+
+ def _get_raw_data_form(self, view, method, request, media_types):
+ # We need to impersonate a request with the correct method,
+ # so that eg. any dynamic get_serializer_class methods return the
+ # correct form for each method.
+ restore = view.request
+ request = clone_request(request, method)
+ view.request = request
+ try:
+ return self.get_raw_data_form(view, method, request, media_types)
+ finally:
+ view.request = restore
+
def get_form(self, view, method, request):
"""
Get a form, possibly bound to either the input or output data.
@@ -438,16 +463,13 @@ class BrowsableAPIRenderer(BaseRenderer):
return GenericContentForm()
def get_name(self, view):
- try:
- return view.get_name()
- except AttributeError:
- return smart_text(view.__class__.__name__)
+ return get_view_name(view.__class__, getattr(view, 'suffix', None))
def get_description(self, view):
- try:
- return view.get_description(html=True)
- except AttributeError:
- return smart_text(view.__doc__ or '')
+ return get_view_description(view.__class__, html=True)
+
+ def get_breadcrumbs(self, request):
+ return get_breadcrumbs(request.path)
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -467,20 +489,20 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer = self.get_default_renderer(view)
content = self.get_content(renderer, data, accepted_media_type, renderer_context)
- put_form = self.get_form(view, 'PUT', request)
- post_form = self.get_form(view, 'POST', request)
- patch_form = self.get_form(view, 'PATCH', request)
- delete_form = self.get_form(view, 'DELETE', request)
- options_form = self.get_form(view, 'OPTIONS', request)
+ put_form = self._get_form(view, 'PUT', request)
+ post_form = self._get_form(view, 'POST', request)
+ patch_form = self._get_form(view, 'PATCH', request)
+ delete_form = self._get_form(view, 'DELETE', request)
+ options_form = self._get_form(view, 'OPTIONS', request)
- raw_data_put_form = self.get_raw_data_form(view, 'PUT', request, media_types)
- raw_data_post_form = self.get_raw_data_form(view, 'POST', request, media_types)
- raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request, media_types)
+ raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types)
+ raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types)
+ raw_data_patch_form = self._get_raw_data_form(view, 'PATCH', request, media_types)
raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form
name = self.get_name(view)
description = self.get_description(view)
- breadcrumb_list = get_breadcrumbs(request.path)
+ breadcrumb_list = self.get_breadcrumbs(request)
template = loader.get_template(self.template)
context = RequestContext(request, {
diff --git a/rest_framework/request.py b/rest_framework/request.py
index ffbbab33..a434659c 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -1,11 +1,10 @@
"""
-The :mod:`request` module provides a :class:`Request` class used to wrap the standard `request`
-object received in all the views.
+The Request class is used as a wrapper around the standard request object.
The wrapped request then offers a richer API, in particular :
- content automatically parsed according to `Content-Type` header,
- and available as :meth:`.DATA<Request.DATA>`
+ and available as `request.DATA`
- full support of PUT method, including support for file uploads
- form overloading of HTTP method, content type and content
"""
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 5e1bf46e..26e4ab37 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -1,3 +1,9 @@
+"""
+The Response class in REST framework is similiar to HTTPResponse, except that
+it is initialized with unrendered data, instead of a pre-rendered string.
+
+The appropriate renderer is called during Django's template response rendering.
+"""
from __future__ import unicode_literals
from django.core.handlers.wsgi import STATUS_CODE_TEXT
from django.template.response import SimpleTemplateResponse
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
new file mode 100644
index 00000000..dba104c3
--- /dev/null
+++ b/rest_framework/routers.py
@@ -0,0 +1,249 @@
+"""
+Routers provide a convenient and consistent way of automatically
+determining the URL conf for your API.
+
+They are used by simply instantiating a Router class, and then registering
+all the required ViewSets with that router.
+
+For example, you might have a `urls.py` that looks something like this:
+
+ router = routers.DefaultRouter()
+ router.register('users', UserViewSet, 'user')
+ router.register('accounts', AccountViewSet, 'account')
+
+ urlpatterns = router.urls
+"""
+from __future__ import unicode_literals
+
+from collections import namedtuple
+from rest_framework import views
+from rest_framework.compat import patterns, url
+from rest_framework.decorators import api_view
+from rest_framework.response import Response
+from rest_framework.reverse import reverse
+from rest_framework.urlpatterns import format_suffix_patterns
+
+
+Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs'])
+
+
+def replace_methodname(format_string, methodname):
+ """
+ Partially format a format_string, swapping out any
+ '{methodname}' or '{methodnamehyphen}' components.
+ """
+ methodnamehyphen = methodname.replace('_', '-')
+ ret = format_string
+ ret = ret.replace('{methodname}', methodname)
+ ret = ret.replace('{methodnamehyphen}', methodnamehyphen)
+ return ret
+
+
+class BaseRouter(object):
+ def __init__(self):
+ self.registry = []
+
+ def register(self, prefix, viewset, base_name=None):
+ if base_name is None:
+ base_name = self.get_default_base_name(viewset)
+ self.registry.append((prefix, viewset, base_name))
+
+ def get_default_base_name(self, viewset):
+ """
+ If `base_name` is not specified, attempt to automatically determine
+ it from the viewset.
+ """
+ raise NotImplemented('get_default_base_name must be overridden')
+
+ def get_urls(self):
+ """
+ Return a list of URL patterns, given the registered viewsets.
+ """
+ raise NotImplemented('get_urls must be overridden')
+
+ @property
+ def urls(self):
+ if not hasattr(self, '_urls'):
+ self._urls = patterns('', *self.get_urls())
+ return self._urls
+
+
+class SimpleRouter(BaseRouter):
+ routes = [
+ # List route.
+ Route(
+ url=r'^{prefix}/$',
+ mapping={
+ 'get': 'list',
+ 'post': 'create'
+ },
+ name='{basename}-list',
+ initkwargs={'suffix': 'List'}
+ ),
+ # Detail route.
+ Route(
+ url=r'^{prefix}/{lookup}/$',
+ mapping={
+ 'get': 'retrieve',
+ 'put': 'update',
+ 'patch': 'partial_update',
+ 'delete': 'destroy'
+ },
+ name='{basename}-detail',
+ initkwargs={'suffix': 'Instance'}
+ ),
+ # Dynamically generated routes.
+ # Generated using @action or @link decorators on methods of the viewset.
+ Route(
+ url=r'^{prefix}/{lookup}/{methodname}/$',
+ mapping={
+ '{httpmethod}': '{methodname}',
+ },
+ name='{basename}-{methodnamehyphen}',
+ initkwargs={}
+ ),
+ ]
+
+ def get_default_base_name(self, viewset):
+ """
+ If `base_name` is not specified, attempt to automatically determine
+ it from the viewset.
+ """
+ model_cls = getattr(viewset, 'model', None)
+ queryset = getattr(viewset, 'queryset', None)
+ if model_cls is None and queryset is not None:
+ model_cls = queryset.model
+
+ assert model_cls, '`name` not argument not specified, and could ' \
+ 'not automatically determine the name from the viewset, as ' \
+ 'it does not have a `.model` or `.queryset` attribute.'
+
+ return model_cls._meta.object_name.lower()
+
+ def get_routes(self, viewset):
+ """
+ Augment `self.routes` with any dynamically generated routes.
+
+ Returns a list of the Route namedtuple.
+ """
+
+ # Determine any `@action` or `@link` decorated methods on the viewset
+ dynamic_routes = []
+ for methodname in dir(viewset):
+ attr = getattr(viewset, methodname)
+ httpmethod = getattr(attr, 'bind_to_method', None)
+ if httpmethod:
+ dynamic_routes.append((httpmethod, methodname))
+
+ ret = []
+ for route in self.routes:
+ if route.mapping == {'{httpmethod}': '{methodname}'}:
+ # Dynamic routes (@link or @action decorator)
+ for httpmethod, methodname in dynamic_routes:
+ initkwargs = route.initkwargs.copy()
+ initkwargs.update(getattr(viewset, methodname).kwargs)
+ ret.append(Route(
+ url=replace_methodname(route.url, methodname),
+ mapping={httpmethod: methodname},
+ name=replace_methodname(route.name, methodname),
+ initkwargs=initkwargs,
+ ))
+ else:
+ # Standard route
+ ret.append(route)
+
+ return ret
+
+ def get_method_map(self, viewset, method_map):
+ """
+ Given a viewset, and a mapping of http methods to actions,
+ return a new mapping which only includes any mappings that
+ are actually implemented by the viewset.
+ """
+ bound_methods = {}
+ for method, action in method_map.items():
+ if hasattr(viewset, action):
+ bound_methods[method] = action
+ return bound_methods
+
+ def get_lookup_regex(self, viewset):
+ """
+ Given a viewset, return the portion of URL regex that is used
+ to match against a single instance.
+ """
+ base_regex = '(?P<{lookup_field}>[^/]+)'
+ lookup_field = getattr(viewset, 'lookup_field', 'pk')
+ return base_regex.format(lookup_field=lookup_field)
+
+ def get_urls(self):
+ """
+ Use the registered viewsets to generate a list of URL patterns.
+ """
+ ret = []
+
+ for prefix, viewset, basename in self.registry:
+ lookup = self.get_lookup_regex(viewset)
+ routes = self.get_routes(viewset)
+
+ for route in routes:
+
+ # Only actions which actually exist on the viewset will be bound
+ mapping = self.get_method_map(viewset, route.mapping)
+ if not mapping:
+ continue
+
+ # Build the url pattern
+ regex = route.url.format(prefix=prefix, lookup=lookup)
+ view = viewset.as_view(mapping, **route.initkwargs)
+ name = route.name.format(basename=basename)
+ ret.append(url(regex, view, name=name))
+
+ return ret
+
+
+class DefaultRouter(SimpleRouter):
+ """
+ The default router extends the SimpleRouter, but also adds in a default
+ API root view, and adds format suffix patterns to the URLs.
+ """
+ include_root_view = True
+ include_format_suffixes = True
+
+ def get_api_root_view(self):
+ """
+ Return a view to use as the API root.
+ """
+ api_root_dict = {}
+ list_name = self.routes[0].name
+ for prefix, viewset, basename in self.registry:
+ api_root_dict[prefix] = list_name.format(basename=basename)
+
+ class APIRoot(views.APIView):
+ _ignore_model_permissions = True
+
+ def get(self, request, format=None):
+ ret = {}
+ for key, url_name in api_root_dict.items():
+ ret[key] = reverse(url_name, request=request, format=format)
+ return Response(ret)
+
+ return APIRoot.as_view()
+
+ def get_urls(self):
+ """
+ Generate the list of URL patterns, including a default root view
+ for the API, and appending `.json` style format suffixes.
+ """
+ urls = []
+
+ if self.include_root_view:
+ root_url = url(r'^$', self.get_api_root_view(), name='api-root')
+ urls.append(root_url)
+
+ default_urls = super(DefaultRouter, self).get_urls()
+ urls.extend(default_urls)
+
+ if self.include_format_suffixes:
+ urls = format_suffix_patterns(urls)
+
+ return urls
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index 9b519f27..9dd7b545 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -4,6 +4,8 @@ DEBUG = True
TEMPLATE_DEBUG = DEBUG
DEBUG_PROPAGATE_EXCEPTIONS = True
+ALLOWED_HOSTS = ['*']
+
ADMINS = (
# ('Your Name', 'your_email@domain.com'),
)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 6aca2f57..500bb306 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -1,3 +1,15 @@
+"""
+Serializers and ModelSerializers are similar to Forms and ModelForms.
+Unlike forms, they are not constrained to dealing with HTML output, and
+form encoded input.
+
+Serialization in REST framework is a two-phase process:
+
+1. Serializers marshal between complex types like model instances, and
+python primatives.
+2. The process of marshalling between python primatives and request and
+response content is handled by parsers and renderers.
+"""
from __future__ import unicode_literals
import copy
import datetime
@@ -130,14 +142,14 @@ class BaseSerializer(WritableField):
def __init__(self, instance=None, data=None, files=None,
context=None, partial=False, many=None,
- allow_delete=False, **kwargs):
+ allow_add_remove=False, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
self.parent = None
self.root = None
self.partial = partial
self.many = many
- self.allow_delete = allow_delete
+ self.allow_add_remove = allow_add_remove
self.context = context or {}
@@ -154,8 +166,8 @@ class BaseSerializer(WritableField):
if many and instance is not None and not hasattr(instance, '__iter__'):
raise ValueError('instance should be a queryset or other iterable with many=True')
- if allow_delete and not many:
- raise ValueError('allow_delete should only be used for bulk updates, but you have not set many=True')
+ if allow_add_remove and not many:
+ raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')
#####
# Methods to determine which fields to use when (de)serializing objects.
@@ -188,7 +200,7 @@ class BaseSerializer(WritableField):
# If 'fields' is specified, use those fields, in that order.
if self.opts.fields:
- assert isinstance(self.opts.fields, (list, tuple)), '`include` must be a list or tuple'
+ assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple'
new = SortedDict()
for key in self.opts.fields:
new[key] = ret[key]
@@ -196,7 +208,7 @@ class BaseSerializer(WritableField):
# Remove anything in 'exclude'
if self.opts.exclude:
- assert isinstance(self.opts.fields, (list, tuple)), '`exclude` must be a list or tuple'
+ assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple'
for key in self.opts.exclude:
ret.pop(key, None)
@@ -206,18 +218,6 @@ class BaseSerializer(WritableField):
return ret
#####
- # Field methods - used when the serializer class is itself used as a field.
-
- def initialize(self, parent, field_name):
- """
- Same behaviour as usual Field, except that we need to keep track
- of state so that we can deal with handling maximum depth.
- """
- super(BaseSerializer, self).initialize(parent, field_name)
- if parent.opts.depth:
- self.opts.depth = parent.opts.depth - 1
-
- #####
# Methods to convert or revert from objects <--> primitive representations.
def get_field_key(self, field_name):
@@ -424,9 +424,9 @@ class BaseSerializer(WritableField):
else:
many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))
if many:
- warnings.warn('Implict list/queryset serialization is due to be deprecated. '
+ warnings.warn('Implict list/queryset serialization is deprecated. '
'Use the `many=True` flag when instantiating the serializer.',
- PendingDeprecationWarning, stacklevel=3)
+ DeprecationWarning, stacklevel=3)
if many:
ret = []
@@ -448,6 +448,10 @@ class BaseSerializer(WritableField):
# Determine which object we're updating
identity = self.get_identity(item)
self.object = identity_to_objects.pop(identity, None)
+ if self.object is None and not self.allow_add_remove:
+ ret.append(None)
+ errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
+ continue
ret.append(self.from_native(item, None))
errors.append(self._errors)
@@ -457,7 +461,7 @@ class BaseSerializer(WritableField):
self._errors = any(errors) and errors or []
else:
- self._errors = {'non_field_errors': ['Expected a list of items']}
+ self._errors = {'non_field_errors': ['Expected a list of items.']}
else:
ret = self.from_native(data, files)
@@ -482,9 +486,9 @@ class BaseSerializer(WritableField):
else:
many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
if many:
- warnings.warn('Implict list/queryset serialization is due to be deprecated. '
+ warnings.warn('Implict list/queryset serialization is deprecated. '
'Use the `many=True` flag when instantiating the serializer.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
if many:
self._data = [self.to_native(item) for item in obj]
@@ -508,7 +512,7 @@ class BaseSerializer(WritableField):
else:
self.save_object(self.object, **kwargs)
- if self.allow_delete and self._deleted:
+ if self.allow_add_remove and self._deleted:
[self.delete_object(item) for item in self._deleted]
return self.object
@@ -544,6 +548,7 @@ class ModelSerializer(Serializer):
models.DateTimeField: DateTimeField,
models.DateField: DateField,
models.TimeField: TimeField,
+ models.DecimalField: DecimalField,
models.EmailField: EmailField,
models.CharField: CharField,
models.URLField: URLField,
@@ -564,39 +569,94 @@ class ModelSerializer(Serializer):
assert cls is not None, \
"Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__
opts = get_concrete_model(cls)._meta
- pk_field = opts.pk
+ ret = SortedDict()
+ nested = bool(self.opts.depth)
- # If model is a child via multitable inheritance, use parent's pk
+ # Deal with adding the primary key field
+ pk_field = opts.pk
while pk_field.rel and pk_field.rel.parent_link:
+ # If model is a child via multitable inheritance, use parent's pk
pk_field = pk_field.rel.to._meta.pk
- fields = [pk_field]
- fields += [field for field in opts.fields if field.serialize]
- fields += [field for field in opts.many_to_many if field.serialize]
+ field = self.get_pk_field(pk_field)
+ if field:
+ ret[pk_field.name] = field
- ret = SortedDict()
- nested = bool(self.opts.depth)
- is_pk = True # First field in the list is the pk
-
- for model_field in fields:
- if is_pk:
- field = self.get_pk_field(model_field)
- is_pk = False
- elif model_field.rel and nested:
- field = self.get_nested_field(model_field)
- elif model_field.rel:
+ # Deal with forward relationships
+ forward_rels = [field for field in opts.fields if field.serialize]
+ forward_rels += [field for field in opts.many_to_many if field.serialize]
+
+ for model_field in forward_rels:
+ if model_field.rel:
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
- field = self.get_related_field(model_field, to_many=to_many)
+ related_model = model_field.rel.to
+
+ if model_field.rel and nested:
+ if len(inspect.getargspec(self.get_nested_field).args) == 2:
+ warnings.warn(
+ 'The `get_nested_field(model_field)` call signature '
+ 'is due to be deprecated. '
+ 'Use `get_nested_field(model_field, related_model, '
+ 'to_many) instead',
+ PendingDeprecationWarning
+ )
+ field = self.get_nested_field(model_field)
+ else:
+ field = self.get_nested_field(model_field, related_model, to_many)
+ elif model_field.rel:
+ if len(inspect.getargspec(self.get_nested_field).args) == 3:
+ warnings.warn(
+ 'The `get_related_field(model_field, to_many)` call '
+ 'signature is due to be deprecated. '
+ 'Use `get_related_field(model_field, related_model, '
+ 'to_many) instead',
+ PendingDeprecationWarning
+ )
+ field = self.get_related_field(model_field, to_many=to_many)
+ else:
+ field = self.get_related_field(model_field, related_model, to_many)
else:
field = self.get_field(model_field)
if field:
ret[model_field.name] = field
+ # Deal with reverse relationships
+ if not self.opts.fields:
+ reverse_rels = []
+ else:
+ # Reverse relationships are only included if they are explicitly
+ # present in the `fields` option on the serializer
+ reverse_rels = opts.get_all_related_objects()
+ reverse_rels += opts.get_all_related_many_to_many_objects()
+
+ for relation in reverse_rels:
+ accessor_name = relation.get_accessor_name()
+ if not self.opts.fields or accessor_name not in self.opts.fields:
+ continue
+ related_model = relation.model
+ to_many = relation.field.rel.multiple
+
+ if nested:
+ field = self.get_nested_field(None, related_model, to_many)
+ else:
+ field = self.get_related_field(None, related_model, to_many)
+
+ if field:
+ ret[accessor_name] = field
+
+ # Add the `read_only` flag to any fields that have bee specified
+ # in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
+ assert field_name not in self.base_fields.keys(), \
+ "field '%s' on serializer '%s' specfied in " \
+ "`read_only_fields`, but also added " \
+ "as an explict field. Remove it from `read_only_fields`." % \
+ (field_name, self.__class__.__name__)
assert field_name in ret, \
- "read_only_fields on '%s' included invalid item '%s'" % \
+ "Noexistant field '%s' specified in `read_only_fields` " \
+ "on serializer '%s'." % \
(self.__class__.__name__, field_name)
ret[field_name].read_only = True
@@ -608,27 +668,36 @@ class ModelSerializer(Serializer):
"""
return self.get_field(model_field)
- def get_nested_field(self, model_field):
+ def get_nested_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a nested relational field.
+
+ Note that model_field will be `None` for reverse relationships.
"""
class NestedModelSerializer(ModelSerializer):
class Meta:
- model = model_field.rel.to
- return NestedModelSerializer()
+ model = related_model
+ depth = self.opts.depth - 1
- def get_related_field(self, model_field, to_many=False):
+ return NestedModelSerializer(many=to_many)
+
+ def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
+
+ Note that model_field will be `None` for reverse relationships.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
+
kwargs = {
- 'required': not(model_field.null or model_field.blank),
- 'queryset': model_field.rel.to._default_manager,
+ 'queryset': related_model._default_manager,
'many': to_many
}
+ if model_field:
+ kwargs['required'] = not(model_field.null or model_field.blank)
+
return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
@@ -636,15 +705,14 @@ class ModelSerializer(Serializer):
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
- has_default = model_field.has_default()
- if model_field.null or model_field.blank or has_default:
+ if model_field.null or model_field.blank:
kwargs['required'] = False
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
- if has_default:
+ if model_field.has_default():
kwargs['default'] = model_field.get_default()
if issubclass(model_field.__class__, models.TextField):
@@ -737,7 +805,7 @@ class ModelSerializer(Serializer):
Override the default method to also include model field validation.
"""
instance = super(ModelSerializer, self).from_native(data, files)
- if instance:
+ if not self._errors:
return self.full_clean(instance)
def save_object(self, obj, **kwargs):
@@ -764,6 +832,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
def __init__(self, meta):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
+ self.lookup_field = getattr(meta, 'lookup_field', None)
class HyperlinkedModelSerializer(ModelSerializer):
@@ -773,6 +842,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
"""
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
+ _hyperlink_field_class = HyperlinkedRelatedField
url = HyperlinkedIdentityField()
@@ -793,22 +863,28 @@ class HyperlinkedModelSerializer(ModelSerializer):
return self._default_view_name % format_kwargs
def get_pk_field(self, model_field):
- return None
+ if self.opts.fields and model_field.name in self.opts.fields:
+ return self.get_field(model_field)
- def get_related_field(self, model_field, to_many):
+ def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
- rel = model_field.rel.to
kwargs = {
- 'required': not(model_field.null or model_field.blank),
- 'queryset': rel._default_manager,
- 'view_name': self._get_default_view_name(rel),
+ 'queryset': related_model._default_manager,
+ 'view_name': self._get_default_view_name(related_model),
'many': to_many
}
- return HyperlinkedRelatedField(**kwargs)
+
+ if model_field:
+ kwargs['required'] = not(model_field.null or model_field.blank)
+
+ if self.opts.lookup_field:
+ kwargs['lookup_field'] = self.opts.lookup_field
+
+ return self._hyperlink_field_class(**kwargs)
def get_identity(self, data):
"""
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index eede0c5a..beb511ac 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -29,6 +29,7 @@ from rest_framework.compat import six
USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None)
DEFAULTS = {
+ # Base API policies
'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.BrowsableAPIRenderer',
@@ -50,11 +51,15 @@ DEFAULTS = {
'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
+
+ # Genric view behavior
'DEFAULT_MODEL_SERIALIZER_CLASS':
'rest_framework.serializers.ModelSerializer',
'DEFAULT_PAGINATION_SERIALIZER_CLASS':
'rest_framework.pagination.PaginationSerializer',
+ 'DEFAULT_FILTER_BACKENDS': (),
+ # Throttling
'DEFAULT_THROTTLE_RATES': {
'user': None,
'anon': None,
@@ -64,9 +69,6 @@ DEFAULTS = {
'PAGINATE_BY': None,
'PAGINATE_BY_PARAM': None,
- # Filtering
- 'FILTER_BACKEND': None,
-
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
@@ -84,17 +86,20 @@ DEFAULTS = {
'DATE_INPUT_FORMATS': (
ISO_8601,
),
- 'DATE_FORMAT': ISO_8601,
+ 'DATE_FORMAT': None,
'DATETIME_INPUT_FORMATS': (
ISO_8601,
),
- 'DATETIME_FORMAT': ISO_8601,
+ 'DATETIME_FORMAT': None,
'TIME_INPUT_FORMATS': (
ISO_8601,
),
- 'TIME_FORMAT': ISO_8601,
+ 'TIME_FORMAT': None,
+
+ # Pending deprecation
+ 'FILTER_BACKEND': None,
}
@@ -108,6 +113,7 @@ IMPORT_STRINGS = (
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
'DEFAULT_MODEL_SERIALIZER_CLASS',
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
+ 'DEFAULT_FILTER_BACKENDS',
'FILTER_BACKEND',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 44633f5a..4410f285 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -115,7 +115,7 @@
</div>
<div class="response-info">
<pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %}
-{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|urlize_quoted_links }}</span>
+{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span>
{% endfor %}
</div>{{ content|urlize_quoted_links }}</pre>{% endautoescape %}
</div>
diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html
index e10ce20f..b7629327 100644
--- a/rest_framework/templates/rest_framework/login.html
+++ b/rest_framework/templates/rest_framework/login.html
@@ -1,53 +1,3 @@
-{% load url from future %}
-{% load rest_framework %}
-<html>
+{% extends "rest_framework/login_base.html" %}
- <head>
- <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>
- <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
- <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>
- </head>
-
- <body class="container">
-
-<div class="container-fluid" style="margin-top: 30px">
- <div class="row-fluid">
-
- <div class="well" style="width: 320px; margin-left: auto; margin-right: auto">
- <div class="row-fluid">
- <div>
- <h3 style="margin: 0 0 20px;">Django REST framework</h3>
- </div>
- </div><!-- /row fluid -->
-
- <div class="row-fluid">
- <div>
- <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">
- {% csrf_token %}
- <div id="div_id_username" class="clearfix control-group">
- <div class="controls">
- <Label class="span4">Username:</label>
- <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">
- </div>
- </div>
- <div id="div_id_password" class="clearfix control-group">
- <div class="controls">
- <Label class="span4">Password:</label>
- <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">
- </div>
- </div>
- <input type="hidden" name="next" value="{{ next }}" />
- <div class="form-actions-no-box">
- <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit">
- </div>
- </form>
- </div>
- </div><!-- /row fluid -->
- </div><!--/span-->
-
- </div><!-- /.row-fluid -->
- </div>
-
- </div>
- </body>
-</html>
+{# Override this template in your own templates directory to customize #}
diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html
new file mode 100644
index 00000000..a3e73b6b
--- /dev/null
+++ b/rest_framework/templates/rest_framework/login_base.html
@@ -0,0 +1,51 @@
+{% load url from future %}
+{% load rest_framework %}
+<html>
+
+ <head>
+ {% block style %}
+ {% block bootstrap_theme %}<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>{% endblock %}
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>
+ {% endblock %}
+ </head>
+
+ <body class="container">
+
+ <div class="container-fluid" style="margin-top: 30px">
+ <div class="row-fluid">
+ <div class="well" style="width: 320px; margin-left: auto; margin-right: auto">
+ <div class="row-fluid">
+ <div>
+ {% block branding %}<h3 style="margin: 0 0 20px;">Django REST framework</h3>{% endblock %}
+ </div>
+ </div><!-- /row fluid -->
+
+ <div class="row-fluid">
+ <div>
+ <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">
+ {% csrf_token %}
+ <div id="div_id_username" class="clearfix control-group">
+ <div class="controls">
+ <Label class="span4">Username:</label>
+ <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">
+ </div>
+ </div>
+ <div id="div_id_password" class="clearfix control-group">
+ <div class="controls">
+ <Label class="span4">Password:</label>
+ <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">
+ </div>
+ </div>
+ <input type="hidden" name="next" value="{{ next }}" />
+ <div class="form-actions-no-box">
+ <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit">
+ </div>
+ </form>
+ </div>
+ </div><!-- /.row-fluid -->
+ </div><!--/.well-->
+ </div><!-- /.row-fluid -->
+ </div><!-- /.container-fluid -->
+ </body>
+</html>
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index c21ddcd7..c86b6456 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -4,11 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch
from django.http import QueryDict
from django.utils.html import escape
from django.utils.safestring import SafeData, mark_safe
-from rest_framework.compat import urlparse
-from rest_framework.compat import force_text
-from rest_framework.compat import six
-import re
-import string
+from rest_framework.compat import urlparse, force_text, six, smart_urlquote
+import re, string
register = template.Library()
@@ -112,22 +109,6 @@ def replace_query_param(url, key, val):
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
-# Bunch of stuff cloned from urlize
-LEADING_PUNCTUATION = ['(', '<', '&lt;', '"', "'"]
-TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '&gt;', '"', "'"]
-DOTS = ['&middot;', '*', '\xe2\x80\xa2', '&#149;', '&bull;', '&#8226;']
-unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)')
-word_split_re = re.compile(r'(\s+)')
-punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \
- ('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]),
- '|'.join([re.escape(x) for x in TRAILING_PUNCTUATION])))
-simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$')
-link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+')
-html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE)
-hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL)
-trailing_empty_content_re = re.compile(r'(?:<p>(?:&nbsp;|\s|<br \/>)*?</p>\s*)+\Z')
-
-
# And the template tags themselves...
@register.simple_tag
@@ -195,15 +176,25 @@ def add_class(value, css_class):
return value
+# Bunch of stuff cloned from urlize
+TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"]
+WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('&lt;', '&gt;'),
+ ('"', '"'), ("'", "'")]
+word_split_re = re.compile(r'(\s+)')
+simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE)
+simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE)
+simple_email_re = re.compile(r'^\S+@\S+\.\S+$')
+
+
@register.filter
def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True):
"""
Converts any URLs in text into clickable links.
- Works on http://, https://, www. links and links ending in .org, .net or
- .com. Links can have trailing punctuation (periods, commas, close-parens)
- and leading punctuation (opening parens) and it'll still do the right
- thing.
+ Works on http://, https://, www. links, and also on links ending in one of
+ the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org).
+ Links can have trailing punctuation (periods, commas, close-parens) and
+ leading punctuation (opening parens) and it'll still do the right thing.
If trim_url_limit is not None, the URLs in link text longer than this limit
will truncated to trim_url_limit-3 characters and appended with an elipsis.
@@ -216,24 +207,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
safe_input = isinstance(text, SafeData)
words = word_split_re.split(force_text(text))
- nofollow_attr = nofollow and ' rel="nofollow"' or ''
for i, word in enumerate(words):
match = None
if '.' in word or '@' in word or ':' in word:
- match = punctuation_re.match(word)
- if match:
- lead, middle, trail = match.groups()
+ # Deal with punctuation.
+ lead, middle, trail = '', word, ''
+ for punctuation in TRAILING_PUNCTUATION:
+ if middle.endswith(punctuation):
+ middle = middle[:-len(punctuation)]
+ trail = punctuation + trail
+ for opening, closing in WRAPPING_PUNCTUATION:
+ if middle.startswith(opening):
+ middle = middle[len(opening):]
+ lead = lead + opening
+ # Keep parentheses at the end only if they're balanced.
+ if (middle.endswith(closing)
+ and middle.count(closing) == middle.count(opening) + 1):
+ middle = middle[:-len(closing)]
+ trail = closing + trail
+
# Make URL we want to point to.
url = None
- if middle.startswith('http://') or middle.startswith('https://'):
- url = middle
- elif middle.startswith('www.') or ('@' not in middle and \
- middle and middle[0] in string.ascii_letters + string.digits and \
- (middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))):
- url = 'http://%s' % middle
- elif '@' in middle and not ':' in middle and simple_email_re.match(middle):
- url = 'mailto:%s' % middle
+ nofollow_attr = ' rel="nofollow"' if nofollow else ''
+ if simple_url_re.match(middle):
+ url = smart_urlquote(middle)
+ elif simple_url_2_re.match(middle):
+ url = smart_urlquote('http://%s' % middle)
+ elif not ':' in middle and simple_email_re.match(middle):
+ local, domain = middle.rsplit('@', 1)
+ try:
+ domain = domain.encode('idna').decode('ascii')
+ except UnicodeError:
+ continue
+ url = 'mailto:%s@%s' % (local, domain)
nofollow_attr = ''
+
# Make link.
if url:
trimmed = trim_url(middle)
@@ -251,4 +259,15 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
words[i] = mark_safe(word)
elif autoescape:
words[i] = escape(word)
- return mark_safe(''.join(words))
+ return ''.join(words)
+
+
+@register.filter
+def break_long_headers(header):
+ """
+ Breaks headers longer than 160 characters (~page length)
+ when possible (are comma separated)
+ """
+ if len(header) > 160 and ',' in header:
+ header = mark_safe('<br> ' + ', <br>'.join(header.split(',')))
+ return header
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index b663ca48..8e6d3e51 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -466,17 +466,13 @@ class OAuth2Tests(TestCase):
def _create_authorization_header(self, token=None):
return "Bearer {0}".format(token or self.access_token.token)
- def _client_credentials_params(self):
- return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET}
-
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_with_wrong_authorization_header_token_type_failing(self):
"""Ensure that a wrong token type lead to the correct HTTP error status code"""
auth = "Wrong token-type-obsviously"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -485,8 +481,7 @@ class OAuth2Tests(TestCase):
auth = "Bearer wrong token format"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -495,33 +490,21 @@ class OAuth2Tests(TestCase):
auth = "Bearer wrong-token"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_client_data_failing_auth(self):
- """Ensure GETing form over OAuth with incorrect client credentials fails"""
- auth = self._create_authorization_header()
- params = self._client_credentials_params()
- params['client_id'] += 'a'
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_passing_auth(self):
"""Ensure GETing form over OAuth with correct client credentials succeed"""
auth = self._create_authorization_header()
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth(self):
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
auth = self._create_authorization_header()
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -529,16 +512,14 @@ class OAuth2Tests(TestCase):
"""Ensure POSTing when there is no OAuth access token in db fails"""
self.access_token.delete()
auth = self._create_authorization_header()
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_with_refresh_token_failing_auth(self):
"""Ensure POSTing with refresh token instead of access token fails"""
auth = self._create_authorization_header(token=self.refresh_token.token)
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -547,8 +528,7 @@ class OAuth2Tests(TestCase):
self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
self.access_token.save()
auth = self._create_authorization_header()
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
self.assertIn('Invalid token', response.content)
@@ -559,10 +539,9 @@ class OAuth2Tests(TestCase):
read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
read_only_access_token.save()
auth = self._create_authorization_header(token=read_only_access_token.token)
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
- response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -572,6 +551,5 @@ class OAuth2Tests(TestCase):
read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
read_write_access_token.save()
auth = self._create_authorization_header(token=read_write_access_token.token)
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py
index 5b3315bc..52c1a34c 100644
--- a/rest_framework/tests/description.py
+++ b/rest_framework/tests/description.py
@@ -4,6 +4,7 @@ from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.views import APIView
from rest_framework.compat import apply_markdown
+from rest_framework.utils.formatting import get_view_name, get_view_description
# We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring
@@ -49,22 +50,16 @@ MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
class TestViewNamesAndDescriptions(TestCase):
- def test_resource_name_uses_classname_by_default(self):
- """Ensure Resource names are based on the classname by default."""
+ def test_view_name_uses_class_name(self):
+ """
+ Ensure view names are based on the class name.
+ """
class MockView(APIView):
pass
- self.assertEqual(MockView().get_name(), 'Mock')
+ self.assertEqual(get_view_name(MockView), 'Mock')
- def test_resource_name_can_be_set_explicitly(self):
- """Ensure Resource names can be set using the 'get_name' method."""
- example = 'Some Other Name'
- class MockView(APIView):
- def get_name(self):
- return example
- self.assertEqual(MockView().get_name(), example)
-
- def test_resource_description_uses_docstring_by_default(self):
- """Ensure Resource names are based on the docstring by default."""
+ def test_view_description_uses_docstring(self):
+ """Ensure view descriptions are based on the docstring."""
class MockView(APIView):
"""an example docstring
====================
@@ -81,44 +76,32 @@ class TestViewNamesAndDescriptions(TestCase):
# hash style header #"""
- self.assertEqual(MockView().get_description(), DESCRIPTION)
-
- def test_resource_description_can_be_set_explicitly(self):
- """Ensure Resource descriptions can be set using the 'get_description' method."""
- example = 'Some other description'
-
- class MockView(APIView):
- """docstring"""
- def get_description(self):
- return example
- self.assertEqual(MockView().get_description(), example)
+ self.assertEqual(get_view_description(MockView), DESCRIPTION)
- def test_resource_description_supports_unicode(self):
+ def test_view_description_supports_unicode(self):
+ """
+ Unicode in docstrings should be respected.
+ """
class MockView(APIView):
"""Проверка"""
pass
- self.assertEqual(MockView().get_description(), "Проверка")
-
-
- def test_resource_description_does_not_require_docstring(self):
- """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method."""
- example = 'Some other description'
-
- class MockView(APIView):
- def get_description(self):
- return example
- self.assertEqual(MockView().get_description(), example)
+ self.assertEqual(get_view_description(MockView), "Проверка")
- def test_resource_description_can_be_empty(self):
- """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string."""
+ def test_view_description_can_be_empty(self):
+ """
+ Ensure that if a view has no docstring,
+ then it's description is the empty string.
+ """
class MockView(APIView):
pass
- self.assertEqual(MockView().get_description(), '')
+ self.assertEqual(get_view_description(MockView), '')
def test_markdown(self):
- """Ensure markdown to HTML works as expected"""
+ """
+ Ensure markdown to HTML works as expected.
+ """
if apply_markdown:
gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21
lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
index 19c663d8..6b1cdfc7 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/fields.py
@@ -2,13 +2,14 @@
General serializer field tests.
"""
from __future__ import unicode_literals
+from django.utils.datastructures import SortedDict
import datetime
-
+from decimal import Decimal
from django.db import models
from django.test import TestCase
from django.core import validators
-
from rest_framework import serializers
+from rest_framework.serializers import Serializer
class TimestampedModel(models.Model):
@@ -61,6 +62,20 @@ class BasicFieldTests(TestCase):
serializer = CharPrimaryKeyModelSerializer()
self.assertEqual(serializer.fields['id'].read_only, False)
+ def test_dict_field_ordering(self):
+ """
+ Field should preserve dictionary ordering, if it exists.
+ See: https://github.com/tomchristie/django-rest-framework/issues/832
+ """
+ ret = SortedDict()
+ ret['c'] = 1
+ ret['b'] = 1
+ ret['a'] = 1
+ ret['z'] = 1
+ field = serializers.Field()
+ keys = list(field.to_native(ret).keys())
+ self.assertEqual(keys, ['c', 'b', 'a', 'z'])
+
class DateFieldTest(TestCase):
"""
@@ -481,3 +496,192 @@ class TimeFieldTest(TestCase):
self.assertEqual('04 - 00 [000000]', result_1)
self.assertEqual('04 - 59 [000000]', result_2)
self.assertEqual('04 - 59 [000200]', result_3)
+
+
+class DecimalFieldTest(TestCase):
+ """
+ Tests for the DecimalField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts string values
+ """
+ f = serializers.DecimalField()
+ result_1 = f.from_native('9000')
+ result_2 = f.from_native('1.00000001')
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_from_native_invalid_string(self):
+ """
+ Make sure from_native() raises ValidationError on passing invalid string
+ """
+ f = serializers.DecimalField()
+
+ try:
+ f.from_native('123.45.6')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Enter a number."])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_integer(self):
+ """
+ Make sure from_native() accepts integer values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(9000)
+
+ self.assertEqual(Decimal('9000'), result)
+
+ def test_from_native_float(self):
+ """
+ Make sure from_native() accepts float values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(1.00000001)
+
+ self.assertEqual(Decimal('1.00000001'), result)
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns Decimal as string.
+ """
+ f = serializers.DecimalField()
+
+ result_1 = f.to_native(Decimal('9000'))
+ result_2 = f.to_native(Decimal('1.00000001'))
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+ def test_valid_serialization(self):
+ """
+ Make sure the serializer works correctly
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_value=9010,
+ min_value=9000,
+ max_digits=6,
+ decimal_places=2)
+
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
+
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
+
+ def test_raise_max_value(self):
+ """
+ Make sure max_value violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '123'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
+
+ def test_raise_min_value(self):
+ """
+ Make sure min_value violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(min_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '99'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
+
+ def test_raise_max_digits(self):
+ """
+ Make sure max_digits violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_digits=5)
+
+ s = DecimalSerializer(data={'decimal_field': '123.456'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
+
+ def test_raise_max_decimal_places(self):
+ """
+ Make sure max_decimal_places violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '123.4567'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
+
+ def test_raise_max_whole_digits(self):
+ """
+ Make sure max_whole_digits violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '12345.6'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
+
+
+class ChoiceFieldTests(TestCase):
+ """
+ Tests for the ChoiceField options generator
+ """
+
+ SAMPLE_CHOICES = [
+ ('red', 'Red'),
+ ('green', 'Green'),
+ ('blue', 'Blue'),
+ ]
+
+ def test_choices_required(self):
+ """
+ Make sure proper choices are rendered if field is required
+ """
+ f = serializers.ChoiceField(required=True, choices=self.SAMPLE_CHOICES)
+ self.assertEqual(f.choices, self.SAMPLE_CHOICES)
+
+ def test_choices_not_required(self):
+ """
+ Make sure proper choices (plus blank) are rendered if the field isn't required
+ """
+ f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)
+ self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES)
diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/filters.py
new file mode 100644
index 00000000..8ae6d530
--- /dev/null
+++ b/rest_framework/tests/filters.py
@@ -0,0 +1,474 @@
+from __future__ import unicode_literals
+import datetime
+from decimal import Decimal
+from django.db import models
+from django.core.urlresolvers import reverse
+from django.test import TestCase
+from django.test.client import RequestFactory
+from django.utils import unittest
+from rest_framework import generics, serializers, status, filters
+from rest_framework.compat import django_filters, patterns, url
+from rest_framework.tests.models import BasicModel
+
+factory = RequestFactory()
+
+
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
+if django_filters:
+ # Basic filter on a list view.
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These class are used to test a filter class.
+ class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterClassRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These classes are used to test a misconfigured filter class.
+ class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+
+ class Meta:
+ model = BasicModel
+ fields = ['text']
+
+ class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = MisconfiguredFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class FilterClassDetailView(generics.RetrieveAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # Regression test for #814
+ class FilterableItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FilterableItem
+
+ class FilterFieldsQuerysetView(generics.ListCreateAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class GetQuerysetView(generics.ListCreateAPIView):
+ serializer_class = FilterableItemSerializer
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ def get_queryset(self):
+ return FilterableItem.objects.all()
+
+ urlpatterns = patterns('',
+ url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
+ url(r'^$', FilterClassRootView.as_view(), name='root-view'),
+ url(r'^get-queryset/$', GetQuerysetView.as_view(),
+ name='get-queryset-view'),
+ )
+
+
+class CommonFilteringTestCase(TestCase):
+ def _serialize_object(self, obj):
+ return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+
+ def setUp(self):
+ """
+ Create 10 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(10):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ self._serialize_object(obj)
+ for obj in self.objects.all()
+ ]
+
+
+class IntegrationTestFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_fields_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter works.
+ search_date = datetime.date(2012, 9, 22)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] == search_date]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_filter_with_queryset(self):
+ """
+ Regression test for #814.
+ """
+ view = FilterFieldsQuerysetView.as_view()
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_filter_with_get_queryset_only(self):
+ """
+ Regression test for #834.
+ """
+ view = GetQuerysetView.as_view()
+ request = factory.get('/get-queryset/')
+ view(request).render()
+ # Used to raise "issubclass() arg 2 must be a class or tuple of classes"
+ # here when neither `model' nor `queryset' was specified.
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_class_root_view(self):
+ """
+ GET requests to filtered ListCreateAPIView that have a filter_class set
+ should return filtered results.
+ """
+ view = FilterClassRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter set with 'lt' in the filter class works.
+ search_decimal = Decimal('4.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter set with 'gt' in the filter class works.
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the text filter set with 'icontains' in the filter class works.
+ search_text = 'ff'
+ request = factory.get('/?text=%s' % search_text)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if search_text in f['text'].lower()]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date and
+ f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_incorrectly_configured_filter(self):
+ """
+ An error should be displayed when the filter class is misconfigured.
+ """
+ view = IncorrectlyConfiguredRootView.as_view()
+
+ request = factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_unknown_filter(self):
+ """
+ GET requests with filters that aren't configured should return 200.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ search_integer = 10
+ request = factory.get('/?integer=%s' % search_integer)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class IntegrationTestDetailFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered detail views.
+ """
+ urls = 'rest_framework.tests.filters'
+
+ def _get_url(self, item):
+ return reverse('detail-view', kwargs=dict(pk=item.pk))
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_detail_view(self):
+ """
+ GET requests to filtered RetrieveAPIView that have a filter_class set
+ should return filtered results.
+ """
+ item = self.objects.all()[0]
+ data = self._serialize_object(item)
+
+ # Basic test with no filter.
+ response = self.client.get(self._get_url(item))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, data)
+
+ # Tests that the decimal filter set that should fail.
+ search_decimal = Decimal('4.25')
+ high_item = self.objects.filter(decimal__gt=search_decimal)[0]
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Tests that the decimal filter set that should succeed.
+ search_decimal = Decimal('4.25')
+ low_item = self.objects.filter(decimal__lt=search_decimal)[0]
+ low_item_data = self._serialize_object(low_item)
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, low_item_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
+ valid_item_data = self._serialize_object(valid_item)
+ response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, valid_item_data)
+
+
+class SearchFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class SearchFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # z abc
+ # zz bcd
+ # zzz cde
+ # ...
+ for idx in range(10):
+ title = 'z' * (idx + 1)
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ SearchFilterModel(title=title, text=text).save()
+
+ def test_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'z', 'text': 'abc'},
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+ def test_exact_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=zzz')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'zzz', 'text': 'cde'}
+ ]
+ )
+
+ def test_startswith_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', '^text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+
+class OrdringFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class OrderingFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # zyx abc
+ # yxw bcd
+ # xwv cde
+ for idx in range(3):
+ title = (
+ chr(ord('z') - idx) +
+ chr(ord('y') - idx) +
+ chr(ord('x') - idx)
+ )
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ OrdringFilterModel(title=title, text=text).save()
+
+ def test_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
+ def test_reverse_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=-text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_incorrectfield_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=foobar')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering_using_string(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
deleted file mode 100644
index 238da56e..00000000
--- a/rest_framework/tests/filterset.py
+++ /dev/null
@@ -1,169 +0,0 @@
-from __future__ import unicode_literals
-import datetime
-from decimal import Decimal
-from django.test import TestCase
-from django.test.client import RequestFactory
-from django.utils import unittest
-from rest_framework import generics, status, filters
-from rest_framework.compat import django_filters
-from rest_framework.tests.models import FilterableItem, BasicModel
-
-factory = RequestFactory()
-
-
-if django_filters:
- # Basic filter on a list view.
- class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_fields = ['decimal', 'date']
- filter_backend = filters.DjangoFilterBackend
-
- # These class are used to test a filter class.
- class SeveralFieldsFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
- decimal = django_filters.NumberFilter(lookup_type='lt')
- date = django_filters.DateFilter(lookup_type='gt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterClassRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = SeveralFieldsFilter
- filter_backend = filters.DjangoFilterBackend
-
- # These classes are used to test a misconfigured filter class.
- class MisconfiguredFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
-
- class Meta:
- model = BasicModel
- fields = ['text']
-
- class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = MisconfiguredFilter
- filter_backend = filters.DjangoFilterBackend
-
-
-class IntegrationTestFiltering(TestCase):
- """
- Integration tests for filtered list views.
- """
-
- def setUp(self):
- """
- Create 10 FilterableItem instances.
- """
- base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
- for i in range(10):
- text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
- decimal = base_data[1] + i
- date = base_data[2] - datetime.timedelta(days=i * 2)
- FilterableItem(text=text, decimal=decimal, date=date).save()
-
- self.objects = FilterableItem.objects
- self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
- ]
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_get_filtered_fields_root_view(self):
- """
- GET requests to paginated ListCreateAPIView should return paginated results.
- """
- view = FilterFieldsRootView.as_view()
-
- # Basic test with no filter.
- request = factory.get('/')
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- # Tests that the decimal filter works.
- search_decimal = Decimal('2.25')
- request = factory.get('/?decimal=%s' % search_decimal)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] == search_decimal]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the date filter works.
- search_date = datetime.date(2012, 9, 22)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] == search_date]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_get_filtered_class_root_view(self):
- """
- GET requests to filtered ListCreateAPIView that have a filter_class set
- should return filtered results.
- """
- view = FilterClassRootView.as_view()
-
- # Basic test with no filter.
- request = factory.get('/')
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- # Tests that the decimal filter set with 'lt' in the filter class works.
- search_decimal = Decimal('4.25')
- request = factory.get('/?decimal=%s' % search_decimal)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] < search_decimal]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the date filter set with 'gt' in the filter class works.
- search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the text filter set with 'icontains' in the filter class works.
- search_text = 'ff'
- request = factory.get('/?text=%s' % search_text)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if search_text in f['text'].lower()]
- self.assertEqual(response.data, expected_data)
-
- # Tests that multiple filters works.
- search_decimal = Decimal('5.25')
- search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date and
- f['decimal'] < search_decimal]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_incorrectly_configured_filter(self):
- """
- An error should be displayed when the filter class is misconfigured.
- """
- view = IncorrectlyConfiguredRootView.as_view()
-
- request = factory.get('/')
- self.assertRaises(AssertionError, view, request)
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_unknown_filter(self):
- """
- GET requests with filters that aren't configured should return 200.
- """
- view = FilterFieldsRootView.as_view()
-
- search_integer = 10
- request = factory.get('/?integer=%s' % search_integer)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index f564890c..15d87e86 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -1,7 +1,8 @@
from __future__ import unicode_literals
from django.db import models
+from django.shortcuts import get_object_or_404
from django.test import TestCase
-from rest_framework import generics, serializers, status
+from rest_framework import generics, renderers, serializers, status
from rest_framework.tests.utils import RequestFactory
from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
from rest_framework.compat import six
@@ -38,6 +39,7 @@ class SlugBasedInstanceView(InstanceView):
"""
model = SlugBasedModel
serializer_class = SlugSerializer
+ lookup_field = 'slug'
class TestRootView(TestCase):
@@ -302,6 +304,47 @@ class TestInstanceView(TestCase):
self.assertEqual(new_obj.text, 'foobar')
+class TestOverriddenGetObject(TestCase):
+ """
+ Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
+ queryset/model mechanism but instead overrides get_object()
+ """
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
+ """
+ Example detail view for override of get_object().
+ """
+ model = BasicModel
+
+ def get_object(self):
+ pk = int(self.kwargs['pk'])
+ return get_object_or_404(BasicModel.objects.all(), id=pk)
+
+ self.view = OverriddenGetObjectView.as_view()
+
+ def test_overridden_get_object_view(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object.
+ """
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+
# Regression test for #285
class CommentSerializer(serializers.ModelSerializer):
@@ -335,7 +378,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
self.assertEqual(created.content, 'foobar')
-# Test for particularly ugly regression with m2m in browseable API
+# Test for particularly ugly regression with m2m in browsable API
class ClassB(models.Model):
name = models.CharField(max_length=255)
@@ -360,7 +403,7 @@ class ExampleView(generics.ListCreateAPIView):
class TestM2MBrowseableAPI(TestCase):
def test_m2m_in_browseable_api(self):
"""
- Test for particularly ugly regression with m2m in browseable API
+ Test for particularly ugly regression with m2m in browsable API
"""
request = factory.get('/', HTTP_ACCEPT='text/html')
view = ExampleView().as_view()
@@ -392,22 +435,14 @@ class TestFilterBackendAppliedToViews(TestCase):
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
- self.root_view = RootView.as_view()
- self.instance_view = InstanceView.as_view()
- self.original_root_backend = getattr(RootView, 'filter_backend')
- self.original_instance_backend = getattr(InstanceView, 'filter_backend')
-
- def tearDown(self):
- setattr(RootView, 'filter_backend', self.original_root_backend)
- setattr(InstanceView, 'filter_backend', self.original_instance_backend)
def test_get_root_view_filters_by_name_with_filter_backend(self):
"""
GET requests to ListCreateAPIView should return filtered list.
"""
- setattr(RootView, 'filter_backend', InclusiveFilterBackend)
+ root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/')
- response = self.root_view(request).render()
+ response = root_view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
@@ -416,9 +451,9 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to ListCreateAPIView should return empty list when all models are filtered out.
"""
- setattr(RootView, 'filter_backend', ExclusiveFilterBackend)
+ root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/')
- response = self.root_view(request).render()
+ response = root_view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, [])
@@ -426,9 +461,9 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
"""
- setattr(InstanceView, 'filter_backend', ExclusiveFilterBackend)
+ instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/1')
- response = self.instance_view(request, pk=1).render()
+ response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {'detail': 'Not found'})
@@ -436,8 +471,40 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
"""
- setattr(InstanceView, 'filter_backend', InclusiveFilterBackend)
+ instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/1')
- response = self.instance_view(request, pk=1).render()
+ response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
+
+
+class TwoFieldModel(models.Model):
+ field_a = models.CharField(max_length=100)
+ field_b = models.CharField(max_length=100)
+
+
+class DynamicSerializerView(generics.ListCreateAPIView):
+ model = TwoFieldModel
+ renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
+
+ def get_serializer_class(self):
+ if self.request.method == 'POST':
+ class DynamicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TwoFieldModel
+ fields = ('field_b',)
+ return DynamicSerializer
+ return super(DynamicSerializerView, self).get_serializer_class()
+
+
+class TestFilterBackendAppliedToViews(TestCase):
+
+ def test_dynamic_serializer_form_in_browsable_api(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ view = DynamicSerializerView.as_view()
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertContains(response, 'field_b')
+ self.assertNotContains(response, 'field_a')
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index 9a61f299..8fc6ba77 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -27,6 +27,14 @@ class PhotoSerializer(serializers.Serializer):
return Photo(**attrs)
+class AlbumSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title')
+
+ class Meta:
+ model = Album
+ fields = ('title', 'url')
+
+
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
@@ -73,6 +81,8 @@ class PhotoListCreate(generics.ListCreateAPIView):
class AlbumDetail(generics.RetrieveAPIView):
model = Album
+ serializer_class = AlbumSerializer
+ lookup_field = 'title'
class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
@@ -180,6 +190,36 @@ class TestManyToManyHyperlinkedView(TestCase):
self.assertEqual(response.data, self.data[0])
+class TestHyperlinkedIdentityFieldLookup(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 Album instances.
+ """
+ titles = ['foo', 'bar', 'baz']
+ for title in titles:
+ album = Album(title=title)
+ album.save()
+ self.detail_view = AlbumDetail.as_view()
+ self.data = {
+ 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'},
+ 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
+ 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'}
+ }
+
+ def test_lookup_field(self):
+ """
+ GET requests to AlbumDetail view should return serialized Albums
+ with a url field keyed by `title`.
+ """
+ for album in Album.objects.all():
+ request = factory.get('/albums/{0}/'.format(album.title))
+ response = self.detail_view(request, title=album.title)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[album.title])
+
+
class TestCreateWithForeignKeys(TestCase):
urls = 'rest_framework.tests.hyperlinkedserializers'
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index f2117538..40e41a64 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -58,13 +58,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
-# Model to test filtering.
-class FilterableItem(RESTFrameworkModel):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
-
-
# Model for regression test for #285
class Comment(RESTFrameworkModel):
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index d2c9b051..e538a78e 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -1,18 +1,24 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
-import django
+from django.db import models
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters
-from rest_framework.tests.models import BasicModel, FilterableItem
+from rest_framework.tests.models import BasicModel
factory = RequestFactory()
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
class RootView(generics.ListCreateAPIView):
"""
Example description for OPTIONS.
@@ -124,21 +130,11 @@ class IntegrationTestPaginationAndFiltering(TestCase):
model = FilterableItem
paginate_by = 10
filter_class = DecimalFilter
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
view = FilterFieldsRootView.as_view()
EXPECTED_NUM_QUERIES = 2
- if django.VERSION < (1, 4):
- # On Django 1.3 we need to use django-filter 0.5.4
- #
- # The filter objects there don't expose a `.count()` method,
- # which means we only make a single query *but* it's a single
- # query across *all* of the queryset, instead of a COUNT and then
- # a SELECT with a LIMIT.
- #
- # Although this is fewer queries, it's actually a regression.
- EXPECTED_NUM_QUERIES = 1
request = factory.get('/?decimal=15.20')
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
@@ -181,7 +177,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
class BasicFilterFieldsRootView(generics.ListCreateAPIView):
model = FilterableItem
paginate_by = 10
- filter_backend = DecimalFilterBackend
+ filter_backends = (DecimalFilterBackend,)
view = BasicFilterFieldsRootView.as_view()
diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py
index 539c5b44..7699e10c 100644
--- a/rest_framework/tests/parsers.py
+++ b/rest_framework/tests/parsers.py
@@ -1,10 +1,11 @@
from __future__ import unicode_literals
from rest_framework.compat import StringIO
from django import forms
+from django.core.files.uploadhandler import MemoryFileUploadHandler
from django.test import TestCase
from django.utils import unittest
from rest_framework.compat import etree
-from rest_framework.parsers import FormParser
+from rest_framework.parsers import FormParser, FileUploadParser
from rest_framework.parsers import XMLParser
import datetime
@@ -82,3 +83,33 @@ class TestXMLParser(TestCase):
parser = XMLParser()
data = parser.parse(self._complex_data_input)
self.assertEqual(data, self._complex_data)
+
+
+class TestFileUploadParser(TestCase):
+ def setUp(self):
+ class MockRequest(object):
+ pass
+ from io import BytesIO
+ self.stream = BytesIO(
+ "Test text file".encode('utf-8')
+ )
+ request = MockRequest()
+ request.upload_handlers = (MemoryFileUploadHandler(),)
+ request.META = {
+ 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'),
+ 'HTTP_CONTENT_LENGTH': 14,
+ }
+ self.parser_context = {'request': request, 'kwargs': {}}
+
+ def test_parse(self):
+ """ Make sure the `QueryDict` works OK """
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ data_and_files = parser.parse(self.stream, None, self.parser_context)
+ file_obj = data_and_files.files['file']
+ self.assertEqual(file_obj._size, 14)
+
+ def test_get_filename(self):
+ parser = FileUploadParser()
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ self.assertEqual(filename, 'file.txt'.encode('utf-8'))
diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py
index cbf93c65..d19219c9 100644
--- a/rest_framework/tests/relations.py
+++ b/rest_framework/tests/relations.py
@@ -5,6 +5,7 @@ from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
from rest_framework import serializers
+from rest_framework.tests.models import BlogPost
class NullModel(models.Model):
@@ -33,7 +34,7 @@ class FieldTests(TestCase):
self.assertRaises(serializers.ValidationError, field.from_native, [])
-class TestManyRelateMixin(TestCase):
+class TestManyRelatedMixin(TestCase):
def test_missing_many_to_many_related_field(self):
'''
Regression test for #632
@@ -45,3 +46,55 @@ class TestManyRelateMixin(TestCase):
into = {}
field.field_from_native({}, None, 'field_name', into)
self.assertEqual(into['field_name'], [])
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class RelatedFieldSourceTests(TestCase):
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='get_blogposts_manager')
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='get_blogposts_queryset')
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='a.b.c')
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
index b5702a48..b3efbf52 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -4,6 +4,7 @@ from django.test.client import RequestFactory
from rest_framework import serializers
from rest_framework.compat import patterns, url
from rest_framework.tests.models import (
+ BlogPost,
ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
)
@@ -16,6 +17,7 @@ def dummy_view(request, pk):
pass
urlpatterns = patterns('',
+ url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
@@ -26,42 +28,44 @@ urlpatterns = patterns('',
)
+# ManyToMany
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail')
-
class Meta:
model = ManyToManyTarget
+ fields = ('url', 'name', 'sources')
class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ManyToManySource
+ fields = ('url', 'name', 'targets')
+# ForeignKey
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail')
-
class Meta:
model = ForeignKeyTarget
+ fields = ('url', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ForeignKeySource
+ fields = ('url', 'name', 'target')
# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = NullableForeignKeySource
+ fields = ('url', 'name', 'target')
-# OneToOne
+# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
- nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail')
-
class Meta:
model = OneToOneTarget
+ fields = ('url', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid
@@ -449,3 +453,72 @@ class HyperlinkedNullableOneToOneTests(TestCase):
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
]
self.assertEqual(serializer.data, expected)
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class HyperlinkedRelatedFieldSourceTests(TestCase):
+ urls = 'rest_framework.tests.relations_hyperlink'
+
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='get_blogposts_manager',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='get_blogposts_queryset',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='a.b.c',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
index a125ba65..f6d006b3 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/relations_nested.py
@@ -6,38 +6,30 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- depth = 1
- model = ForeignKeySource
-
-
-class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+ depth = 1
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = FlatForeignKeySourceSerializer(many=True)
-
class Meta:
model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
+ depth = 1
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- depth = 1
model = NullableForeignKeySource
-
-
-class NullableOneToOneSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableOneToOneSource
+ fields = ('id', 'name', 'target')
+ depth = 1
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- nullable_source = NullableOneToOneSourceSerializer()
-
class Meta:
model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
+ depth = 1
class ReverseForeignKeyTests(TestCase):
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
index f08e1808..0f8c5247 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/relations_pk.py
@@ -1,45 +1,51 @@
from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+from rest_framework.tests.models import (
+ BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
+ NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource,
+)
from rest_framework.compat import six
+# ManyToMany
class ManyToManyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.PrimaryKeyRelatedField(many=True)
-
class Meta:
model = ManyToManyTarget
+ fields = ('id', 'name', 'sources')
class ManyToManySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManySource
+ fields = ('id', 'name', 'targets')
+# ForeignKey
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.PrimaryKeyRelatedField(many=True)
-
class Meta:
model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableForeignKeySource
+ fields = ('id', 'name', 'target')
-# OneToOne
+# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- nullable_source = serializers.PrimaryKeyRelatedField()
-
class Meta:
model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid
@@ -418,3 +424,55 @@ class PKNullableOneToOneTests(TestCase):
{'id': 2, 'name': 'target-2', 'nullable_source': 1},
]
self.assertEqual(serializer.data, expected)
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class PrimaryKeyRelatedFieldSourceTests(TestCase):
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager')
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset')
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c')
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
diff --git a/rest_framework/tests/routers.py b/rest_framework/tests/routers.py
new file mode 100644
index 00000000..4e4765cb
--- /dev/null
+++ b/rest_framework/tests/routers.py
@@ -0,0 +1,55 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework import status
+from rest_framework.response import Response
+from rest_framework import viewsets
+from rest_framework.decorators import link, action
+from rest_framework.routers import SimpleRouter
+import copy
+
+factory = RequestFactory()
+
+
+class BasicViewSet(viewsets.ViewSet):
+ def list(self, request, *args, **kwargs):
+ return Response({'method': 'list'})
+
+ @action()
+ def action1(self, request, *args, **kwargs):
+ return Response({'method': 'action1'})
+
+ @action()
+ def action2(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @link()
+ def link1(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @link()
+ def link2(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
+
+class TestSimpleRouter(TestCase):
+ def setUp(self):
+ self.router = SimpleRouter()
+
+ def test_link_and_action_decorator(self):
+ routes = self.router.get_routes(BasicViewSet)
+ decorator_routes = routes[2:]
+ # Make sure all these endpoints exist and none have been clobbered
+ for i, endpoint in enumerate(['action1', 'action2', 'link1', 'link2']):
+ route = decorator_routes[i]
+ # check url listing
+ self.assertEqual(route.url,
+ '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint))
+ # check method to function mapping
+ if endpoint.startswith('action'):
+ method_map = 'post'
+ else:
+ method_map = 'get'
+ self.assertEqual(route.mapping[method_map], endpoint)
+
+
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index 0386ca76..d978dc87 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -1,9 +1,11 @@
from __future__ import unicode_literals
+from django.db import models
+from django.db.models.fields import BLANK_CHOICE_DASH
from django.utils.datastructures import MultiValueDict
from django.test import TestCase
from rest_framework import serializers
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
- BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,
+ BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
import datetime
import pickle
@@ -43,6 +45,17 @@ class CommentSerializer(serializers.Serializer):
return instance
+class NamesSerializer(serializers.Serializer):
+ first = serializers.CharField()
+ last = serializers.CharField(required=False, default='')
+ initials = serializers.CharField(required=False, default='')
+
+
+class PersonIdentifierSerializer(serializers.Serializer):
+ ssn = serializers.CharField()
+ names = NamesSerializer(source='names', required=False)
+
+
class BookSerializer(serializers.ModelSerializer):
isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'})
@@ -78,6 +91,18 @@ class PersonSerializer(serializers.ModelSerializer):
read_only_fields = ('age',)
+class PersonSerializerInvalidReadOnly(serializers.ModelSerializer):
+ """
+ Testing for #652.
+ """
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+ read_only_fields = ('age', 'info')
+
+
class AlbumsSerializer(serializers.ModelSerializer):
class Meta:
@@ -141,6 +166,42 @@ class BasicTests(TestCase):
self.assertFalse(serializer.object is expected)
self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
+ def test_create_nested(self):
+ """Test a serializer with nested data."""
+ names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'}
+ data = {'ssn': '1234567890', 'names': names}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is data)
+ self.assertEqual(serializer.data['names'], names)
+
+ def test_create_partial_nested(self):
+ """Test a serializer with nested data which has missing fields."""
+ names = {'first': 'John'}
+ data = {'ssn': '1234567890', 'names': names}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ expected_names = {'first': 'John', 'last': '', 'initials': ''}
+ data['names'] = expected_names
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is expected_names)
+ self.assertEqual(serializer.data['names'], expected_names)
+
+ def test_null_nested(self):
+ """Test a serializer with a nonexistent nested field"""
+ data = {'ssn': '1234567890'}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is data)
+ expected = {'ssn': '1234567890', 'names': None}
+ self.assertEqual(serializer.data, expected)
+
def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data)
expected = self.comment
@@ -189,6 +250,12 @@ class BasicTests(TestCase):
# Assert age is unchanged (35)
self.assertEqual(instance.age, self.person_data['age'])
+ def test_invalid_read_only_fields(self):
+ """
+ Regression test for #652.
+ """
+ self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, [])
+
class DictStyleSerializer(serializers.Serializer):
"""
@@ -357,7 +424,6 @@ class CustomValidationTests(TestCase):
def validate_email(self, attrs, source):
value = attrs[source]
-
return attrs
def validate_content(self, attrs, source):
@@ -738,6 +804,43 @@ class ManyRelatedTests(TestCase):
self.assertEqual(serializer.data, expected)
+ def test_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_depth_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+ depth = 1
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post',
+ 'blogpostcomment_set': [
+ {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1},
+ {'id': 2, 'text': 'I love this blog post', 'blog_post': 1}
+ ]
+ }
+ self.assertEqual(serializer.data, expected)
+
def test_callable_source(self):
post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I love this blog post")
@@ -767,8 +870,6 @@ class RelatedTraversalTest(TestCase):
post = BlogPost.objects.create(title="Test blog post", writer=user)
post.blogpostcomment_set.create(text="I love this blog post")
- from rest_framework.tests.models import BlogPostComment
-
class PersonSerializer(serializers.ModelSerializer):
class Meta:
model = Person
@@ -819,23 +920,6 @@ class RelatedTraversalTest(TestCase):
self.assertEqual(serializer.data, expected)
- def test_queryset_nested_traversal(self):
- """
- Relational fields should be able to use methods as their source.
- """
- BlogPost.objects.create(title='blah')
-
- class QuerysetMethodSerializer(serializers.Serializer):
- blogposts = serializers.RelatedField(many=True, source='get_all_blogposts')
-
- class ClassWithQuerysetMethod(object):
- def get_all_blogposts(self):
- return BlogPost.objects
-
- obj = ClassWithQuerysetMethod()
- serializer = QuerysetMethodSerializer(obj)
- self.assertEqual(serializer.data, {'blogposts': ['BlogPost object']})
-
class SerializerMethodFieldTests(TestCase):
def setUp(self):
@@ -966,25 +1050,95 @@ class SerializerPickleTests(TestCase):
repr(pickle.loads(pickle.dumps(data, 0)))
+# test for issue #725
+class SeveralChoicesModel(models.Model):
+ color = models.CharField(
+ max_length=10,
+ choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')],
+ blank=False
+ )
+ drink = models.CharField(
+ max_length=10,
+ choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')],
+ blank=False,
+ default='beer'
+ )
+ os = models.CharField(
+ max_length=10,
+ choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')],
+ blank=True
+ )
+ music_genre = models.CharField(
+ max_length=10,
+ choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')],
+ blank=True,
+ default='metal'
+ )
+
+
+class SerializerChoiceFields(TestCase):
+
+ def setUp(self):
+ super(SerializerChoiceFields, self).setUp()
+
+ class SeveralChoicesSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = SeveralChoicesModel
+ fields = ('color', 'drink', 'os', 'music_genre')
+
+ self.several_choices_serializer = SeveralChoicesSerializer
+
+ def test_choices_blank_false_not_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['color'].choices,
+ [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')]
+ )
+
+ def test_choices_blank_false_with_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['drink'].choices,
+ [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')]
+ )
+
+ def test_choices_blank_true_not_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['os'].choices,
+ BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')]
+ )
+
+ def test_choices_blank_true_with_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['music_genre'].choices,
+ BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')]
+ )
+
+
class DepthTest(TestCase):
def test_implicit_nesting(self):
+
writer = Person.objects.create(name="django", age=1)
post = BlogPost.objects.create(title="Test blog post", writer=writer)
+ comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
- class BlogPostSerializer(serializers.ModelSerializer):
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
class Meta:
- model = BlogPost
- depth = 1
+ model = BlogPostComment
+ depth = 2
- serializer = BlogPostSerializer(instance=post)
- expected = {'id': 1, 'title': 'Test blog post',
- 'writer': {'id': 1, 'name': 'django', 'age': 1}}
+ serializer = BlogPostCommentSerializer(instance=comment)
+ expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
self.assertEqual(serializer.data, expected)
def test_explicit_nesting(self):
writer = Person.objects.create(name="django", age=1)
post = BlogPost.objects.create(title="Test blog post", writer=writer)
+ comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
class PersonSerializer(serializers.ModelSerializer):
class Meta:
@@ -996,9 +1150,15 @@ class DepthTest(TestCase):
class Meta:
model = BlogPost
- serializer = BlogPostSerializer(instance=post)
- expected = {'id': 1, 'title': 'Test blog post',
- 'writer': {'id': 1, 'name': 'django', 'age': 1}}
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ blog_post = BlogPostSerializer()
+
+ class Meta:
+ model = BlogPostComment
+
+ serializer = BlogPostCommentSerializer(instance=comment)
+ expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
self.assertEqual(serializer.data, expected)
@@ -1066,7 +1226,7 @@ class DeserializeListTestCase(TestCase):
def test_no_errors(self):
data = [self.data.copy() for x in range(0, 3)]
- serializer = CommentSerializer(data=data)
+ serializer = CommentSerializer(data=data, many=True)
self.assertTrue(serializer.is_valid())
self.assertTrue(isinstance(serializer.object, list))
self.assertTrue(
@@ -1078,7 +1238,7 @@ class DeserializeListTestCase(TestCase):
invalid_item['email'] = ''
data = [self.data.copy(), invalid_item, self.data.copy()]
- serializer = CommentSerializer(data=data)
+ serializer = CommentSerializer(data=data, many=True)
self.assertFalse(serializer.is_valid())
expected = [{}, {'email': ['This field is required.']}, {}]
self.assertEqual(serializer.errors, expected)
diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py
index afc1a1a9..8b0ded1a 100644
--- a/rest_framework/tests/serializer_bulk_update.py
+++ b/rest_framework/tests/serializer_bulk_update.py
@@ -98,7 +98,7 @@ class BulkCreateSerializerTests(TestCase):
serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False)
- expected_errors = {'non_field_errors': ['Expected a list of items']}
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
self.assertEqual(serializer.errors, expected_errors)
@@ -115,7 +115,7 @@ class BulkCreateSerializerTests(TestCase):
serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False)
- expected_errors = {'non_field_errors': ['Expected a list of items']}
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
self.assertEqual(serializer.errors, expected_errors)
@@ -201,11 +201,12 @@ class BulkUpdateSerializerTests(TestCase):
'author': 'Haruki Murakami'
}
]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.data, data)
serializer.save()
new_data = self.BookSerializer(self.books(), many=True).data
+
self.assertEqual(data, new_data)
def test_bulk_update_and_create(self):
@@ -223,13 +224,36 @@ class BulkUpdateSerializerTests(TestCase):
'author': 'Haruki Murakami'
}
]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.data, data)
serializer.save()
new_data = self.BookSerializer(self.books(), many=True).data
self.assertEqual(data, new_data)
+ def test_bulk_update_invalid_create(self):
+ """
+ Bulk update serialization without allow_add_remove may not create items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
def test_bulk_update_error(self):
"""
Incorrect bulk update serialization should return error data.
@@ -249,6 +273,6 @@ class BulkUpdateSerializerTests(TestCase):
{},
{'id': ['Enter a whole number.']}
]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors)
diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py
index 6a29c652..71d0e24b 100644
--- a/rest_framework/tests/serializer_nested.py
+++ b/rest_framework/tests/serializer_nested.py
@@ -109,7 +109,7 @@ class WritableNestedSerializerBasicTests(TestCase):
}
]
- serializer = self.AlbumSerializer(data=data)
+ serializer = self.AlbumSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors)
@@ -241,6 +241,6 @@ class WritableNestedSerializerObjectTests(TestCase):
)
]
- serializer = self.AlbumSerializer(data=data)
+ serializer = self.AlbumSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.object, expected_object)
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 810cad63..93ea9816 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -1,3 +1,6 @@
+"""
+Provides various throttling policies.
+"""
from __future__ import unicode_literals
from django.core.cache import cache
from rest_framework import exceptions
@@ -28,9 +31,8 @@ class SimpleRateThrottle(BaseThrottle):
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.
- The rate (requests / seconds) is set by a :attr:`throttle` attribute
- on the :class:`.View` class. The attribute is a string of the form 'number of
- requests/period'.
+ The rate (requests / seconds) is set by a `throttle` attribute on the View
+ class. The attribute is a string of the form 'number_of_requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py
index af21ac79..d51374b0 100644
--- a/rest_framework/utils/breadcrumbs.py
+++ b/rest_framework/utils/breadcrumbs.py
@@ -1,26 +1,37 @@
from __future__ import unicode_literals
from django.core.urlresolvers import resolve, get_script_prefix
+from rest_framework.utils.formatting import get_view_name
def get_breadcrumbs(url):
- """Given a url returns a list of breadcrumbs, which are each a tuple of (name, url)."""
+ """
+ Given a url returns a list of breadcrumbs, which are each a
+ tuple of (name, url).
+ """
from rest_framework.views import APIView
def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):
- """Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url."""
+ """
+ Add tuples of (name, url) to the breadcrumbs list,
+ progressively chomping off parts of the url.
+ """
try:
(view, unused_args, unused_kwargs) = resolve(url)
except Exception:
pass
else:
- # Check if this is a REST framework view, and if so add it to the breadcrumbs
- if isinstance(getattr(view, 'cls_instance', None), APIView):
+ # Check if this is a REST framework view,
+ # and if so add it to the breadcrumbs
+ cls = getattr(view, 'cls', None)
+ if cls is not None and issubclass(cls, APIView):
# Don't list the same view twice in a row.
# Probably an optional trailing slash.
if not seen or seen[-1] != view:
- breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url))
+ suffix = getattr(view, 'suffix', None)
+ name = get_view_name(view.cls, suffix)
+ breadcrumbs_list.insert(0, (name, prefix + url))
seen.append(view)
if url == '':
@@ -28,11 +39,15 @@ def get_breadcrumbs(url):
return breadcrumbs_list
elif url.endswith('/'):
- # Drop trailing slash off the end and continue to try to resolve more breadcrumbs
- return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen)
-
- # Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs
- return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen)
+ # Drop trailing slash off the end and continue to try to
+ # resolve more breadcrumbs
+ url = url.rstrip('/')
+ return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
+
+ # Drop trailing non-slash off the end and continue to try to
+ # resolve more breadcrumbs
+ url = url[:url.rfind('/') + 1]
+ return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
prefix = get_script_prefix().rstrip('/')
url = url[len(prefix):]
diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py
new file mode 100644
index 00000000..ebadb3a6
--- /dev/null
+++ b/rest_framework/utils/formatting.py
@@ -0,0 +1,80 @@
+"""
+Utility functions to return a formatted name and description for a given view.
+"""
+from __future__ import unicode_literals
+
+from django.utils.html import escape
+from django.utils.safestring import mark_safe
+from rest_framework.compat import apply_markdown
+import re
+
+
+def _remove_trailing_string(content, trailing):
+ """
+ Strip trailing component `trailing` from `content` if it exists.
+ Used when generating names from view classes.
+ """
+ if content.endswith(trailing) and content != trailing:
+ return content[:-len(trailing)]
+ return content
+
+
+def _remove_leading_indent(content):
+ """
+ Remove leading indent from a block of text.
+ Used when generating descriptions from docstrings.
+ """
+ whitespace_counts = [len(line) - len(line.lstrip(' '))
+ for line in content.splitlines()[1:] if line.lstrip()]
+
+ # unindent the content if needed
+ if whitespace_counts:
+ whitespace_pattern = '^' + (' ' * min(whitespace_counts))
+ content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
+ content = content.strip('\n')
+ return content
+
+
+def _camelcase_to_spaces(content):
+ """
+ Translate 'CamelCaseNames' to 'Camel Case Names'.
+ Used when generating names from view classes.
+ """
+ camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))'
+ content = re.sub(camelcase_boundry, ' \\1', content).strip()
+ return ' '.join(content.split('_')).title()
+
+
+def get_view_name(cls, suffix=None):
+ """
+ Return a formatted name for an `APIView` class or `@api_view` function.
+ """
+ name = cls.__name__
+ name = _remove_trailing_string(name, 'View')
+ name = _remove_trailing_string(name, 'ViewSet')
+ name = _camelcase_to_spaces(name)
+ if suffix:
+ name += ' ' + suffix
+ return name
+
+
+def get_view_description(cls, html=False):
+ """
+ Return a description for an `APIView` class or `@api_view` function.
+ """
+ description = cls.__doc__ or ''
+ description = _remove_leading_indent(description)
+ if html:
+ return markup_description(description)
+ return description
+
+
+def markup_description(description):
+ """
+ Apply HTML markup to the given description.
+ """
+ if apply_markdown:
+ description = apply_markdown(description)
+ else:
+ description = escape(description).replace('\n', '<br />')
+ return mark_safe(description)
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 81cbdcbb..555fa2f4 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -1,54 +1,16 @@
"""
-Provides an APIView class that is used as the base of all class-based views.
+Provides an APIView class that is the base of all views in REST framework.
"""
from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied
-from django.http import Http404
-from django.utils.html import escape
-from django.utils.safestring import mark_safe
+from django.http import Http404, HttpResponse
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions
-from rest_framework.compat import View, apply_markdown
+from rest_framework.compat import View
from rest_framework.response import Response
from rest_framework.request import Request
from rest_framework.settings import api_settings
-import re
-
-
-def _remove_trailing_string(content, trailing):
- """
- Strip trailing component `trailing` from `content` if it exists.
- Used when generating names from view classes.
- """
- if content.endswith(trailing) and content != trailing:
- return content[:-len(trailing)]
- return content
-
-
-def _remove_leading_indent(content):
- """
- Remove leading indent from a block of text.
- Used when generating descriptions from docstrings.
- """
- whitespace_counts = [len(line) - len(line.lstrip(' '))
- for line in content.splitlines()[1:] if line.lstrip()]
-
- # unindent the content if needed
- if whitespace_counts:
- whitespace_pattern = '^' + (' ' * min(whitespace_counts))
- content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
- content = content.strip('\n')
- return content
-
-
-def _camelcase_to_spaces(content):
- """
- Translate 'CamelCaseNames' to 'Camel Case Names'.
- Used when generating names from view classes.
- """
- camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))'
- content = re.sub(camelcase_boundry, ' \\1', content).strip()
- return ' '.join(content.split('_')).title()
+from rest_framework.utils.formatting import get_view_name, get_view_description
class APIView(View):
@@ -64,22 +26,21 @@ class APIView(View):
@classmethod
def as_view(cls, **initkwargs):
"""
- Override the default :meth:`as_view` to store an instance of the view
- as an attribute on the callable function. This allows us to discover
- information about the view when we do URL reverse lookups.
+ Store the original class on the view function.
+
+ This allows us to discover information about the view when we do URL
+ reverse lookups. Used for breadcrumb generation.
"""
- # TODO: deprecate?
view = super(APIView, cls).as_view(**initkwargs)
- view.cls_instance = cls(**initkwargs)
+ view.cls = cls
return view
@property
def allowed_methods(self):
"""
- Return the list of allowed HTTP methods, uppercased.
+ Wrap Django's private `_allowed_methods` interface in a public property.
"""
- return [method.upper() for method in self.http_method_names
- if hasattr(self, method)]
+ return self._allowed_methods()
@property
def default_response_headers(self):
@@ -90,43 +51,10 @@ class APIView(View):
'Vary': 'Accept'
}
- def get_name(self):
- """
- Return the resource or view class name for use as this view's name.
- Override to customize.
- """
- # TODO: deprecate?
- name = self.__class__.__name__
- name = _remove_trailing_string(name, 'View')
- return _camelcase_to_spaces(name)
-
- def get_description(self, html=False):
- """
- Return the resource or view docstring for use as this view's description.
- Override to customize.
- """
- # TODO: deprecate?
- description = self.__doc__ or ''
- description = _remove_leading_indent(description)
- if html:
- return self.markup_description(description)
- return description
-
- def markup_description(self, description):
- """
- Apply HTML markup to the description of this view.
- """
- # TODO: deprecate?
- if apply_markdown:
- description = apply_markdown(description)
- else:
- description = escape(description).replace('\n', '<br />')
- return mark_safe(description)
-
def metadata(self, request):
return {
- 'name': self.get_name(),
- 'description': self.get_description(),
+ 'name': get_view_name(self.__class__),
+ 'description': get_view_description(self.__class__),
'renders': [renderer.media_type for renderer in self.renderer_classes],
'parses': [parser.media_type for parser in self.parser_classes],
}
@@ -140,7 +68,8 @@ class APIView(View):
def http_method_not_allowed(self, request, *args, **kwargs):
"""
- Called if `request.method` does not correspond to a handler method.
+ If `request.method` does not correspond to a handler method,
+ determine what kind of exception to raise.
"""
raise exceptions.MethodNotAllowed(request.method)
@@ -327,6 +256,12 @@ class APIView(View):
"""
Returns the final response object.
"""
+ # Make the error obvious if a proper response is not returned
+ assert isinstance(response, HttpResponse), (
+ 'Expected a `Response` to be returned from the view, '
+ 'but received a `%s`' % type(response)
+ )
+
if isinstance(response, Response):
if not getattr(request, 'accepted_renderer', None):
neg = self.perform_content_negotiation(request, force=True)
diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py
new file mode 100644
index 00000000..d91323f2
--- /dev/null
+++ b/rest_framework/viewsets.py
@@ -0,0 +1,139 @@
+"""
+ViewSets are essentially just a type of class based view, that doesn't provide
+any method handlers, such as `get()`, `post()`, etc... but instead has actions,
+such as `list()`, `retrieve()`, `create()`, etc...
+
+Actions are only bound to methods at the point of instantiating the views.
+
+ user_list = UserViewSet.as_view({'get': 'list'})
+ user_detail = UserViewSet.as_view({'get': 'retrieve'})
+
+Typically, rather than instantiate views from viewsets directly, you'll
+regsiter the viewset with a router and let the URL conf be determined
+automatically.
+
+ router = DefaultRouter()
+ router.register(r'users', UserViewSet, 'user')
+ urlpatterns = router.urls
+"""
+from __future__ import unicode_literals
+
+from functools import update_wrapper
+from django.utils.decorators import classonlymethod
+from rest_framework import views, generics, mixins
+
+
+class ViewSetMixin(object):
+ """
+ This is the magic.
+
+ Overrides `.as_view()` so that it takes an `actions` keyword that performs
+ the binding of HTTP methods to actions on the Resource.
+
+ For example, to create a concrete view binding the 'GET' and 'POST' methods
+ to the 'list' and 'create' actions...
+
+ view = MyViewSet.as_view({'get': 'list', 'post': 'create'})
+ """
+
+ @classonlymethod
+ def as_view(cls, actions=None, **initkwargs):
+ """
+ Because of the way class based views create a closure around the
+ instantiated view, we need to totally reimplement `.as_view`,
+ and slightly modify the view function that is created and returned.
+ """
+ # The suffix initkwarg is reserved for identifing the viewset type
+ # eg. 'List' or 'Instance'.
+ cls.suffix = None
+
+ # sanitize keyword arguments
+ for key in initkwargs:
+ if key in cls.http_method_names:
+ raise TypeError("You tried to pass in the %s method name as a "
+ "keyword argument to %s(). Don't do that."
+ % (key, cls.__name__))
+ if not hasattr(cls, key):
+ raise TypeError("%s() received an invalid keyword %r" % (
+ cls.__name__, key))
+
+ def view(request, *args, **kwargs):
+ self = cls(**initkwargs)
+ # We also store the mapping of request methods to actions,
+ # so that we can later set the action attribute.
+ # eg. `self.action = 'list'` on an incoming GET request.
+ self.action_map = actions
+
+ # Bind methods to actions
+ # This is the bit that's different to a standard view
+ for method, action in actions.items():
+ handler = getattr(self, action)
+ setattr(self, method, handler)
+
+ # Patch this in as it's otherwise only present from 1.5 onwards
+ if hasattr(self, 'get') and not hasattr(self, 'head'):
+ self.head = self.get
+
+ # And continue as usual
+ return self.dispatch(request, *args, **kwargs)
+
+ # take name and docstring from class
+ update_wrapper(view, cls, updated=())
+
+ # and possible attributes set by decorators
+ # like csrf_exempt from dispatch
+ update_wrapper(view, cls.dispatch, assigned=())
+
+ # We need to set these on the view function, so that breadcrumb
+ # generation can pick out these bits of information from a
+ # resolved URL.
+ view.cls = cls
+ view.suffix = initkwargs.get('suffix', None)
+ return view
+
+ def initialize_request(self, request, *args, **kargs):
+ """
+ Set the `.action` attribute on the view,
+ depending on the request method.
+ """
+ request = super(ViewSetMixin, self).initialize_request(request, *args, **kargs)
+ self.action = self.action_map.get(request.method.lower())
+ return request
+
+
+class ViewSet(ViewSetMixin, views.APIView):
+ """
+ The base ViewSet class does not provide any actions by default.
+ """
+ pass
+
+
+class GenericViewSet(ViewSetMixin, generics.GenericAPIView):
+ """
+ The GenericViewSet class does not provide any actions by default,
+ but does include the base set of generic view behavior, such as
+ the `get_object` and `get_queryset` methods.
+ """
+ pass
+
+
+class ReadOnlyModelViewSet(mixins.RetrieveModelMixin,
+ mixins.ListModelMixin,
+ GenericViewSet):
+ """
+ A viewset that provides default `list()` and `retrieve()` actions.
+ """
+ pass
+
+
+class ModelViewSet(mixins.CreateModelMixin,
+ mixins.RetrieveModelMixin,
+ mixins.UpdateModelMixin,
+ mixins.DestroyModelMixin,
+ mixins.ListModelMixin,
+ GenericViewSet):
+ """
+ A viewset that provides default `create()`, `retrieve()`, `update()`,
+ `partial_update()`, `destroy()` and `list()` actions.
+ """
+ pass