aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py5
-rw-r--r--rest_framework/authentication.py238
-rw-r--r--rest_framework/authtoken/migrations/0001_initial.py17
-rw-r--r--rest_framework/authtoken/models.py9
-rw-r--r--rest_framework/compat.py85
-rw-r--r--rest_framework/decorators.py30
-rw-r--r--rest_framework/fields.py296
-rw-r--r--rest_framework/filters.py101
-rw-r--r--rest_framework/generics.py364
-rw-r--r--rest_framework/mixins.py68
-rw-r--r--rest_framework/negotiation.py4
-rw-r--r--rest_framework/pagination.py6
-rw-r--r--rest_framework/parsers.py90
-rw-r--r--rest_framework/permissions.py63
-rw-r--r--rest_framework/relations.py217
-rw-r--r--rest_framework/renderers.py18
-rw-r--r--rest_framework/request.py13
-rw-r--r--rest_framework/response.py6
-rw-r--r--rest_framework/routers.py249
-rw-r--r--rest_framework/runtests/settings.py23
-rw-r--r--rest_framework/serializers.py422
-rw-r--r--rest_framework/settings.py31
-rw-r--r--rest_framework/templates/rest_framework/base.html2
-rw-r--r--rest_framework/templates/rest_framework/login.html54
-rw-r--r--rest_framework/templates/rest_framework/login_base.html51
-rw-r--r--rest_framework/templatetags/rest_framework.py95
-rw-r--r--rest_framework/tests/authentication.py341
-rw-r--r--rest_framework/tests/description.py63
-rw-r--r--rest_framework/tests/fields.py583
-rw-r--r--rest_framework/tests/filters.py474
-rw-r--r--rest_framework/tests/filterset.py169
-rw-r--r--rest_framework/tests/generics.py164
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py40
-rw-r--r--rest_framework/tests/models.py7
-rw-r--r--rest_framework/tests/multitable_inheritance.py67
-rw-r--r--rest_framework/tests/pagination.py112
-rw-r--r--rest_framework/tests/parsers.py33
-rw-r--r--rest_framework/tests/relations_hyperlink.py16
-rw-r--r--rest_framework/tests/relations_nested.py24
-rw-r--r--rest_framework/tests/relations_pk.py23
-rw-r--r--rest_framework/tests/request.py8
-rw-r--r--rest_framework/tests/routers.py55
-rw-r--r--rest_framework/tests/serializer.py135
-rw-r--r--rest_framework/tests/serializer_bulk_update.py278
-rw-r--r--rest_framework/tests/serializer_nested.py246
-rw-r--r--rest_framework/tests/status.py13
-rw-r--r--rest_framework/throttling.py8
-rw-r--r--rest_framework/utils/breadcrumbs.py35
-rw-r--r--rest_framework/utils/formatting.py80
-rw-r--r--rest_framework/views.py107
-rw-r--r--rest_framework/viewsets.py139
51 files changed, 4798 insertions, 979 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 29f3d7bc..0b1e67fb 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,6 +1,9 @@
-__version__ = '2.2.1'
+__version__ = '2.3.3'
VERSION = __version__ # synonym
# Header encoding (see RFC5987)
HTTP_HEADER_ENCODING = 'iso-8859-1'
+
+# Default datetime input and output formats
+ISO_8601 = 'iso-8601'
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 14b2136b..9caca788 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -1,13 +1,30 @@
"""
-Provides a set of pluggable authentication policies.
+Provides various authentication policies.
"""
from __future__ import unicode_literals
+import base64
+from datetime import datetime
+
from django.contrib.auth import authenticate
-from django.utils.encoding import DjangoUnicodeDecodeError
+from django.core.exceptions import ImproperlyConfigured
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
+from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
+from rest_framework.compat import oauth2_provider
from rest_framework.authtoken.models import Token
-import base64
+
+
+def get_authorization_header(request):
+ """
+ Return request's 'Authorization:' header, as a bytestring.
+
+ Hide some test client ickyness where the header can be unicode.
+ """
+ auth = request.META.get('HTTP_AUTHORIZATION', b'')
+ if type(auth) == type(''):
+ # Work around django test client oddness
+ auth = auth.encode(HTTP_HEADER_ENCODING)
+ return auth
class BaseAuthentication(object):
@@ -41,28 +58,25 @@ class BasicAuthentication(BaseAuthentication):
Returns a `User` if a correct username and password have been supplied
using HTTP Basic authentication. Otherwise returns `None`.
"""
- auth = request.META.get('HTTP_AUTHORIZATION', b'')
- if type(auth) == type(''):
- # Work around django test client oddness
- auth = auth.encode(HTTP_HEADER_ENCODING)
- auth = auth.split()
+ auth = get_authorization_header(request).split()
if not auth or auth[0].lower() != b'basic':
return None
- if len(auth) != 2:
- raise exceptions.AuthenticationFailed('Invalid basic header')
+ if len(auth) == 1:
+ msg = 'Invalid basic header. No credentials provided.'
+ raise exceptions.AuthenticationFailed(msg)
+ elif len(auth) > 2:
+ msg = 'Invalid basic header. Credentials string should not contain spaces.'
+ raise exceptions.AuthenticationFailed(msg)
try:
auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':')
except (TypeError, UnicodeDecodeError):
- raise exceptions.AuthenticationFailed('Invalid basic header')
-
- try:
- userid, password = auth_parts[0], auth_parts[2]
- except DjangoUnicodeDecodeError:
- raise exceptions.AuthenticationFailed('Invalid basic header')
+ msg = 'Invalid basic header. Credentials not correctly base64 encoded'
+ raise exceptions.AuthenticationFailed(msg)
+ userid, password = auth_parts[0], auth_parts[2]
return self.authenticate_credentials(userid, password)
def authenticate_credentials(self, userid, password):
@@ -70,9 +84,9 @@ class BasicAuthentication(BaseAuthentication):
Authenticate the userid and password against username and password.
"""
user = authenticate(username=userid, password=password)
- if user is not None and user.is_active:
- return (user, None)
- raise exceptions.AuthenticationFailed('Invalid username/password')
+ if user is None or not user.is_active:
+ raise exceptions.AuthenticationFailed('Invalid username/password')
+ return (user, None)
def authenticate_header(self, request):
return 'Basic realm="%s"' % self.www_authenticate_realm
@@ -131,13 +145,17 @@ class TokenAuthentication(BaseAuthentication):
"""
def authenticate(self, request):
- auth = request.META.get('HTTP_AUTHORIZATION', '').split()
+ auth = get_authorization_header(request).split()
- if not auth or auth[0].lower() != "token":
+ if not auth or auth[0].lower() != b'token':
return None
- if len(auth) != 2:
- raise exceptions.AuthenticationFailed('Invalid token header')
+ if len(auth) == 1:
+ msg = 'Invalid token header. No credentials provided.'
+ raise exceptions.AuthenticationFailed(msg)
+ elif len(auth) > 2:
+ msg = 'Invalid token header. Token string should not contain spaces.'
+ raise exceptions.AuthenticationFailed(msg)
return self.authenticate_credentials(auth[1])
@@ -147,12 +165,178 @@ class TokenAuthentication(BaseAuthentication):
except self.model.DoesNotExist:
raise exceptions.AuthenticationFailed('Invalid token')
- if token.user.is_active:
- return (token.user, token)
- raise exceptions.AuthenticationFailed('User inactive or deleted')
+ if not token.user.is_active:
+ raise exceptions.AuthenticationFailed('User inactive or deleted')
+
+ return (token.user, token)
def authenticate_header(self, request):
return 'Token'
-# TODO: OAuthAuthentication
+class OAuthAuthentication(BaseAuthentication):
+ """
+ OAuth 1.0a authentication backend using `django-oauth-plus` and `oauth2`.
+
+ Note: The `oauth2` package actually provides oauth1.0a support. Urg.
+ We import it from the `compat` module as `oauth`.
+ """
+ www_authenticate_realm = 'api'
+
+ def __init__(self, *args, **kwargs):
+ super(OAuthAuthentication, self).__init__(*args, **kwargs)
+
+ if oauth is None:
+ raise ImproperlyConfigured(
+ "The 'oauth2' package could not be imported."
+ "It is required for use with the 'OAuthAuthentication' class.")
+
+ if oauth_provider is None:
+ raise ImproperlyConfigured(
+ "The 'django-oauth-plus' package could not be imported."
+ "It is required for use with the 'OAuthAuthentication' class.")
+
+ def authenticate(self, request):
+ """
+ Returns two-tuple of (user, token) if authentication succeeds,
+ or None otherwise.
+ """
+ try:
+ oauth_request = oauth_provider.utils.get_oauth_request(request)
+ except oauth.Error as err:
+ raise exceptions.AuthenticationFailed(err.message)
+
+ if not oauth_request:
+ return None
+
+ oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES
+
+ found = any(param for param in oauth_params if param in oauth_request)
+ missing = list(param for param in oauth_params if param not in oauth_request)
+
+ if not found:
+ # OAuth authentication was not attempted.
+ return None
+
+ if missing:
+ # OAuth was attempted but missing parameters.
+ msg = 'Missing parameters: %s' % (', '.join(missing))
+ raise exceptions.AuthenticationFailed(msg)
+
+ if not self.check_nonce(request, oauth_request):
+ msg = 'Nonce check failed'
+ raise exceptions.AuthenticationFailed(msg)
+
+ try:
+ consumer_key = oauth_request.get_parameter('oauth_consumer_key')
+ consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
+ except oauth_provider.store.InvalidConsumerError as err:
+ raise exceptions.AuthenticationFailed(err)
+
+ if consumer.status != oauth_provider.consts.ACCEPTED:
+ msg = 'Invalid consumer key status: %s' % consumer.get_status_display()
+ raise exceptions.AuthenticationFailed(msg)
+
+ try:
+ token_param = oauth_request.get_parameter('oauth_token')
+ token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param)
+ except oauth_provider.store.InvalidTokenError:
+ msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token')
+ raise exceptions.AuthenticationFailed(msg)
+
+ try:
+ self.validate_token(request, consumer, token)
+ except oauth.Error as err:
+ raise exceptions.AuthenticationFailed(err.message)
+
+ user = token.user
+
+ if not user.is_active:
+ msg = 'User inactive or deleted: %s' % user.username
+ raise exceptions.AuthenticationFailed(msg)
+
+ return (token.user, token)
+
+ def authenticate_header(self, request):
+ """
+ If permission is denied, return a '401 Unauthorized' response,
+ with an appropraite 'WWW-Authenticate' header.
+ """
+ return 'OAuth realm="%s"' % self.www_authenticate_realm
+
+ def validate_token(self, request, consumer, token):
+ """
+ Check the token and raise an `oauth.Error` exception if invalid.
+ """
+ oauth_server, oauth_request = oauth_provider.utils.initialize_server_request(request)
+ oauth_server.verify_request(oauth_request, consumer, token)
+
+ def check_nonce(self, request, oauth_request):
+ """
+ Checks nonce of request, and return True if valid.
+ """
+ return oauth_provider_store.check_nonce(request, oauth_request, oauth_request['oauth_nonce'])
+
+
+class OAuth2Authentication(BaseAuthentication):
+ """
+ OAuth 2 authentication backend using `django-oauth2-provider`
+ """
+ www_authenticate_realm = 'api'
+
+ def __init__(self, *args, **kwargs):
+ super(OAuth2Authentication, self).__init__(*args, **kwargs)
+
+ if oauth2_provider is None:
+ raise ImproperlyConfigured(
+ "The 'django-oauth2-provider' package could not be imported. "
+ "It is required for use with the 'OAuth2Authentication' class.")
+
+ def authenticate(self, request):
+ """
+ Returns two-tuple of (user, token) if authentication succeeds,
+ or None otherwise.
+ """
+
+ auth = get_authorization_header(request).split()
+
+ if not auth or auth[0].lower() != b'bearer':
+ return None
+
+ if len(auth) == 1:
+ msg = 'Invalid bearer header. No credentials provided.'
+ raise exceptions.AuthenticationFailed(msg)
+ elif len(auth) > 2:
+ msg = 'Invalid bearer header. Token string should not contain spaces.'
+ raise exceptions.AuthenticationFailed(msg)
+
+ return self.authenticate_credentials(request, auth[1])
+
+ def authenticate_credentials(self, request, access_token):
+ """
+ Authenticate the request, given the access token.
+ """
+
+ try:
+ token = oauth2_provider.models.AccessToken.objects.select_related('user')
+ # TODO: Change to timezone aware datetime when oauth2_provider add
+ # support to it.
+ token = token.get(token=access_token, expires__gt=datetime.now())
+ except oauth2_provider.models.AccessToken.DoesNotExist:
+ raise exceptions.AuthenticationFailed('Invalid token')
+
+ user = token.user
+
+ if not user.is_active:
+ msg = 'User inactive or deleted: %s' % user.username
+ raise exceptions.AuthenticationFailed(msg)
+
+ return (user, token)
+
+ def authenticate_header(self, request):
+ """
+ Bearer is the only finalized type currently
+
+ Check details on the `OAuth2Authentication.authenticate` method
+ """
+ return 'Bearer realm="%s"' % self.www_authenticate_realm
diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py
index f4e052e4..d5965e40 100644
--- a/rest_framework/authtoken/migrations/0001_initial.py
+++ b/rest_framework/authtoken/migrations/0001_initial.py
@@ -4,6 +4,8 @@ from south.db import db
from south.v2 import SchemaMigration
from django.db import models
+from rest_framework.settings import api_settings
+
try:
from django.contrib.auth import get_user_model
@@ -45,20 +47,7 @@ class Migration(SchemaMigration):
'name': ('django.db.models.fields.CharField', [], {'max_length': '50'})
},
"%s.%s" % (User._meta.app_label, User._meta.module_name): {
- 'Meta': {'object_name': 'User'},
- 'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
- 'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}),
- 'first_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}),
- 'groups': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Group']", 'symmetrical': 'False', 'blank': 'True'}),
- 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
- 'is_active': ('django.db.models.fields.BooleanField', [], {'default': 'True'}),
- 'is_staff': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
- 'is_superuser': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
- 'last_login': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
- 'last_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}),
- 'password': ('django.db.models.fields.CharField', [], {'max_length': '128'}),
- 'user_permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'}),
- 'username': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '30'})
+ 'Meta': {'object_name': User._meta.module_name},
},
'authtoken.token': {
'Meta': {'object_name': 'Token'},
diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py
index 7f5a75a3..52c45ad1 100644
--- a/rest_framework/authtoken/models.py
+++ b/rest_framework/authtoken/models.py
@@ -2,6 +2,7 @@ import uuid
import hmac
from hashlib import sha1
from rest_framework.compat import User
+from django.conf import settings
from django.db import models
@@ -13,6 +14,14 @@ class Token(models.Model):
user = models.OneToOneField(User, related_name='auth_token')
created = models.DateTimeField(auto_now_add=True)
+ class Meta:
+ # Work around for a bug in Django:
+ # https://code.djangoproject.com/ticket/19422
+ #
+ # Also see corresponding ticket:
+ # https://github.com/tomchristie/django-rest-framework/issues/705
+ abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
+
def save(self, *args, **kwargs):
if not self.key:
self.key = self.generate_key()
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 07fdddce..cd39f544 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -6,6 +6,7 @@ versions of django/python, and compatibility wrappers around optional packages.
from __future__ import unicode_literals
import django
+from django.core.exceptions import ImproperlyConfigured
# Try to import six from Django, fallback to included `six`.
try:
@@ -87,9 +88,7 @@ else:
raise ImportError("User model is not to be found.")
-# First implementation of Django class-based views did not include head method
-# in base View class - https://code.djangoproject.com/ticket/15668
-if django.VERSION >= (1, 4):
+if django.VERSION >= (1, 5):
from django.views.generic import View
else:
from django.views.generic import View as _View
@@ -97,6 +96,8 @@ else:
from django.utils.functional import update_wrapper
class View(_View):
+ # 1.3 does not include head method in base View class
+ # See: https://code.djangoproject.com/ticket/15668
@classonlymethod
def as_view(cls, **initkwargs):
"""
@@ -126,11 +127,15 @@ else:
update_wrapper(view, cls.dispatch, assigned=())
return view
-# Taken from @markotibold's attempt at supporting PATCH.
-# https://github.com/markotibold/django-rest-framework/tree/patch
-http_method_names = set(View.http_method_names)
-http_method_names.add('patch')
-View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django
+ # _allowed_methods only present from 1.5 onwards
+ def _allowed_methods(self):
+ return [m.upper() for m in self.http_method_names if hasattr(self, m)]
+
+
+# PATCH method is not implemented by Django
+if 'patch' not in View.http_method_names:
+ View.http_method_names = View.http_method_names + ['patch']
+
# PUT, DELETE do not require CSRF until 1.4. They should. Make it better.
if django.VERSION >= (1, 4):
@@ -395,6 +400,41 @@ except ImportError:
kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None)
return datetime.datetime(**kw)
+
+# smart_urlquote is new on Django 1.4
+try:
+ from django.utils.html import smart_urlquote
+except ImportError:
+ import re
+ from django.utils.encoding import smart_str
+ try:
+ from urllib.parse import quote, urlsplit, urlunsplit
+ except ImportError: # Python 2
+ from urllib import quote
+ from urlparse import urlsplit, urlunsplit
+
+ unquoted_percents_re = re.compile(r'%(?![0-9A-Fa-f]{2})')
+
+ def smart_urlquote(url):
+ "Quotes a URL if it isn't already quoted."
+ # Handle IDN before quoting.
+ scheme, netloc, path, query, fragment = urlsplit(url)
+ try:
+ netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE
+ except UnicodeError: # invalid domain part
+ pass
+ else:
+ url = urlunsplit((scheme, netloc, path, query, fragment))
+
+ # An URL is considered unquoted if it contains no % characters or
+ # contains a % not followed by two hexadecimal digits. See #9655.
+ if '%' not in url or unquoted_percents_re.search(url):
+ # See http://bugs.python.org/issue2637
+ url = quote(smart_str(url), safe=b'!*\'();:@&=+$,/?#[]~')
+
+ return force_text(url)
+
+
# Markdown is optional
try:
import markdown
@@ -426,3 +466,32 @@ try:
import defusedxml.ElementTree as etree
except ImportError:
etree = None
+
+# OAuth is optional
+try:
+ # Note: The `oauth2` package actually provides oauth1.0a support. Urg.
+ import oauth2 as oauth
+except ImportError:
+ oauth = None
+
+# OAuth is optional
+try:
+ import oauth_provider
+ from oauth_provider.store import store as oauth_provider_store
+except (ImportError, ImproperlyConfigured):
+ oauth_provider = None
+ oauth_provider_store = None
+
+# OAuth 2 support is optional
+try:
+ import provider.oauth2 as oauth2_provider
+ from provider.oauth2 import models as oauth2_provider_models
+ from provider.oauth2 import forms as oauth2_provider_forms
+ from provider import scope as oauth2_provider_scope
+ from provider import constants as oauth2_constants
+except ImportError:
+ oauth2_provider = None
+ oauth2_provider_models = None
+ oauth2_provider_forms = None
+ oauth2_provider_scope = None
+ oauth2_constants = None
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 8250cd3b..81e585e1 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -1,3 +1,11 @@
+"""
+The most imporant decorator in this module is `@api_view`, which is used
+for writing function-based views with REST framework.
+
+There are also various decorators for setting the API policies on function
+based views, as well as the `@action` and `@link` decorators, which are
+used to annotate methods on viewsets that should be included by routers.
+"""
from __future__ import unicode_literals
from rest_framework.compat import six
from rest_framework.views import APIView
@@ -97,3 +105,25 @@ def permission_classes(permission_classes):
func.permission_classes = permission_classes
return func
return decorator
+
+
+def link(**kwargs):
+ """
+ Used to mark a method on a ViewSet that should be routed for GET requests.
+ """
+ def decorator(func):
+ func.bind_to_method = 'get'
+ func.kwargs = kwargs
+ return func
+ return decorator
+
+
+def action(**kwargs):
+ """
+ Used to mark a method on a ViewSet that should be routed for POST requests.
+ """
+ def decorator(func):
+ func.bind_to_method = 'post'
+ func.kwargs = kwargs
+ return func
+ return decorator
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 86c3a837..c83ee5ec 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,7 +1,13 @@
+"""
+Serializer fields perform validation on incoming data.
+
+They are very similar to Django's form fields.
+"""
from __future__ import unicode_literals
import copy
import datetime
+from decimal import Decimal, DecimalException
import inspect
import re
import warnings
@@ -13,26 +19,29 @@ from django import forms
from django.forms import widgets
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
-from rest_framework.compat import parse_date, parse_datetime
-from rest_framework.compat import timezone
+
+from rest_framework import ISO_8601
+from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time
from rest_framework.compat import BytesIO
from rest_framework.compat import six
from rest_framework.compat import smart_text
-from rest_framework.compat import parse_time
+from rest_framework.settings import api_settings
def is_simple_callable(obj):
"""
True if the object is a callable that takes no arguments.
"""
- try:
- args, _, _, defaults = inspect.getargspec(obj)
- except TypeError:
+ function = inspect.isfunction(obj)
+ method = inspect.ismethod(obj)
+
+ if not (function or method):
return False
- else:
- len_args = len(args) if inspect.isfunction(obj) else len(args) - 1
- len_defaults = len(defaults) if defaults else 0
- return len_args <= len_defaults
+
+ args, _, _, defaults = inspect.getargspec(obj)
+ len_args = len(args) if function else len(args) - 1
+ len_defaults = len(defaults) if defaults else 0
+ return len_args <= len_defaults
def get_component(obj, attr_name):
@@ -50,6 +59,46 @@ def get_component(obj, attr_name):
return val
+def readable_datetime_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]')
+ return humanize_strptime(format)
+
+
+def readable_date_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
+ return humanize_strptime(format)
+
+
+def readable_time_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
+ return humanize_strptime(format)
+
+
+def humanize_strptime(format_string):
+ # Note that we're missing some of the locale specific mappings that
+ # don't really make sense.
+ mapping = {
+ "%Y": "YYYY",
+ "%y": "YY",
+ "%m": "MM",
+ "%b": "[Jan-Dec]",
+ "%B": "[January-December]",
+ "%d": "DD",
+ "%H": "hh",
+ "%I": "hh", # Requires '%p' to differentiate from '%H'.
+ "%M": "mm",
+ "%S": "ss",
+ "%f": "uuuuuu",
+ "%a": "[Mon-Sun]",
+ "%A": "[Monday-Sunday]",
+ "%p": "[AM|PM]",
+ "%z": "[+HHMM|-HHMM]"
+ }
+ for key, val in mapping.items():
+ format_string = format_string.replace(key, val)
+ return format_string
+
+
class Field(object):
read_only = True
creation_counter = 0
@@ -151,9 +200,9 @@ class WritableField(Field):
# 'blank' is to be deprecated in favor of 'required'
if blank is not None:
- warnings.warn('The `blank` keyword argument is due to deprecated. '
+ warnings.warn('The `blank` keyword argument is deprecated. '
'Use the `required` keyword argument instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
required = not(blank)
super(WritableField, self).__init__(source=source)
@@ -447,12 +496,16 @@ class DateField(WritableField):
form_field_class = forms.DateField
default_error_messages = {
- 'invalid': _("'%s' value has an invalid date format. It must be "
- "in YYYY-MM-DD format."),
- 'invalid_date': _("'%s' value has the correct format (YYYY-MM-DD) "
- "but it is an invalid date."),
+ 'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
}
empty = None
+ input_formats = api_settings.DATE_INPUT_FORMATS
+ format = api_settings.DATE_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(DateField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
@@ -468,17 +521,37 @@ class DateField(WritableField):
if isinstance(value, datetime.date):
return value
- try:
- parsed = parse_date(value)
- if parsed is not None:
- return parsed
- except (ValueError, TypeError):
- msg = self.error_messages['invalid_date'] % value
- raise ValidationError(msg)
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_date(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed.date()
- msg = self.error_messages['invalid'] % value
+ msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats)
raise ValidationError(msg)
+ def to_native(self, value):
+ if value is None or self.format is None:
+ return value
+
+ if isinstance(value, datetime.datetime):
+ value = value.date()
+
+ if self.format.lower() == ISO_8601:
+ return value.isoformat()
+ return value.strftime(self.format)
+
class DateTimeField(WritableField):
type_name = 'DateTimeField'
@@ -486,15 +559,16 @@ class DateTimeField(WritableField):
form_field_class = forms.DateTimeField
default_error_messages = {
- 'invalid': _("'%s' value has an invalid format. It must be in "
- "YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."),
- 'invalid_date': _("'%s' value has the correct format "
- "(YYYY-MM-DD) but it is an invalid date."),
- 'invalid_datetime': _("'%s' value has the correct format "
- "(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) "
- "but it is an invalid date/time."),
+ 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
}
empty = None
+ input_formats = api_settings.DATETIME_INPUT_FORMATS
+ format = api_settings.DATETIME_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(DateTimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
@@ -516,25 +590,37 @@ class DateTimeField(WritableField):
value = timezone.make_aware(value, default_timezone)
return value
- try:
- parsed = parse_datetime(value)
- if parsed is not None:
- return parsed
- except (ValueError, TypeError):
- msg = self.error_messages['invalid_datetime'] % value
- raise ValidationError(msg)
-
- try:
- parsed = parse_date(value)
- if parsed is not None:
- return datetime.datetime(parsed.year, parsed.month, parsed.day)
- except (ValueError, TypeError):
- msg = self.error_messages['invalid_date'] % value
- raise ValidationError(msg)
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_datetime(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed
- msg = self.error_messages['invalid'] % value
+ msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats)
raise ValidationError(msg)
+ def to_native(self, value):
+ if value is None or self.format is None:
+ return value
+
+ if self.format.lower() == ISO_8601:
+ ret = value.isoformat()
+ if ret.endswith('+00:00'):
+ ret = ret[:-6] + 'Z'
+ return ret
+ return value.strftime(self.format)
+
class TimeField(WritableField):
type_name = 'TimeField'
@@ -542,10 +628,16 @@ class TimeField(WritableField):
form_field_class = forms.TimeField
default_error_messages = {
- 'invalid': _("'%s' value has an invalid format. It must be a valid "
- "time in the HH:MM[:ss[.uuuuuu]] format."),
+ 'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
}
empty = None
+ input_formats = api_settings.TIME_INPUT_FORMATS
+ format = api_settings.TIME_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(TimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
@@ -554,13 +646,36 @@ class TimeField(WritableField):
if isinstance(value, datetime.time):
return value
- try:
- parsed = parse_time(value)
- assert parsed is not None
- return parsed
- except (ValueError, TypeError):
- msg = self.error_messages['invalid'] % value
- raise ValidationError(msg)
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_time(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed.time()
+
+ msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats)
+ raise ValidationError(msg)
+
+ def to_native(self, value):
+ if value is None or self.format is None:
+ return value
+
+ if isinstance(value, datetime.datetime):
+ value = value.time()
+
+ if self.format.lower() == ISO_8601:
+ return value.isoformat()
+ return value.strftime(self.format)
class IntegerField(WritableField):
@@ -612,6 +727,75 @@ class FloatField(WritableField):
raise ValidationError(msg)
+class DecimalField(WritableField):
+ type_name = 'DecimalField'
+ form_field_class = forms.DecimalField
+
+ default_error_messages = {
+ 'invalid': _('Enter a number.'),
+ 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
+ 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
+ 'max_digits': _('Ensure that there are no more than %s digits in total.'),
+ 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'),
+ 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.')
+ }
+
+ def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs):
+ self.max_value, self.min_value = max_value, min_value
+ self.max_digits, self.decimal_places = max_digits, decimal_places
+ super(DecimalField, self).__init__(*args, **kwargs)
+
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+
+ def from_native(self, value):
+ """
+ Validates that the input is a decimal number. Returns a Decimal
+ instance. Returns None for empty values. Ensures that there are no more
+ than max_digits in the number, and no more than decimal_places digits
+ after the decimal point.
+ """
+ if value in validators.EMPTY_VALUES:
+ return None
+ value = smart_text(value).strip()
+ try:
+ value = Decimal(value)
+ except DecimalException:
+ raise ValidationError(self.error_messages['invalid'])
+ return value
+
+ def validate(self, value):
+ super(DecimalField, self).validate(value)
+ if value in validators.EMPTY_VALUES:
+ return
+ # Check for NaN, Inf and -Inf values. We can't compare directly for NaN,
+ # since it is never equal to itself. However, NaN is the only value that
+ # isn't equal to itself, so we can use this to identify NaN
+ if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
+ raise ValidationError(self.error_messages['invalid'])
+ sign, digittuple, exponent = value.as_tuple()
+ decimals = abs(exponent)
+ # digittuple doesn't include any leading zeros.
+ digits = len(digittuple)
+ if decimals > digits:
+ # We have leading zeros up to or past the decimal point. Count
+ # everything past the decimal point as a digit. We do not count
+ # 0 before the decimal point as a digit since that would mean
+ # we would not allow max_digits = decimal_places.
+ digits = decimals
+ whole_digits = digits - decimals
+
+ if self.max_digits is not None and digits > self.max_digits:
+ raise ValidationError(self.error_messages['max_digits'] % self.max_digits)
+ if self.decimal_places is not None and decimals > self.decimal_places:
+ raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places)
+ if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
+ raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places))
+ return value
+
+
class FileField(WritableField):
use_files = True
type_name = 'FileField'
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index 6fea46fa..c058bc71 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -1,5 +1,12 @@
+"""
+Provides generic filtering backends that can be used to filter the results
+returned by list views.
+"""
from __future__ import unicode_literals
-from rest_framework.compat import django_filters
+from django.db import models
+from rest_framework.compat import django_filters, six
+from functools import reduce
+import operator
FilterSet = django_filters and django_filters.FilterSet or None
@@ -25,36 +32,112 @@ class DjangoFilterBackend(BaseFilterBackend):
def __init__(self):
assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed'
- def get_filter_class(self, view):
+ def get_filter_class(self, view, queryset=None):
"""
Return the django-filters `FilterSet` used to filter the queryset.
"""
filter_class = getattr(view, 'filter_class', None)
filter_fields = getattr(view, 'filter_fields', None)
- view_model = getattr(view, 'model', None)
if filter_class:
filter_model = filter_class.Meta.model
- assert issubclass(filter_model, view_model), \
- 'FilterSet model %s does not match view model %s' % \
- (filter_model, view_model)
+ assert issubclass(filter_model, queryset.model), \
+ 'FilterSet model %s does not match queryset model %s' % \
+ (filter_model, queryset.model)
return filter_class
if filter_fields:
class AutoFilterSet(self.default_filter_set):
class Meta:
- model = view_model
+ model = queryset.model
fields = filter_fields
return AutoFilterSet
return None
def filter_queryset(self, request, queryset, view):
- filter_class = self.get_filter_class(view)
+ filter_class = self.get_filter_class(view, queryset)
if filter_class:
- return filter_class(request.QUERY_PARAMS, queryset=queryset)
+ return filter_class(request.QUERY_PARAMS, queryset=queryset).qs
+
+ return queryset
+
+
+class SearchFilter(BaseFilterBackend):
+ search_param = 'search' # The URL query parameter used for the search.
+
+ def get_search_terms(self, request):
+ """
+ Search terms are set by a ?search=... query parameter,
+ and may be comma and/or whitespace delimited.
+ """
+ params = request.QUERY_PARAMS.get(self.search_param, '')
+ return params.replace(',', ' ').split()
+
+ def construct_search(self, field_name):
+ if field_name.startswith('^'):
+ return "%s__istartswith" % field_name[1:]
+ elif field_name.startswith('='):
+ return "%s__iexact" % field_name[1:]
+ elif field_name.startswith('@'):
+ return "%s__search" % field_name[1:]
+ else:
+ return "%s__icontains" % field_name
+
+ def filter_queryset(self, request, queryset, view):
+ search_fields = getattr(view, 'search_fields', None)
+
+ if not search_fields:
+ return queryset
+
+ orm_lookups = [self.construct_search(str(search_field))
+ for search_field in search_fields]
+
+ for search_term in self.get_search_terms(request):
+ or_queries = [models.Q(**{orm_lookup: search_term})
+ for orm_lookup in orm_lookups]
+ queryset = queryset.filter(reduce(operator.or_, or_queries))
+
+ return queryset
+
+
+class OrderingFilter(BaseFilterBackend):
+ ordering_param = 'ordering' # The URL query parameter used for the ordering.
+
+ def get_ordering(self, request):
+ """
+ Search terms are set by a ?search=... query parameter,
+ and may be comma and/or whitespace delimited.
+ """
+ params = request.QUERY_PARAMS.get(self.ordering_param)
+ if params:
+ return [param.strip() for param in params.split(',')]
+
+ def get_default_ordering(self, view):
+ ordering = getattr(view, 'ordering', None)
+ if isinstance(ordering, six.string_types):
+ return (ordering,)
+ return ordering
+
+ def remove_invalid_fields(self, queryset, ordering):
+ field_names = [field.name for field in queryset.model._meta.fields]
+ return [term for term in ordering if term.lstrip('-') in field_names]
+
+ def filter_queryset(self, request, queryset, view):
+ ordering = self.get_ordering(request)
+
+ if ordering:
+ # Skip any incorrect parameters
+ ordering = self.remove_invalid_fields(queryset, ordering)
+
+ if not ordering:
+ # Use 'ordering' attribtue by default
+ ordering = self.get_default_ordering(view)
+
+ if ordering:
+ return queryset.order_by(*ordering)
return queryset
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 9ae8cf0a..05ec93d3 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -2,22 +2,59 @@
Generic views that provide commonly needed behaviour.
"""
from __future__ import unicode_literals
+
+from django.core.exceptions import ImproperlyConfigured
+from django.core.paginator import Paginator, InvalidPage
+from django.http import Http404
+from django.shortcuts import get_object_or_404
+from django.utils.translation import ugettext as _
from rest_framework import views, mixins
+from rest_framework.exceptions import ConfigurationError
from rest_framework.settings import api_settings
-from django.views.generic.detail import SingleObjectMixin
-from django.views.generic.list import MultipleObjectMixin
-
+import warnings
-### Base classes for the generic views ###
class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
- model = None
+ # You'll need to either set these attributes,
+ # or override `get_queryset()`/`get_serializer_class()`.
+ queryset = None
serializer_class = None
+
+ # This shortcut may be used instead of setting either or both
+ # of the `queryset`/`serializer_class` attributes, although using
+ # the explicit style is generally preferred.
+ model = None
+
+ # If you want to use object lookups other than pk, set this attribute.
+ # For more complex lookup requirements override `get_object()`.
+ lookup_field = 'pk'
+
+ # Pagination settings
+ paginate_by = api_settings.PAGINATE_BY
+ paginate_by_param = api_settings.PAGINATE_BY_PARAM
+ pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
+ page_kwarg = 'page'
+
+ # The filter backend classes to use for queryset filtering
+ filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
+
+ # The following attributes may be subject to change,
+ # and should be considered private API.
model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
+ paginator_class = Paginator
+
+ ######################################
+ # These are pending deprecation...
+
+ pk_url_kwarg = 'pk'
+ slug_url_kwarg = 'slug'
+ slug_field = 'slug'
+ allow_empty = True
+ filter_backend = api_settings.FILTER_BACKEND
def get_serializer_context(self):
"""
@@ -29,24 +66,6 @@ class GenericAPIView(views.APIView):
'view': self
}
- def get_serializer_class(self):
- """
- Return the class to use for the serializer.
-
- Defaults to using `self.serializer_class`, falls back to constructing a
- model serializer class using `self.model_serializer_class`, with
- `self.model` as the model.
- """
- serializer_class = self.serializer_class
-
- if serializer_class is None:
- class DefaultSerializer(self.model_serializer_class):
- class Meta:
- model = self.model
- serializer_class = DefaultSerializer
-
- return serializer_class
-
def get_serializer(self, instance=None, data=None,
files=None, many=False, partial=False):
"""
@@ -58,86 +77,244 @@ class GenericAPIView(views.APIView):
return serializer_class(instance, data=data, files=files,
many=many, partial=partial, context=context)
- def pre_save(self, obj):
+ def get_pagination_serializer(self, page):
"""
- Placeholder method for calling before saving an object.
- May be used eg. to set attributes on the object that are implicit
- in either the request, or the url.
+ Return a serializer instance to use with paginated data.
"""
- pass
+ class SerializerClass(self.pagination_serializer_class):
+ class Meta:
+ object_serializer_class = self.get_serializer_class()
- def post_save(self, obj, created=False):
+ pagination_serializer_class = SerializerClass
+ context = self.get_serializer_context()
+ return pagination_serializer_class(instance=page, context=context)
+
+ def paginate_queryset(self, queryset, page_size=None):
"""
- Placeholder method for calling after saving an object.
+ Paginate a queryset if required, either returning a page object,
+ or `None` if pagination is not configured for this view.
"""
- pass
-
-
-class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
- """
- Base class for generic views onto a queryset.
- """
-
- paginate_by = api_settings.PAGINATE_BY
- paginate_by_param = api_settings.PAGINATE_BY_PARAM
- pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
- filter_backend = api_settings.FILTER_BACKEND
+ deprecated_style = False
+ if page_size is not None:
+ warnings.warn('The `page_size` parameter to `paginate_queryset()` '
+ 'is due to be deprecated. '
+ 'Note that the return style of this method is also '
+ 'changed, and will simply return a page object '
+ 'when called without a `page_size` argument.',
+ PendingDeprecationWarning, stacklevel=2)
+ deprecated_style = True
+ else:
+ # Determine the required page size.
+ # If pagination is not configured, simply return None.
+ page_size = self.get_paginate_by()
+ if not page_size:
+ return None
+
+ if not self.allow_empty:
+ warnings.warn(
+ 'The `allow_empty` parameter is due to be deprecated. '
+ 'To use `allow_empty=False` style behavior, You should override '
+ '`get_queryset()` and explicitly raise a 404 on empty querysets.',
+ PendingDeprecationWarning, stacklevel=2
+ )
+
+ paginator = self.paginator_class(queryset, page_size,
+ allow_empty_first_page=self.allow_empty)
+ page_kwarg = self.kwargs.get(self.page_kwarg)
+ page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
+ page = page_kwarg or page_query_param or 1
+ try:
+ page_number = int(page)
+ except ValueError:
+ if page == 'last':
+ page_number = paginator.num_pages
+ else:
+ raise Http404(_("Page is not 'last', nor can it be converted to an int."))
+ try:
+ page = paginator.page(page_number)
+ except InvalidPage as e:
+ raise Http404(_('Invalid page (%(page_number)s): %(message)s') % {
+ 'page_number': page_number,
+ 'message': str(e)
+ })
+
+ if deprecated_style:
+ return (paginator, page, page.object_list, page.has_other_pages())
+ return page
def filter_queryset(self, queryset):
"""
Given a queryset, filter it with whichever filter backend is in use.
- """
- if not self.filter_backend:
- return queryset
- backend = self.filter_backend()
- return backend.filter_queryset(self.request, queryset, self)
- def get_pagination_serializer(self, page=None):
+ You are unlikely to want to override this method, although you may need
+ to call it either from a list view, or from a custom `get_object`
+ method if you want to apply the configured filtering backend to the
+ default queryset.
"""
- Return a serializer instance to use with paginated data.
+ filter_backends = self.filter_backends or []
+ if not filter_backends and self.filter_backend:
+ warnings.warn(
+ 'The `filter_backend` attribute and `FILTER_BACKEND` setting '
+ 'are due to be deprecated in favor of a `filter_backends` '
+ 'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take '
+ 'a *list* of filter backend classes.',
+ PendingDeprecationWarning, stacklevel=2
+ )
+ filter_backends = [self.filter_backend]
+
+ for backend in filter_backends:
+ queryset = backend().filter_queryset(self.request, queryset, self)
+ return queryset
+
+ ########################
+ ### The following methods provide default implementations
+ ### that you may want to override for more complex cases.
+
+ def get_paginate_by(self, queryset=None):
"""
- class SerializerClass(self.pagination_serializer_class):
- class Meta:
- object_serializer_class = self.get_serializer_class()
+ Return the size of pages to use with pagination.
- pagination_serializer_class = SerializerClass
- context = self.get_serializer_context()
- return pagination_serializer_class(instance=page, context=context)
+ If `PAGINATE_BY_PARAM` is set it will attempt to get the page size
+ from a named query parameter in the url, eg. ?page_size=100
- def get_paginate_by(self, queryset):
- """
- Return the size of pages to use with pagination.
+ Otherwise defaults to using `self.paginate_by`.
"""
+ if queryset is not None:
+ warnings.warn('The `queryset` parameter to `get_paginate_by()` '
+ 'is due to be deprecated.',
+ PendingDeprecationWarning, stacklevel=2)
+
if self.paginate_by_param:
query_params = self.request.QUERY_PARAMS
try:
return int(query_params[self.paginate_by_param])
except (KeyError, ValueError):
pass
+
return self.paginate_by
+ def get_serializer_class(self):
+ """
+ Return the class to use for the serializer.
+ Defaults to using `self.serializer_class`.
+
+ You may want to override this if you need to provide different
+ serializations depending on the incoming request.
-class SingleObjectAPIView(SingleObjectMixin, GenericAPIView):
- """
- Base class for generic views onto a model instance.
- """
+ (Eg. admins get full serialization, others get basic serilization)
+ """
+ serializer_class = self.serializer_class
+ if serializer_class is not None:
+ return serializer_class
- pk_url_kwarg = 'pk' # Not provided in Django 1.3
- slug_url_kwarg = 'slug' # Not provided in Django 1.3
- slug_field = 'slug'
+ assert self.model is not None, \
+ "'%s' should either include a 'serializer_class' attribute, " \
+ "or use the 'model' attribute as a shortcut for " \
+ "automatically generating a serializer class." \
+ % self.__class__.__name__
+
+ class DefaultSerializer(self.model_serializer_class):
+ class Meta:
+ model = self.model
+ return DefaultSerializer
+
+ def get_queryset(self):
+ """
+ Get the list of items for this view.
+ This must be an iterable, and may be a queryset.
+ Defaults to using `self.queryset`.
+
+ You may want to override this if you need to provide different
+ querysets depending on the incoming request.
+
+ (Eg. return a list of items that is specific to the user)
+ """
+ if self.queryset is not None:
+ return self.queryset._clone()
+
+ if self.model is not None:
+ return self.model._default_manager.all()
+
+ raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'"
+ % self.__class__.__name__)
def get_object(self, queryset=None):
"""
- Override default to add support for object-level permissions.
+ Returns the object the view is displaying.
+
+ You may want to override this if you need to provide non-standard
+ queryset lookups. Eg if objects are referenced using multiple
+ keyword arguments in the url conf.
"""
- obj = super(SingleObjectAPIView, self).get_object(queryset)
+ # Determine the base queryset to use.
+ if queryset is None:
+ queryset = self.filter_queryset(self.get_queryset())
+ else:
+ pass # Deprecation warning
+
+ # Perform the lookup filtering.
+ pk = self.kwargs.get(self.pk_url_kwarg, None)
+ slug = self.kwargs.get(self.slug_url_kwarg, None)
+ lookup = self.kwargs.get(self.lookup_field, None)
+
+ if lookup is not None:
+ filter_kwargs = {self.lookup_field: lookup}
+ elif pk is not None and self.lookup_field == 'pk':
+ warnings.warn(
+ 'The `pk_url_kwarg` attribute is due to be deprecated. '
+ 'Use the `lookup_field` attribute instead',
+ PendingDeprecationWarning
+ )
+ filter_kwargs = {'pk': pk}
+ elif slug is not None and self.lookup_field == 'pk':
+ warnings.warn(
+ 'The `slug_url_kwarg` attribute is due to be deprecated. '
+ 'Use the `lookup_field` attribute instead',
+ PendingDeprecationWarning
+ )
+ filter_kwargs = {self.slug_field: slug}
+ else:
+ raise ConfigurationError(
+ 'Expected view %s to be called with a URL keyword argument '
+ 'named "%s". Fix your URL conf, or set the `.lookup_field` '
+ 'attribute on the view correctly.' %
+ (self.__class__.__name__, self.lookup_field)
+ )
+
+ obj = get_object_or_404(queryset, **filter_kwargs)
+
+ # May raise a permission denied
self.check_object_permissions(self.request, obj)
+
return obj
+ ########################
+ ### The following are placeholder methods,
+ ### and are intended to be overridden.
+ ###
+ ### The are not called by GenericAPIView directly,
+ ### but are used by the mixin methods.
+
+ def pre_save(self, obj):
+ """
+ Placeholder method for calling before saving an object.
+
+ May be used to set attributes on the object that are implicit
+ in either the request, or the url.
+ """
+ pass
+
+ def post_save(self, obj, created=False):
+ """
+ Placeholder method for calling after saving an object.
+ """
+ pass
-### Concrete view classes that provide method handlers ###
-### by composing the mixin classes with a base view. ###
+##########################################################
+### Concrete view classes that provide method handlers ###
+### by composing the mixin classes with the base view. ###
+##########################################################
class CreateAPIView(mixins.CreateModelMixin,
GenericAPIView):
@@ -150,7 +327,7 @@ class CreateAPIView(mixins.CreateModelMixin,
class ListAPIView(mixins.ListModelMixin,
- MultipleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for listing a queryset.
"""
@@ -159,7 +336,7 @@ class ListAPIView(mixins.ListModelMixin,
class RetrieveAPIView(mixins.RetrieveModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for retrieving a model instance.
"""
@@ -168,7 +345,7 @@ class RetrieveAPIView(mixins.RetrieveModelMixin,
class DestroyAPIView(mixins.DestroyModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for deleting a model instance.
@@ -178,7 +355,7 @@ class DestroyAPIView(mixins.DestroyModelMixin,
class UpdateAPIView(mixins.UpdateModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for updating a model instance.
@@ -187,13 +364,12 @@ class UpdateAPIView(mixins.UpdateModelMixin,
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
- kwargs['partial'] = True
- return self.update(request, *args, **kwargs)
+ return self.partial_update(request, *args, **kwargs)
class ListCreateAPIView(mixins.ListModelMixin,
mixins.CreateModelMixin,
- MultipleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for listing a queryset or creating a model instance.
"""
@@ -206,7 +382,7 @@ class ListCreateAPIView(mixins.ListModelMixin,
class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for retrieving, updating a model instance.
"""
@@ -217,13 +393,12 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
- kwargs['partial'] = True
- return self.update(request, *args, **kwargs)
+ return self.partial_update(request, *args, **kwargs)
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for retrieving or deleting a model instance.
"""
@@ -237,7 +412,7 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for retrieving, updating or deleting a model instance.
"""
@@ -248,8 +423,31 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
- kwargs['partial'] = True
- return self.update(request, *args, **kwargs)
+ return self.partial_update(request, *args, **kwargs)
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
+
+
+##########################
+### Deprecated classes ###
+##########################
+
+class MultipleObjectAPIView(GenericAPIView):
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ 'Subclassing `MultipleObjectAPIView` is due to be deprecated. '
+ 'You should simply subclass `GenericAPIView` instead.',
+ PendingDeprecationWarning, stacklevel=2
+ )
+ super(MultipleObjectAPIView, self).__init__(*args, **kwargs)
+
+
+class SingleObjectAPIView(GenericAPIView):
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ 'Subclassing `SingleObjectAPIView` is due to be deprecated. '
+ 'You should simply subclass `GenericAPIView` instead.',
+ PendingDeprecationWarning, stacklevel=2
+ )
+ super(SingleObjectAPIView, self).__init__(*args, **kwargs)
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 97201c4b..f3cd5868 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -10,9 +10,10 @@ from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
from rest_framework.request import clone_request
+import warnings
-def _get_validation_exclusions(obj, pk=None, slug_field=None):
+def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None):
"""
Given a model instance, and an optional pk and slug field,
return the full list of all other field names on that model.
@@ -23,28 +24,32 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None):
include = []
if pk:
+ # Pending deprecation
pk_field = obj._meta.pk
while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk
include.append(pk_field.name)
if slug_field:
+ # Pending deprecation
include.append(slug_field)
+ if lookup_field and lookup_field != 'pk':
+ include.append(lookup_field)
+
return [field.name for field in obj._meta.fields if field.name not in include]
class CreateModelMixin(object):
"""
Create a model instance.
- Should be mixed in with any `GenericAPIView`.
"""
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
if serializer.is_valid():
self.pre_save(serializer.object)
- self.object = serializer.save()
+ self.object = serializer.save(force_insert=True)
self.post_save(self.object, created=True)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED,
@@ -62,28 +67,28 @@ class CreateModelMixin(object):
class ListModelMixin(object):
"""
List a queryset.
- Should be mixed in with `MultipleObjectAPIView`.
"""
empty_error = "Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs):
- queryset = self.get_queryset()
- self.object_list = self.filter_queryset(queryset)
+ self.object_list = self.filter_queryset(self.get_queryset())
# Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets.
- allow_empty = self.get_allow_empty()
- if not allow_empty and not self.object_list:
+ if not self.allow_empty and not self.object_list:
+ warnings.warn(
+ 'The `allow_empty` parameter is due to be deprecated. '
+ 'To use `allow_empty=False` style behavior, You should override '
+ '`get_queryset()` and explicitly raise a 404 on empty querysets.',
+ PendingDeprecationWarning
+ )
class_name = self.__class__.__name__
error_msg = self.empty_error % {'class_name': class_name}
raise Http404(error_msg)
- # Pagination size is set by the `.paginate_by` attribute,
- # which may be `None` to disable pagination.
- page_size = self.get_paginate_by(self.object_list)
- if page_size:
- packed = self.paginate_queryset(self.object_list, page_size)
- paginator, page, queryset, is_paginated = packed
+ # Switch between paginated or standard style responses
+ page = self.paginate_queryset(self.object_list)
+ if page is not None:
serializer = self.get_pagination_serializer(page)
else:
serializer = self.get_serializer(self.object_list, many=True)
@@ -94,7 +99,6 @@ class ListModelMixin(object):
class RetrieveModelMixin(object):
"""
Retrieve a model instance.
- Should be mixed in with `SingleObjectAPIView`.
"""
def retrieve(self, request, *args, **kwargs):
self.object = self.get_object()
@@ -105,21 +109,28 @@ class RetrieveModelMixin(object):
class UpdateModelMixin(object):
"""
Update a model instance.
- Should be mixed in with `SingleObjectAPIView`.
"""
- def update(self, request, *args, **kwargs):
- partial = kwargs.pop('partial', False)
- self.object = None
+ def get_object_or_none(self):
try:
- self.object = self.get_object()
+ return self.get_object()
except Http404:
# If this is a PUT-as-create operation, we need to ensure that
# we have relevant permissions, as if this was a POST request.
- self.check_permissions(clone_request(request, 'POST'))
+ # This will either raise a PermissionDenied exception,
+ # or simply return None
+ self.check_permissions(clone_request(self.request, 'POST'))
+
+ def update(self, request, *args, **kwargs):
+ partial = kwargs.pop('partial', False)
+ self.object = self.get_object_or_none()
+
+ if self.object is None:
created = True
+ save_kwargs = {'force_insert': True}
success_status_code = status.HTTP_201_CREATED
else:
created = False
+ save_kwargs = {'force_update': True}
success_status_code = status.HTTP_200_OK
serializer = self.get_serializer(self.object, data=request.DATA,
@@ -127,20 +138,28 @@ class UpdateModelMixin(object):
if serializer.is_valid():
self.pre_save(serializer.object)
- self.object = serializer.save()
+ self.object = serializer.save(**save_kwargs)
self.post_save(self.object, created=created)
return Response(serializer.data, status=success_status_code)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ def partial_update(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
def pre_save(self, obj):
"""
Set any attributes on the object that are implicit in the request.
"""
# pk and/or slug attributes are implicit in the URL.
+ lookup = self.kwargs.get(self.lookup_field, None)
pk = self.kwargs.get(self.pk_url_kwarg, None)
slug = self.kwargs.get(self.slug_url_kwarg, None)
- slug_field = slug and self.get_slug_field() or None
+ slug_field = slug and self.slug_field or None
+
+ if lookup:
+ setattr(obj, self.lookup_field, lookup)
if pk:
setattr(obj, 'pk', pk)
@@ -151,14 +170,13 @@ class UpdateModelMixin(object):
# Ensure we clean the attributes so that we don't eg return integer
# pk using a string representation, as provided by the url conf kwarg.
if hasattr(obj, 'full_clean'):
- exclude = _get_validation_exclusions(obj, pk, slug_field)
+ exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field)
obj.full_clean(exclude)
class DestroyModelMixin(object):
"""
Destroy a model instance.
- Should be mixed in with `SingleObjectAPIView`.
"""
def destroy(self, request, *args, **kwargs):
obj = self.get_object()
diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py
index 0694d35f..4d205c0e 100644
--- a/rest_framework/negotiation.py
+++ b/rest_framework/negotiation.py
@@ -1,3 +1,7 @@
+"""
+Content negotiation deals with selecting an appropriate renderer given the
+incoming request. Typically this will be based on the request's Accept header.
+"""
from __future__ import unicode_literals
from django.http import Http404
from rest_framework import exceptions
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 03a7a30f..d51ea929 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -1,9 +1,11 @@
+"""
+Pagination serializers determine the structure of the output that should
+be used for paginated responses.
+"""
from __future__ import unicode_literals
from rest_framework import serializers
from rest_framework.templatetags.rest_framework import replace_query_param
-# TODO: Support URLconf kwarg-style paging
-
class NextPageField(serializers.Field):
"""
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 491acd68..25be2e6a 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -6,9 +6,10 @@ on the request, such as form content or json encoded data.
"""
from __future__ import unicode_literals
from django.conf import settings
+from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
-from django.http.multipartparser import MultiPartParserError
+from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter
from rest_framework.compat import yaml, etree
from rest_framework.exceptions import ParseError
from rest_framework.compat import six
@@ -205,3 +206,90 @@ class XMLParser(BaseParser):
pass
return value
+
+
+class FileUploadParser(BaseParser):
+ """
+ Parser for file upload data.
+ """
+ media_type = '*/*'
+
+ def parse(self, stream, media_type=None, parser_context=None):
+ """
+ Returns a DataAndFiles object.
+
+ `.data` will be None (we expect request body to be a file content).
+ `.files` will be a `QueryDict` containing one 'file' element.
+ """
+
+ parser_context = parser_context or {}
+ request = parser_context['request']
+ encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
+ meta = request.META
+ upload_handlers = request.upload_handlers
+ filename = self.get_filename(stream, media_type, parser_context)
+
+ # Note that this code is extracted from Django's handling of
+ # file uploads in MultiPartParser.
+ content_type = meta.get('HTTP_CONTENT_TYPE',
+ meta.get('CONTENT_TYPE', ''))
+ try:
+ content_length = int(meta.get('HTTP_CONTENT_LENGTH',
+ meta.get('CONTENT_LENGTH', 0)))
+ except (ValueError, TypeError):
+ content_length = None
+
+ # See if the handler will want to take care of the parsing.
+ for handler in upload_handlers:
+ result = handler.handle_raw_input(None,
+ meta,
+ content_length,
+ None,
+ encoding)
+ if result is not None:
+ return DataAndFiles(None, {'file': result[1]})
+
+ # This is the standard case.
+ possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size]
+ chunk_size = min([2 ** 31 - 4] + possible_sizes)
+ chunks = ChunkIter(stream, chunk_size)
+ counters = [0] * len(upload_handlers)
+
+ for handler in upload_handlers:
+ try:
+ handler.new_file(None, filename, content_type,
+ content_length, encoding)
+ except StopFutureHandlers:
+ break
+
+ for chunk in chunks:
+ for i, handler in enumerate(upload_handlers):
+ chunk_length = len(chunk)
+ chunk = handler.receive_data_chunk(chunk, counters[i])
+ counters[i] += chunk_length
+ if chunk is None:
+ break
+
+ for i, handler in enumerate(upload_handlers):
+ file_obj = handler.file_complete(counters[i])
+ if file_obj:
+ return DataAndFiles(None, {'file': file_obj})
+ raise ParseError("FileUpload parse error - "
+ "none of upload handlers can handle the stream")
+
+ def get_filename(self, stream, media_type, parser_context):
+ """
+ Detects the uploaded file name. First searches a 'filename' url kwarg.
+ Then tries to parse Content-Disposition header.
+ """
+ try:
+ return parser_context['kwargs']['filename']
+ except KeyError:
+ pass
+
+ try:
+ meta = parser_context['request'].META
+ disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'])
+ return disposition[1]['filename']
+ except (AttributeError, KeyError):
+ pass
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index 306f00ca..45fcfd66 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -7,6 +7,8 @@ import warnings
SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
+from rest_framework.compat import oauth2_provider_scope, oauth2_constants
+
class BasePermission(object):
"""
@@ -23,10 +25,12 @@ class BasePermission(object):
"""
Return `True` if permission is granted, `False` otherwise.
"""
- if len(inspect.getargspec(self.has_permission)[0]) == 4:
- warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. '
- 'Use `has_object_permission()` instead for object permissions.',
- PendingDeprecationWarning, stacklevel=2)
+ if len(inspect.getargspec(self.has_permission).args) == 4:
+ warnings.warn(
+ 'The `obj` argument in `has_permission` is deprecated. '
+ 'Use `has_object_permission()` instead for object permissions.',
+ DeprecationWarning, stacklevel=2
+ )
return self.has_permission(request, view, obj)
return True
@@ -85,8 +89,8 @@ class DjangoModelPermissions(BasePermission):
It ensures that the user is authenticated, and has the appropriate
`add`/`change`/`delete` permissions on the model.
- This permission will only be applied against view classes that
- provide a `.model` attribute, such as the generic class-based views.
+ This permission can only be applied against view classes that
+ provide a `.model` or `.queryset` attribute.
"""
# Map methods into required permission codes.
@@ -102,6 +106,8 @@ class DjangoModelPermissions(BasePermission):
'DELETE': ['%(app_label)s.delete_%(model_name)s'],
}
+ authenticated_users_only = True
+
def get_required_permissions(self, method, model_cls):
"""
Given a model and an HTTP method, return the list of permission
@@ -115,13 +121,54 @@ class DjangoModelPermissions(BasePermission):
def has_permission(self, request, view):
model_cls = getattr(view, 'model', None)
- if not model_cls:
+ queryset = getattr(view, 'queryset', None)
+
+ if model_cls is None and queryset is not None:
+ model_cls = queryset.model
+
+ # Workaround to ensure DjangoModelPermissions are not applied
+ # to the root view when using DefaultRouter.
+ if model_cls is None and getattr(view, '_ignore_model_permissions'):
return True
+ assert model_cls, ('Cannot apply DjangoModelPermissions on a view that'
+ ' does not have `.model` or `.queryset` property.')
+
perms = self.get_required_permissions(request.method, model_cls)
if (request.user and
- request.user.is_authenticated() and
+ (request.user.is_authenticated() or not self.authenticated_users_only) and
request.user.has_perms(perms)):
return True
return False
+
+
+class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
+ """
+ Similar to DjangoModelPermissions, except that anonymous users are
+ allowed read-only access.
+ """
+ authenticated_users_only = False
+
+
+class TokenHasReadWriteScope(BasePermission):
+ """
+ The request is authenticated as a user and the token used has the right scope
+ """
+
+ def has_permission(self, request, view):
+ token = request.auth
+ read_only = request.method in SAFE_METHODS
+
+ if not token:
+ return False
+
+ if hasattr(token, 'resource'): # OAuth 1
+ return read_only or not request.auth.resource.is_readonly
+ elif hasattr(token, 'scope'): # OAuth 2
+ required = oauth2_constants.READ if read_only else oauth2_constants.WRITE
+ return oauth2_provider_scope.check(required, request.auth.scope)
+
+ assert False, ('TokenHasReadWriteScope requires either the'
+ '`OAuthAuthentication` or `OAuth2Authentication` authentication '
+ 'class to be used.')
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index ef465b3c..884b954c 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -1,3 +1,9 @@
+"""
+Serializer fields that deal with relationships.
+
+These fields allow you to specify the style that should be used to represent
+model relationships, including hyperlinks, primary keys, or slugs.
+"""
from __future__ import unicode_literals
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
@@ -36,9 +42,9 @@ class RelatedField(WritableField):
# 'null' is to be deprecated in favor of 'required'
if 'null' in kwargs:
- warnings.warn('The `null` keyword argument is due to be deprecated. '
+ warnings.warn('The `null` keyword argument is deprecated. '
'Use the `required` keyword argument instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['required'] = not kwargs.pop('null')
self.queryset = kwargs.pop('queryset', None)
@@ -243,7 +249,6 @@ class PrimaryKeyRelatedField(RelatedField):
pk = getattr(obj, self.source or field_name).pk
except ObjectDoesNotExist:
return None
- return self.to_native(obj.pk)
# Forward relationship
return self.to_native(pk)
@@ -291,10 +296,8 @@ class HyperlinkedRelatedField(RelatedField):
"""
Represents a relationship using hyperlinking.
"""
- pk_url_kwarg = 'pk'
- slug_field = 'slug'
- slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
read_only = False
+ lookup_field = 'pk'
default_error_messages = {
'no_match': _('Invalid hyperlink - No URL match'),
@@ -304,69 +307,138 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
}
+ # These are all pending deprecation
+ pk_url_kwarg = 'pk'
+ slug_field = 'slug'
+ slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
+
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
except KeyError:
raise ValueError("Hyperlinked field requires 'view_name' kwarg")
+ self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
+ self.format = kwargs.pop('format', None)
+
+ # These are pending deprecation
+ if 'pk_url_kwarg' in kwargs:
+ msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_url_kwarg' in kwargs:
+ msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_field' in kwargs:
+ msg = 'slug_field is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+
+ self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
self.slug_field = kwargs.pop('slug_field', self.slug_field)
default_slug_kwarg = self.slug_url_kwarg or self.slug_field
- self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
- self.format = kwargs.pop('format', None)
super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
- def get_slug_field(self):
- """
- Get the name of a slug field to be used to look up by slug.
+ def get_url(self, obj, view_name, request, format):
"""
- return self.slug_field
-
- def to_native(self, obj):
- view_name = self.view_name
- request = self.context.get('request', None)
- format = self.format or self.context.get('format', None)
+ Given an object, return the URL that hyperlinks to the object.
- if request is None:
- warnings.warn("Using `HyperlinkedRelatedField` without including the "
- "request in the serializer context is due to be deprecated. "
- "Add `context={'request': request}` when instantiating the serializer.",
- PendingDeprecationWarning, stacklevel=4)
-
- pk = getattr(obj, 'pk', None)
- if pk is None:
- return
- kwargs = {self.pk_url_kwarg: pk}
+ May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
+ attributes are not configured to correctly match the URL conf.
+ """
+ lookup_field = getattr(obj, self.lookup_field)
+ kwargs = {self.lookup_field: lookup_field}
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch:
pass
+ if self.pk_url_kwarg != 'pk':
+ # Only try pk if it has been explicitly set.
+ # Otherwise, the default `lookup_field = 'pk'` has us covered.
+ pk = obj.pk
+ kwargs = {self.pk_url_kwarg: pk}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except NoReverseMatch:
+ pass
+
slug = getattr(obj, self.slug_field, None)
+ if slug is not None:
+ # Only try slug if it corresponds to an attribute on the object.
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ ret = reverse(view_name, kwargs=kwargs, request=request, format=format)
+ if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug':
+ # If the lookup succeeds using the default slug params,
+ # then `slug_field` is being used implicitly, and we
+ # we need to warn about the pending deprecation.
+ msg = 'Implicit slug field hyperlinked fields are pending deprecation.' \
+ 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ return ret
+ except NoReverseMatch:
+ pass
+
+ raise NoReverseMatch()
+
+ def get_object(self, queryset, view_name, view_args, view_kwargs):
+ """
+ Return the object corresponding to a matched URL.
- if not slug:
- raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+ Takes the matched URL conf arguments, and the queryset, and should
+ return an object instance, or raise an `ObjectDoesNotExist` exception.
+ """
+ lookup = view_kwargs.get(self.lookup_field, None)
+ pk = view_kwargs.get(self.pk_url_kwarg, None)
+ slug = view_kwargs.get(self.slug_url_kwarg, None)
+
+ if lookup is not None:
+ filter_kwargs = {self.lookup_field: lookup}
+ elif pk is not None:
+ filter_kwargs = {'pk': pk}
+ elif slug is not None:
+ filter_kwargs = {self.slug_field: slug}
+ else:
+ raise ObjectDoesNotExist()
- kwargs = {self.slug_url_kwarg: slug}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
+ return queryset.get(**filter_kwargs)
- kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
+ def to_native(self, obj):
+ view_name = self.view_name
+ request = self.context.get('request', None)
+ format = self.format or self.context.get('format', None)
+
+ if request is None:
+ msg = (
+ "Using `HyperlinkedRelatedField` without including the request "
+ "in the serializer context is deprecated. "
+ "Add `context={'request': request}` when instantiating "
+ "the serializer."
+ )
+ warnings.warn(msg, DeprecationWarning, stacklevel=4)
+
+ # If the object has not yet been saved then we cannot hyperlink to it.
+ if getattr(obj, 'pk', None) is None:
+ return
+
+ # Return the hyperlink, or error if incorrectly configured.
try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ return self.get_url(obj, view_name, request, format)
except NoReverseMatch:
- pass
-
- raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+ msg = (
+ 'Could not resolve URL for hyperlinked relationship using '
+ 'view name "%s". You may have failed to include the related '
+ 'model in your API, or incorrectly configured the '
+ '`lookup_field` attribute on this field.'
+ )
+ raise Exception(msg % view_name)
def from_native(self, value):
# Convert URL -> model instance pk
# TODO: Use values_list
- if self.queryset is None:
+ queryset = self.queryset
+ if queryset is None:
raise Exception('Writable related fields must include a `queryset` argument')
try:
@@ -390,39 +462,24 @@ class HyperlinkedRelatedField(RelatedField):
if match.view_name != self.view_name:
raise ValidationError(self.error_messages['incorrect_match'])
- pk = match.kwargs.get(self.pk_url_kwarg, None)
- slug = match.kwargs.get(self.slug_url_kwarg, None)
-
- # Try explicit primary key.
- if pk is not None:
- queryset = self.queryset.filter(pk=pk)
- # Next, try looking up by slug.
- elif slug is not None:
- slug_field = self.get_slug_field()
- queryset = self.queryset.filter(**{slug_field: slug})
- # If none of those are defined, it's probably a configuation error.
- else:
- raise ValidationError(self.error_messages['configuration_error'])
-
try:
- obj = queryset.get()
- except ObjectDoesNotExist:
+ return self.get_object(queryset, match.view_name,
+ match.args, match.kwargs)
+ except (ObjectDoesNotExist, TypeError, ValueError):
raise ValidationError(self.error_messages['does_not_exist'])
- except (TypeError, ValueError):
- msg = self.error_messages['incorrect_type']
- raise ValidationError(msg % type(value).__name__)
-
- return obj
class HyperlinkedIdentityField(Field):
"""
Represents the instance, or a property on the instance, using hyperlinking.
"""
+ lookup_field = 'pk'
+ read_only = True
+
+ # These are all pending deprecation
pk_url_kwarg = 'pk'
slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
- read_only = True
def __init__(self, *args, **kwargs):
# TODO: Make view_name mandatory, and have the
@@ -431,6 +488,19 @@ class HyperlinkedIdentityField(Field):
# Optionally the format of the target hyperlink may be specified
self.format = kwargs.pop('format', None)
+ self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
+
+ # These are pending deprecation
+ if 'pk_url_kwarg' in kwargs:
+ msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_url_kwarg' in kwargs:
+ msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_field' in kwargs:
+ msg = 'slug_field is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+
self.slug_field = kwargs.pop('slug_field', self.slug_field)
default_slug_kwarg = self.slug_url_kwarg or self.slug_field
self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
@@ -442,13 +512,14 @@ class HyperlinkedIdentityField(Field):
request = self.context.get('request', None)
format = self.context.get('format', None)
view_name = self.view_name or self.parent.opts.view_name
- kwargs = {self.pk_url_kwarg: obj.pk}
+ lookup_field = getattr(obj, self.lookup_field)
+ kwargs = {self.lookup_field: lookup_field}
if request is None:
warnings.warn("Using `HyperlinkedIdentityField` without including the "
- "request in the serializer context is due to be deprecated. "
+ "request in the serializer context is deprecated. "
"Add `context={'request': request}` when instantiating the serializer.",
- PendingDeprecationWarning, stacklevel=4)
+ DeprecationWarning, stacklevel=4)
# By default use whatever format is given for the current context
# unless the target is a different type to the source.
@@ -491,35 +562,35 @@ class HyperlinkedIdentityField(Field):
class ManyRelatedField(RelatedField):
def __init__(self, *args, **kwargs):
- warnings.warn('`ManyRelatedField()` is due to be deprecated. '
+ warnings.warn('`ManyRelatedField()` is deprecated. '
'Use `RelatedField(many=True)` instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyRelatedField, self).__init__(*args, **kwargs)
class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
def __init__(self, *args, **kwargs):
- warnings.warn('`ManyPrimaryKeyRelatedField()` is due to be deprecated. '
+ warnings.warn('`ManyPrimaryKeyRelatedField()` is deprecated. '
'Use `PrimaryKeyRelatedField(many=True)` instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs)
class ManySlugRelatedField(SlugRelatedField):
def __init__(self, *args, **kwargs):
- warnings.warn('`ManySlugRelatedField()` is due to be deprecated. '
+ warnings.warn('`ManySlugRelatedField()` is deprecated. '
'Use `SlugRelatedField(many=True)` instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManySlugRelatedField, self).__init__(*args, **kwargs)
class ManyHyperlinkedRelatedField(HyperlinkedRelatedField):
def __init__(self, *args, **kwargs):
- warnings.warn('`ManyHyperlinkedRelatedField()` is due to be deprecated. '
+ warnings.warn('`ManyHyperlinkedRelatedField()` is deprecated. '
'Use `HyperlinkedRelatedField(many=True)` instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs)
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 4c15e0db..1917a080 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -24,6 +24,7 @@ from rest_framework.settings import api_settings
from rest_framework.request import clone_request
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
+from rest_framework.utils.formatting import get_view_name, get_view_description
from rest_framework import exceptions, parsers, status, VERSION
@@ -57,7 +58,7 @@ class JSONRenderer(BaseRenderer):
return ''
# If 'indent' is provided in the context, then pretty print the result.
- # E.g. If we're being called by the BrowseableAPIRenderer.
+ # E.g. If we're being called by the BrowsableAPIRenderer.
renderer_context = renderer_context or {}
indent = renderer_context.get('indent', None)
@@ -438,16 +439,13 @@ class BrowsableAPIRenderer(BaseRenderer):
return GenericContentForm()
def get_name(self, view):
- try:
- return view.get_name()
- except AttributeError:
- return smart_text(view.__class__.__name__)
+ return get_view_name(view.__class__, getattr(view, 'suffix', None))
def get_description(self, view):
- try:
- return view.get_description(html=True)
- except AttributeError:
- return smart_text(view.__doc__ or '')
+ return get_view_description(view.__class__, html=True)
+
+ def get_breadcrumbs(self, request):
+ return get_breadcrumbs(request.path)
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -480,7 +478,7 @@ class BrowsableAPIRenderer(BaseRenderer):
name = self.get_name(view)
description = self.get_description(view)
- breadcrumb_list = get_breadcrumbs(request.path)
+ breadcrumb_list = self.get_breadcrumbs(request)
template = loader.get_template(self.template)
context = RequestContext(request, {
diff --git a/rest_framework/request.py b/rest_framework/request.py
index 3e2fbd88..a434659c 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -1,11 +1,10 @@
"""
-The :mod:`request` module provides a :class:`Request` class used to wrap the standard `request`
-object received in all the views.
+The Request class is used as a wrapper around the standard request object.
The wrapped request then offers a richer API, in particular :
- content automatically parsed according to `Content-Type` header,
- and available as :meth:`.DATA<Request.DATA>`
+ and available as `request.DATA`
- full support of PUT method, including support for file uploads
- form overloading of HTTP method, content type and content
"""
@@ -231,11 +230,17 @@ class Request(object):
"""
self._content_type = self.META.get('HTTP_CONTENT_TYPE',
self.META.get('CONTENT_TYPE', ''))
+
self._perform_form_overloading()
- # if the HTTP method was not overloaded, we take the raw HTTP method
+
if not _hasattr(self, '_method'):
self._method = self._request.method
+ if self._method == 'POST':
+ # Allow X-HTTP-METHOD-OVERRIDE header
+ self._method = self.META.get('HTTP_X_HTTP_METHOD_OVERRIDE',
+ self._method)
+
def _load_stream(self):
"""
Return the content body of the request, as a stream.
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 5e1bf46e..26e4ab37 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -1,3 +1,9 @@
+"""
+The Response class in REST framework is similiar to HTTPResponse, except that
+it is initialized with unrendered data, instead of a pre-rendered string.
+
+The appropriate renderer is called during Django's template response rendering.
+"""
from __future__ import unicode_literals
from django.core.handlers.wsgi import STATUS_CODE_TEXT
from django.template.response import SimpleTemplateResponse
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
new file mode 100644
index 00000000..dba104c3
--- /dev/null
+++ b/rest_framework/routers.py
@@ -0,0 +1,249 @@
+"""
+Routers provide a convenient and consistent way of automatically
+determining the URL conf for your API.
+
+They are used by simply instantiating a Router class, and then registering
+all the required ViewSets with that router.
+
+For example, you might have a `urls.py` that looks something like this:
+
+ router = routers.DefaultRouter()
+ router.register('users', UserViewSet, 'user')
+ router.register('accounts', AccountViewSet, 'account')
+
+ urlpatterns = router.urls
+"""
+from __future__ import unicode_literals
+
+from collections import namedtuple
+from rest_framework import views
+from rest_framework.compat import patterns, url
+from rest_framework.decorators import api_view
+from rest_framework.response import Response
+from rest_framework.reverse import reverse
+from rest_framework.urlpatterns import format_suffix_patterns
+
+
+Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs'])
+
+
+def replace_methodname(format_string, methodname):
+ """
+ Partially format a format_string, swapping out any
+ '{methodname}' or '{methodnamehyphen}' components.
+ """
+ methodnamehyphen = methodname.replace('_', '-')
+ ret = format_string
+ ret = ret.replace('{methodname}', methodname)
+ ret = ret.replace('{methodnamehyphen}', methodnamehyphen)
+ return ret
+
+
+class BaseRouter(object):
+ def __init__(self):
+ self.registry = []
+
+ def register(self, prefix, viewset, base_name=None):
+ if base_name is None:
+ base_name = self.get_default_base_name(viewset)
+ self.registry.append((prefix, viewset, base_name))
+
+ def get_default_base_name(self, viewset):
+ """
+ If `base_name` is not specified, attempt to automatically determine
+ it from the viewset.
+ """
+ raise NotImplemented('get_default_base_name must be overridden')
+
+ def get_urls(self):
+ """
+ Return a list of URL patterns, given the registered viewsets.
+ """
+ raise NotImplemented('get_urls must be overridden')
+
+ @property
+ def urls(self):
+ if not hasattr(self, '_urls'):
+ self._urls = patterns('', *self.get_urls())
+ return self._urls
+
+
+class SimpleRouter(BaseRouter):
+ routes = [
+ # List route.
+ Route(
+ url=r'^{prefix}/$',
+ mapping={
+ 'get': 'list',
+ 'post': 'create'
+ },
+ name='{basename}-list',
+ initkwargs={'suffix': 'List'}
+ ),
+ # Detail route.
+ Route(
+ url=r'^{prefix}/{lookup}/$',
+ mapping={
+ 'get': 'retrieve',
+ 'put': 'update',
+ 'patch': 'partial_update',
+ 'delete': 'destroy'
+ },
+ name='{basename}-detail',
+ initkwargs={'suffix': 'Instance'}
+ ),
+ # Dynamically generated routes.
+ # Generated using @action or @link decorators on methods of the viewset.
+ Route(
+ url=r'^{prefix}/{lookup}/{methodname}/$',
+ mapping={
+ '{httpmethod}': '{methodname}',
+ },
+ name='{basename}-{methodnamehyphen}',
+ initkwargs={}
+ ),
+ ]
+
+ def get_default_base_name(self, viewset):
+ """
+ If `base_name` is not specified, attempt to automatically determine
+ it from the viewset.
+ """
+ model_cls = getattr(viewset, 'model', None)
+ queryset = getattr(viewset, 'queryset', None)
+ if model_cls is None and queryset is not None:
+ model_cls = queryset.model
+
+ assert model_cls, '`name` not argument not specified, and could ' \
+ 'not automatically determine the name from the viewset, as ' \
+ 'it does not have a `.model` or `.queryset` attribute.'
+
+ return model_cls._meta.object_name.lower()
+
+ def get_routes(self, viewset):
+ """
+ Augment `self.routes` with any dynamically generated routes.
+
+ Returns a list of the Route namedtuple.
+ """
+
+ # Determine any `@action` or `@link` decorated methods on the viewset
+ dynamic_routes = []
+ for methodname in dir(viewset):
+ attr = getattr(viewset, methodname)
+ httpmethod = getattr(attr, 'bind_to_method', None)
+ if httpmethod:
+ dynamic_routes.append((httpmethod, methodname))
+
+ ret = []
+ for route in self.routes:
+ if route.mapping == {'{httpmethod}': '{methodname}'}:
+ # Dynamic routes (@link or @action decorator)
+ for httpmethod, methodname in dynamic_routes:
+ initkwargs = route.initkwargs.copy()
+ initkwargs.update(getattr(viewset, methodname).kwargs)
+ ret.append(Route(
+ url=replace_methodname(route.url, methodname),
+ mapping={httpmethod: methodname},
+ name=replace_methodname(route.name, methodname),
+ initkwargs=initkwargs,
+ ))
+ else:
+ # Standard route
+ ret.append(route)
+
+ return ret
+
+ def get_method_map(self, viewset, method_map):
+ """
+ Given a viewset, and a mapping of http methods to actions,
+ return a new mapping which only includes any mappings that
+ are actually implemented by the viewset.
+ """
+ bound_methods = {}
+ for method, action in method_map.items():
+ if hasattr(viewset, action):
+ bound_methods[method] = action
+ return bound_methods
+
+ def get_lookup_regex(self, viewset):
+ """
+ Given a viewset, return the portion of URL regex that is used
+ to match against a single instance.
+ """
+ base_regex = '(?P<{lookup_field}>[^/]+)'
+ lookup_field = getattr(viewset, 'lookup_field', 'pk')
+ return base_regex.format(lookup_field=lookup_field)
+
+ def get_urls(self):
+ """
+ Use the registered viewsets to generate a list of URL patterns.
+ """
+ ret = []
+
+ for prefix, viewset, basename in self.registry:
+ lookup = self.get_lookup_regex(viewset)
+ routes = self.get_routes(viewset)
+
+ for route in routes:
+
+ # Only actions which actually exist on the viewset will be bound
+ mapping = self.get_method_map(viewset, route.mapping)
+ if not mapping:
+ continue
+
+ # Build the url pattern
+ regex = route.url.format(prefix=prefix, lookup=lookup)
+ view = viewset.as_view(mapping, **route.initkwargs)
+ name = route.name.format(basename=basename)
+ ret.append(url(regex, view, name=name))
+
+ return ret
+
+
+class DefaultRouter(SimpleRouter):
+ """
+ The default router extends the SimpleRouter, but also adds in a default
+ API root view, and adds format suffix patterns to the URLs.
+ """
+ include_root_view = True
+ include_format_suffixes = True
+
+ def get_api_root_view(self):
+ """
+ Return a view to use as the API root.
+ """
+ api_root_dict = {}
+ list_name = self.routes[0].name
+ for prefix, viewset, basename in self.registry:
+ api_root_dict[prefix] = list_name.format(basename=basename)
+
+ class APIRoot(views.APIView):
+ _ignore_model_permissions = True
+
+ def get(self, request, format=None):
+ ret = {}
+ for key, url_name in api_root_dict.items():
+ ret[key] = reverse(url_name, request=request, format=format)
+ return Response(ret)
+
+ return APIRoot.as_view()
+
+ def get_urls(self):
+ """
+ Generate the list of URL patterns, including a default root view
+ for the API, and appending `.json` style format suffixes.
+ """
+ urls = []
+
+ if self.include_root_view:
+ root_url = url(r'^$', self.get_api_root_view(), name='api-root')
+ urls.append(root_url)
+
+ default_urls = super(DefaultRouter, self).get_urls()
+ urls.extend(default_urls)
+
+ if self.include_format_suffixes:
+ urls = format_suffix_patterns(urls)
+
+ return urls
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index 03bfc216..9b519f27 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -97,9 +97,30 @@ INSTALLED_APPS = (
# 'django.contrib.admindocs',
'rest_framework',
'rest_framework.authtoken',
- 'rest_framework.tests'
+ 'rest_framework.tests',
)
+# OAuth is optional and won't work if there is no oauth_provider & oauth2
+try:
+ import oauth_provider
+ import oauth2
+except ImportError:
+ pass
+else:
+ INSTALLED_APPS += (
+ 'oauth_provider',
+ )
+
+try:
+ import provider
+except ImportError:
+ pass
+else:
+ INSTALLED_APPS += (
+ 'provider',
+ 'provider.oauth2',
+ )
+
STATIC_URL = '/static/'
PASSWORD_HASHERS = (
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index ba9e9e9c..7707de7a 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -1,3 +1,15 @@
+"""
+Serializers and ModelSerializers are similar to Forms and ModelForms.
+Unlike forms, they are not constrained to dealing with HTML output, and
+form encoded input.
+
+Serialization in REST framework is a two-phase process:
+
+1. Serializers marshal between complex types like model instances, and
+python primatives.
+2. The process of marshalling between python primatives and request and
+response content is handled by parsers and renderers.
+"""
from __future__ import unicode_literals
import copy
import datetime
@@ -7,8 +19,7 @@ from django.core.paginator import Page
from django.db import models
from django.forms import widgets
from django.utils.datastructures import SortedDict
-from rest_framework.compat import get_concrete_model
-from rest_framework.compat import six
+from rest_framework.compat import get_concrete_model, six
# Note: We do the following so that users of the framework can use this style:
#
@@ -21,6 +32,25 @@ from rest_framework.relations import *
from rest_framework.fields import *
+class NestedValidationError(ValidationError):
+ """
+ The default ValidationError behavior is to stringify each item in the list
+ if the messages are a list of error messages.
+
+ In the case of nested serializers, where the parent has many children,
+ then the child's `serializer.errors` will be a list of dicts. In the case
+ of a single child, the `serializer.errors` will be a dict.
+
+ We need to override the default behavior to get properly nested error dicts.
+ """
+
+ def __init__(self, message):
+ if isinstance(message, dict):
+ self.messages = [message]
+ else:
+ self.messages = message
+
+
class DictWithMetadata(dict):
"""
A dict-like object, that can have additional properties attached.
@@ -99,7 +129,7 @@ class SerializerOptions(object):
self.exclude = getattr(meta, 'exclude', ())
-class BaseSerializer(Field):
+class BaseSerializer(WritableField):
"""
This is the Serializer implementation.
We need to implement it as `BaseSerializer` due to metaclass magicks.
@@ -111,13 +141,15 @@ class BaseSerializer(Field):
_dict_class = SortedDictWithMetadata
def __init__(self, instance=None, data=None, files=None,
- context=None, partial=False, many=None, source=None):
- super(BaseSerializer, self).__init__(source=source)
+ context=None, partial=False, many=None,
+ allow_add_remove=False, **kwargs):
+ super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
self.parent = None
self.root = None
self.partial = partial
self.many = many
+ self.allow_add_remove = allow_add_remove
self.context = context or {}
@@ -129,6 +161,13 @@ class BaseSerializer(Field):
self._data = None
self._files = None
self._errors = None
+ self._deleted = None
+
+ if many and instance is not None and not hasattr(instance, '__iter__'):
+ raise ValueError('instance should be a queryset or other iterable with many=True')
+
+ if allow_add_remove and not many:
+ raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')
#####
# Methods to determine which fields to use when (de)serializing objects.
@@ -161,7 +200,7 @@ class BaseSerializer(Field):
# If 'fields' is specified, use those fields, in that order.
if self.opts.fields:
- assert isinstance(self.opts.fields, (list, tuple)), '`include` must be a list or tuple'
+ assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple'
new = SortedDict()
for key in self.opts.fields:
new[key] = ret[key]
@@ -169,7 +208,7 @@ class BaseSerializer(Field):
# Remove anything in 'exclude'
if self.opts.exclude:
- assert isinstance(self.opts.fields, (list, tuple)), '`exclude` must be a list or tuple'
+ assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple'
for key in self.opts.exclude:
ret.pop(key, None)
@@ -179,18 +218,6 @@ class BaseSerializer(Field):
return ret
#####
- # Field methods - used when the serializer class is itself used as a field.
-
- def initialize(self, parent, field_name):
- """
- Same behaviour as usual Field, except that we need to keep track
- of state so that we can deal with handling maximum depth.
- """
- super(BaseSerializer, self).initialize(parent, field_name)
- if parent.opts.depth:
- self.opts.depth = parent.opts.depth - 1
-
- #####
# Methods to convert or revert from objects <--> primitive representations.
def get_field_key(self, field_name):
@@ -285,10 +312,6 @@ class BaseSerializer(Field):
"""
Deserialize primitives -> objects.
"""
- if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)):
- # TODO: error data when deserializing lists
- return [self.from_native(item, None) for item in data]
-
self._errors = {}
if data is not None or files is not None:
attrs = self.restore_fields(data, files)
@@ -301,40 +324,91 @@ class BaseSerializer(Field):
def field_to_native(self, obj, field_name):
"""
- Override default so that we can apply ModelSerializer as a nested
- field to relationships.
+ Override default so that the serializer can be used as a nested field
+ across relationships.
"""
if self.source == '*':
return self.to_native(obj)
try:
- if self.source:
- for component in self.source.split('.'):
- obj = getattr(obj, component)
- if is_simple_callable(obj):
- obj = obj()
- else:
- obj = getattr(obj, field_name)
- if is_simple_callable(obj):
- obj = obj()
+ source = self.source or field_name
+ value = obj
+
+ for component in source.split('.'):
+ value = get_component(value, component)
+ if value is None:
+ break
except ObjectDoesNotExist:
return None
- # If the object has an "all" method, assume it's a relationship
- if is_simple_callable(getattr(obj, 'all', None)):
- return [self.to_native(item) for item in obj.all()]
+ if is_simple_callable(getattr(value, 'all', None)):
+ return [self.to_native(item) for item in value.all()]
- if obj is None:
+ if value is None:
return None
if self.many is not None:
many = self.many
else:
- many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
+ many = hasattr(value, '__iter__') and not isinstance(value, (Page, dict, six.text_type))
if many:
- return [self.to_native(item) for item in obj]
- return self.to_native(obj)
+ return [self.to_native(item) for item in value]
+ return self.to_native(value)
+
+ def field_from_native(self, data, files, field_name, into):
+ """
+ Override default so that the serializer can be used as a writable
+ nested field across relationships.
+ """
+ if self.read_only:
+ return
+
+ try:
+ value = data[field_name]
+ except KeyError:
+ if self.default is not None and not self.partial:
+ # Note: partial updates shouldn't set defaults
+ value = copy.deepcopy(self.default)
+ else:
+ if self.required:
+ raise ValidationError(self.error_messages['required'])
+ return
+
+ # Set the serializer object if it exists
+ obj = getattr(self.parent.object, field_name) if self.parent.object else None
+
+ if value in (None, ''):
+ into[(self.source or field_name)] = None
+ else:
+ kwargs = {
+ 'instance': obj,
+ 'data': value,
+ 'context': self.context,
+ 'partial': self.partial,
+ 'many': self.many
+ }
+ serializer = self.__class__(**kwargs)
+
+ if serializer.is_valid():
+ into[self.source or field_name] = serializer.object
+ else:
+ # Propagate errors up to our parent
+ raise NestedValidationError(serializer.errors)
+
+ def get_identity(self, data):
+ """
+ This hook is required for bulk update.
+ It is used to determine the canonical identity of a given object.
+
+ Note that the data has not been validated at this point, so we need
+ to make sure that we catch any cases of incorrect datatypes being
+ passed to this method.
+ """
+ try:
+ return data.get('id', None)
+ except AttributeError:
+ return None
@property
def errors(self):
@@ -348,19 +422,52 @@ class BaseSerializer(Field):
if self.many is not None:
many = self.many
else:
- many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict))
+ many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))
if many:
- warnings.warn('Implict list/queryset serialization is due to be deprecated. '
+ warnings.warn('Implict list/queryset serialization is deprecated. '
'Use the `many=True` flag when instantiating the serializer.',
- PendingDeprecationWarning, stacklevel=3)
+ DeprecationWarning, stacklevel=3)
- # TODO: error data when deserializing lists
if many:
- ret = [self.from_native(item, None) for item in data]
- ret = self.from_native(data, files)
+ ret = []
+ errors = []
+ update = self.object is not None
+
+ if update:
+ # If this is a bulk update we need to map all the objects
+ # to a canonical identity so we can determine which
+ # individual object is being updated for each item in the
+ # incoming data
+ objects = self.object
+ identities = [self.get_identity(self.to_native(obj)) for obj in objects]
+ identity_to_objects = dict(zip(identities, objects))
+
+ if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)):
+ for item in data:
+ if update:
+ # Determine which object we're updating
+ identity = self.get_identity(item)
+ self.object = identity_to_objects.pop(identity, None)
+ if self.object is None and not self.allow_add_remove:
+ ret.append(None)
+ errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
+ continue
+
+ ret.append(self.from_native(item, None))
+ errors.append(self._errors)
+
+ if update:
+ self._deleted = identity_to_objects.values()
+
+ self._errors = any(errors) and errors or []
+ else:
+ self._errors = {'non_field_errors': ['Expected a list of items.']}
+ else:
+ ret = self.from_native(data, files)
if not self._errors:
self.object = ret
+
return self._errors
def is_valid(self):
@@ -379,9 +486,9 @@ class BaseSerializer(Field):
else:
many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
if many:
- warnings.warn('Implict list/queryset serialization is due to be deprecated. '
+ warnings.warn('Implict list/queryset serialization is deprecated. '
'Use the `many=True` flag when instantiating the serializer.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
if many:
self._data = [self.to_native(item) for item in obj]
@@ -390,11 +497,24 @@ class BaseSerializer(Field):
return self._data
- def save(self):
+ def save_object(self, obj, **kwargs):
+ obj.save(**kwargs)
+
+ def delete_object(self, obj):
+ obj.delete()
+
+ def save(self, **kwargs):
"""
Save the deserialized object and return it.
"""
- self.object.save()
+ if isinstance(self.object, list):
+ [self.save_object(item, **kwargs) for item in self.object]
+ else:
+ self.save_object(self.object, **kwargs)
+
+ if self.allow_add_remove and self._deleted:
+ [self.delete_object(item) for item in self._deleted]
+
return self.object
@@ -428,6 +548,7 @@ class ModelSerializer(Serializer):
models.DateTimeField: DateTimeField,
models.DateField: DateField,
models.TimeField: TimeField,
+ models.DecimalField: DecimalField,
models.EmailField: EmailField,
models.CharField: CharField,
models.URLField: URLField,
@@ -448,36 +569,94 @@ class ModelSerializer(Serializer):
assert cls is not None, \
"Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__
opts = get_concrete_model(cls)._meta
- pk_field = opts.pk
- # while pk_field.rel:
- # pk_field = pk_field.rel.to._meta.pk
- fields = [pk_field]
- fields += [field for field in opts.fields if field.serialize]
- fields += [field for field in opts.many_to_many if field.serialize]
-
ret = SortedDict()
nested = bool(self.opts.depth)
- is_pk = True # First field in the list is the pk
-
- for model_field in fields:
- if is_pk:
- field = self.get_pk_field(model_field)
- is_pk = False
- elif model_field.rel and nested:
- field = self.get_nested_field(model_field)
- elif model_field.rel:
+
+ # Deal with adding the primary key field
+ pk_field = opts.pk
+ while pk_field.rel and pk_field.rel.parent_link:
+ # If model is a child via multitable inheritance, use parent's pk
+ pk_field = pk_field.rel.to._meta.pk
+
+ field = self.get_pk_field(pk_field)
+ if field:
+ ret[pk_field.name] = field
+
+ # Deal with forward relationships
+ forward_rels = [field for field in opts.fields if field.serialize]
+ forward_rels += [field for field in opts.many_to_many if field.serialize]
+
+ for model_field in forward_rels:
+ if model_field.rel:
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
- field = self.get_related_field(model_field, to_many=to_many)
+ related_model = model_field.rel.to
+
+ if model_field.rel and nested:
+ if len(inspect.getargspec(self.get_nested_field).args) == 2:
+ warnings.warn(
+ 'The `get_nested_field(model_field)` call signature '
+ 'is due to be deprecated. '
+ 'Use `get_nested_field(model_field, related_model, '
+ 'to_many) instead',
+ PendingDeprecationWarning
+ )
+ field = self.get_nested_field(model_field)
+ else:
+ field = self.get_nested_field(model_field, related_model, to_many)
+ elif model_field.rel:
+ if len(inspect.getargspec(self.get_nested_field).args) == 3:
+ warnings.warn(
+ 'The `get_related_field(model_field, to_many)` call '
+ 'signature is due to be deprecated. '
+ 'Use `get_related_field(model_field, related_model, '
+ 'to_many) instead',
+ PendingDeprecationWarning
+ )
+ field = self.get_related_field(model_field, to_many=to_many)
+ else:
+ field = self.get_related_field(model_field, related_model, to_many)
else:
field = self.get_field(model_field)
if field:
ret[model_field.name] = field
+ # Deal with reverse relationships
+ if not self.opts.fields:
+ reverse_rels = []
+ else:
+ # Reverse relationships are only included if they are explicitly
+ # present in the `fields` option on the serializer
+ reverse_rels = opts.get_all_related_objects()
+ reverse_rels += opts.get_all_related_many_to_many_objects()
+
+ for relation in reverse_rels:
+ accessor_name = relation.get_accessor_name()
+ if not self.opts.fields or accessor_name not in self.opts.fields:
+ continue
+ related_model = relation.model
+ to_many = relation.field.rel.multiple
+
+ if nested:
+ field = self.get_nested_field(None, related_model, to_many)
+ else:
+ field = self.get_related_field(None, related_model, to_many)
+
+ if field:
+ ret[accessor_name] = field
+
+ # Add the `read_only` flag to any fields that have bee specified
+ # in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
+ assert field_name not in self.base_fields.keys(), \
+ "field '%s' on serializer '%s' specfied in " \
+ "`read_only_fields`, but also added " \
+ "as an explict field. Remove it from `read_only_fields`." % \
+ (field_name, self.__class__.__name__)
assert field_name in ret, \
- "read_only_fields on '%s' included invalid item '%s'" % \
+ "Noexistant field '%s' specified in `read_only_fields` " \
+ "on serializer '%s'." % \
(self.__class__.__name__, field_name)
ret[field_name].read_only = True
@@ -489,27 +668,36 @@ class ModelSerializer(Serializer):
"""
return self.get_field(model_field)
- def get_nested_field(self, model_field):
+ def get_nested_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a nested relational field.
+
+ Note that model_field will be `None` for reverse relationships.
"""
class NestedModelSerializer(ModelSerializer):
class Meta:
- model = model_field.rel.to
- return NestedModelSerializer()
+ model = related_model
+ depth = self.opts.depth - 1
+
+ return NestedModelSerializer(many=to_many)
- def get_related_field(self, model_field, to_many=False):
+ def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
+
+ Note that model_field will be `None` for reverse relationships.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
+
kwargs = {
- 'required': not(model_field.null or model_field.blank),
- 'queryset': model_field.rel.to._default_manager,
+ 'queryset': related_model._default_manager,
'many': to_many
}
+ if model_field:
+ kwargs['required'] = not(model_field.null or model_field.blank)
+
return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
@@ -574,33 +762,43 @@ class ModelSerializer(Serializer):
"""
Restore the model instance.
"""
- self.m2m_data = {}
- self.related_data = {}
+ m2m_data = {}
+ related_data = {}
+ meta = self.opts.model._meta
- # Reverse fk relations
- for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model():
+ # Reverse fk or one-to-one relations
+ for (obj, model) in meta.get_all_related_objects_with_model():
field_name = obj.field.related_query_name()
if field_name in attrs:
- self.related_data[field_name] = attrs.pop(field_name)
+ related_data[field_name] = attrs.pop(field_name)
# Reverse m2m relations
- for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model():
+ for (obj, model) in meta.get_all_related_m2m_objects_with_model():
field_name = obj.field.related_query_name()
if field_name in attrs:
- self.m2m_data[field_name] = attrs.pop(field_name)
+ m2m_data[field_name] = attrs.pop(field_name)
# Forward m2m relations
- for field in self.opts.model._meta.many_to_many:
+ for field in meta.many_to_many:
if field.name in attrs:
- self.m2m_data[field.name] = attrs.pop(field.name)
+ m2m_data[field.name] = attrs.pop(field.name)
+ # Update an existing instance...
if instance is not None:
for key, val in attrs.items():
setattr(instance, key, val)
+ # ...or create a new instance
else:
instance = self.opts.model(**attrs)
+ # Any relations that cannot be set until we've
+ # saved the model get hidden away on these
+ # private attributes, so we can deal with them
+ # at the point of save.
+ instance._related_data = related_data
+ instance._m2m_data = m2m_data
+
return instance
def from_native(self, data, files):
@@ -608,26 +806,24 @@ class ModelSerializer(Serializer):
Override the default method to also include model field validation.
"""
instance = super(ModelSerializer, self).from_native(data, files)
- if instance:
+ if not self._errors:
return self.full_clean(instance)
- def save(self):
+ def save_object(self, obj, **kwargs):
"""
Save the deserialized object and return it.
"""
- self.object.save()
-
- if getattr(self, 'm2m_data', None):
- for accessor_name, object_list in self.m2m_data.items():
- setattr(self.object, accessor_name, object_list)
- self.m2m_data = {}
+ obj.save(**kwargs)
- if getattr(self, 'related_data', None):
- for accessor_name, object_list in self.related_data.items():
- setattr(self.object, accessor_name, object_list)
- self.related_data = {}
+ if getattr(obj, '_m2m_data', None):
+ for accessor_name, object_list in obj._m2m_data.items():
+ setattr(obj, accessor_name, object_list)
+ del(obj._m2m_data)
- return self.object
+ if getattr(obj, '_related_data', None):
+ for accessor_name, related in obj._related_data.items():
+ setattr(obj, accessor_name, related)
+ del(obj._related_data)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
@@ -637,6 +833,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
def __init__(self, meta):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
+ self.lookup_field = getattr(meta, 'lookup_field', None)
class HyperlinkedModelSerializer(ModelSerializer):
@@ -646,6 +843,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
"""
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
+ _hyperlink_field_class = HyperlinkedRelatedField
url = HyperlinkedIdentityField()
@@ -666,19 +864,35 @@ class HyperlinkedModelSerializer(ModelSerializer):
return self._default_view_name % format_kwargs
def get_pk_field(self, model_field):
- return None
+ if self.opts.fields and model_field.name in self.opts.fields:
+ return self.get_field(model_field)
- def get_related_field(self, model_field, to_many):
+ def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
- rel = model_field.rel.to
kwargs = {
- 'required': not(model_field.null or model_field.blank),
- 'queryset': rel._default_manager,
- 'view_name': self._get_default_view_name(rel),
+ 'queryset': related_model._default_manager,
+ 'view_name': self._get_default_view_name(related_model),
'many': to_many
}
- return HyperlinkedRelatedField(**kwargs)
+
+ if model_field:
+ kwargs['required'] = not(model_field.null or model_field.blank)
+
+ if self.opts.lookup_field:
+ kwargs['lookup_field'] = self.opts.lookup_field
+
+ return self._hyperlink_field_class(**kwargs)
+
+ def get_identity(self, data):
+ """
+ This hook is required for bulk update.
+ We need to override the default, to use the url as the identity.
+ """
+ try:
+ return data.get('url', None)
+ except AttributeError:
+ return None
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index b7aa0bbe..beb511ac 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -18,14 +18,18 @@ REST framework settings, checking for user settings first, then falling
back to the defaults.
"""
from __future__ import unicode_literals
+
from django.conf import settings
from django.utils import importlib
+
+from rest_framework import ISO_8601
from rest_framework.compat import six
USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None)
DEFAULTS = {
+ # Base API policies
'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.BrowsableAPIRenderer',
@@ -47,11 +51,15 @@ DEFAULTS = {
'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
+
+ # Genric view behavior
'DEFAULT_MODEL_SERIALIZER_CLASS':
'rest_framework.serializers.ModelSerializer',
'DEFAULT_PAGINATION_SERIALIZER_CLASS':
'rest_framework.pagination.PaginationSerializer',
+ 'DEFAULT_FILTER_BACKENDS': (),
+ # Throttling
'DEFAULT_THROTTLE_RATES': {
'user': None,
'anon': None,
@@ -61,9 +69,6 @@ DEFAULTS = {
'PAGINATE_BY': None,
'PAGINATE_BY_PARAM': None,
- # Filtering
- 'FILTER_BACKEND': None,
-
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
@@ -76,6 +81,25 @@ DEFAULTS = {
'URL_FORMAT_OVERRIDE': 'format',
'FORMAT_SUFFIX_KWARG': 'format',
+
+ # Input and output formats
+ 'DATE_INPUT_FORMATS': (
+ ISO_8601,
+ ),
+ 'DATE_FORMAT': None,
+
+ 'DATETIME_INPUT_FORMATS': (
+ ISO_8601,
+ ),
+ 'DATETIME_FORMAT': None,
+
+ 'TIME_INPUT_FORMATS': (
+ ISO_8601,
+ ),
+ 'TIME_FORMAT': None,
+
+ # Pending deprecation
+ 'FILTER_BACKEND': None,
}
@@ -89,6 +113,7 @@ IMPORT_STRINGS = (
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
'DEFAULT_MODEL_SERIALIZER_CLASS',
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
+ 'DEFAULT_FILTER_BACKENDS',
'FILTER_BACKEND',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 44633f5a..4410f285 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -115,7 +115,7 @@
</div>
<div class="response-info">
<pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %}
-{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|urlize_quoted_links }}</span>
+{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span>
{% endfor %}
</div>{{ content|urlize_quoted_links }}</pre>{% endautoescape %}
</div>
diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html
index e10ce20f..b7629327 100644
--- a/rest_framework/templates/rest_framework/login.html
+++ b/rest_framework/templates/rest_framework/login.html
@@ -1,53 +1,3 @@
-{% load url from future %}
-{% load rest_framework %}
-<html>
+{% extends "rest_framework/login_base.html" %}
- <head>
- <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>
- <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
- <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>
- </head>
-
- <body class="container">
-
-<div class="container-fluid" style="margin-top: 30px">
- <div class="row-fluid">
-
- <div class="well" style="width: 320px; margin-left: auto; margin-right: auto">
- <div class="row-fluid">
- <div>
- <h3 style="margin: 0 0 20px;">Django REST framework</h3>
- </div>
- </div><!-- /row fluid -->
-
- <div class="row-fluid">
- <div>
- <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">
- {% csrf_token %}
- <div id="div_id_username" class="clearfix control-group">
- <div class="controls">
- <Label class="span4">Username:</label>
- <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">
- </div>
- </div>
- <div id="div_id_password" class="clearfix control-group">
- <div class="controls">
- <Label class="span4">Password:</label>
- <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">
- </div>
- </div>
- <input type="hidden" name="next" value="{{ next }}" />
- <div class="form-actions-no-box">
- <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit">
- </div>
- </form>
- </div>
- </div><!-- /row fluid -->
- </div><!--/span-->
-
- </div><!-- /.row-fluid -->
- </div>
-
- </div>
- </body>
-</html>
+{# Override this template in your own templates directory to customize #}
diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html
new file mode 100644
index 00000000..a3e73b6b
--- /dev/null
+++ b/rest_framework/templates/rest_framework/login_base.html
@@ -0,0 +1,51 @@
+{% load url from future %}
+{% load rest_framework %}
+<html>
+
+ <head>
+ {% block style %}
+ {% block bootstrap_theme %}<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>{% endblock %}
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>
+ {% endblock %}
+ </head>
+
+ <body class="container">
+
+ <div class="container-fluid" style="margin-top: 30px">
+ <div class="row-fluid">
+ <div class="well" style="width: 320px; margin-left: auto; margin-right: auto">
+ <div class="row-fluid">
+ <div>
+ {% block branding %}<h3 style="margin: 0 0 20px;">Django REST framework</h3>{% endblock %}
+ </div>
+ </div><!-- /row fluid -->
+
+ <div class="row-fluid">
+ <div>
+ <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">
+ {% csrf_token %}
+ <div id="div_id_username" class="clearfix control-group">
+ <div class="controls">
+ <Label class="span4">Username:</label>
+ <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">
+ </div>
+ </div>
+ <div id="div_id_password" class="clearfix control-group">
+ <div class="controls">
+ <Label class="span4">Password:</label>
+ <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">
+ </div>
+ </div>
+ <input type="hidden" name="next" value="{{ next }}" />
+ <div class="form-actions-no-box">
+ <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit">
+ </div>
+ </form>
+ </div>
+ </div><!-- /.row-fluid -->
+ </div><!--/.well-->
+ </div><!-- /.row-fluid -->
+ </div><!-- /.container-fluid -->
+ </body>
+</html>
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index c21ddcd7..c86b6456 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -4,11 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch
from django.http import QueryDict
from django.utils.html import escape
from django.utils.safestring import SafeData, mark_safe
-from rest_framework.compat import urlparse
-from rest_framework.compat import force_text
-from rest_framework.compat import six
-import re
-import string
+from rest_framework.compat import urlparse, force_text, six, smart_urlquote
+import re, string
register = template.Library()
@@ -112,22 +109,6 @@ def replace_query_param(url, key, val):
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
-# Bunch of stuff cloned from urlize
-LEADING_PUNCTUATION = ['(', '<', '&lt;', '"', "'"]
-TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '&gt;', '"', "'"]
-DOTS = ['&middot;', '*', '\xe2\x80\xa2', '&#149;', '&bull;', '&#8226;']
-unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)')
-word_split_re = re.compile(r'(\s+)')
-punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \
- ('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]),
- '|'.join([re.escape(x) for x in TRAILING_PUNCTUATION])))
-simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$')
-link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+')
-html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE)
-hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL)
-trailing_empty_content_re = re.compile(r'(?:<p>(?:&nbsp;|\s|<br \/>)*?</p>\s*)+\Z')
-
-
# And the template tags themselves...
@register.simple_tag
@@ -195,15 +176,25 @@ def add_class(value, css_class):
return value
+# Bunch of stuff cloned from urlize
+TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"]
+WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('&lt;', '&gt;'),
+ ('"', '"'), ("'", "'")]
+word_split_re = re.compile(r'(\s+)')
+simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE)
+simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE)
+simple_email_re = re.compile(r'^\S+@\S+\.\S+$')
+
+
@register.filter
def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True):
"""
Converts any URLs in text into clickable links.
- Works on http://, https://, www. links and links ending in .org, .net or
- .com. Links can have trailing punctuation (periods, commas, close-parens)
- and leading punctuation (opening parens) and it'll still do the right
- thing.
+ Works on http://, https://, www. links, and also on links ending in one of
+ the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org).
+ Links can have trailing punctuation (periods, commas, close-parens) and
+ leading punctuation (opening parens) and it'll still do the right thing.
If trim_url_limit is not None, the URLs in link text longer than this limit
will truncated to trim_url_limit-3 characters and appended with an elipsis.
@@ -216,24 +207,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
safe_input = isinstance(text, SafeData)
words = word_split_re.split(force_text(text))
- nofollow_attr = nofollow and ' rel="nofollow"' or ''
for i, word in enumerate(words):
match = None
if '.' in word or '@' in word or ':' in word:
- match = punctuation_re.match(word)
- if match:
- lead, middle, trail = match.groups()
+ # Deal with punctuation.
+ lead, middle, trail = '', word, ''
+ for punctuation in TRAILING_PUNCTUATION:
+ if middle.endswith(punctuation):
+ middle = middle[:-len(punctuation)]
+ trail = punctuation + trail
+ for opening, closing in WRAPPING_PUNCTUATION:
+ if middle.startswith(opening):
+ middle = middle[len(opening):]
+ lead = lead + opening
+ # Keep parentheses at the end only if they're balanced.
+ if (middle.endswith(closing)
+ and middle.count(closing) == middle.count(opening) + 1):
+ middle = middle[:-len(closing)]
+ trail = closing + trail
+
# Make URL we want to point to.
url = None
- if middle.startswith('http://') or middle.startswith('https://'):
- url = middle
- elif middle.startswith('www.') or ('@' not in middle and \
- middle and middle[0] in string.ascii_letters + string.digits and \
- (middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))):
- url = 'http://%s' % middle
- elif '@' in middle and not ':' in middle and simple_email_re.match(middle):
- url = 'mailto:%s' % middle
+ nofollow_attr = ' rel="nofollow"' if nofollow else ''
+ if simple_url_re.match(middle):
+ url = smart_urlquote(middle)
+ elif simple_url_2_re.match(middle):
+ url = smart_urlquote('http://%s' % middle)
+ elif not ':' in middle and simple_email_re.match(middle):
+ local, domain = middle.rsplit('@', 1)
+ try:
+ domain = domain.encode('idna').decode('ascii')
+ except UnicodeError:
+ continue
+ url = 'mailto:%s@%s' % (local, domain)
nofollow_attr = ''
+
# Make link.
if url:
trimmed = trim_url(middle)
@@ -251,4 +259,15 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
words[i] = mark_safe(word)
elif autoescape:
words[i] = escape(word)
- return mark_safe(''.join(words))
+ return ''.join(words)
+
+
+@register.filter
+def break_long_headers(header):
+ """
+ Breaks headers longer than 160 characters (~page length)
+ when possible (are comma separated)
+ """
+ if len(header) > 160 and ',' in header:
+ header = mark_safe('<br> ' + ', <br>'.join(header.split(',')))
+ return header
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index 7b754af5..8e6d3e51 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -2,23 +2,29 @@ from __future__ import unicode_literals
from django.contrib.auth.models import User
from django.http import HttpResponse
from django.test import Client, TestCase
+from django.utils import unittest
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions
from rest_framework import permissions
from rest_framework import status
-from rest_framework.authtoken.models import Token
from rest_framework.authentication import (
BaseAuthentication,
TokenAuthentication,
BasicAuthentication,
- SessionAuthentication
+ SessionAuthentication,
+ OAuthAuthentication,
+ OAuth2Authentication
)
-from rest_framework.compat import patterns
+from rest_framework.authtoken.models import Token
+from rest_framework.compat import patterns, url, include
+from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
+from rest_framework.compat import oauth, oauth_provider
from rest_framework.tests.utils import RequestFactory
from rest_framework.views import APIView
import json
import base64
-
+import time
+import datetime
factory = RequestFactory()
@@ -41,8 +47,19 @@ urlpatterns = patterns('',
(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
+ (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
+ (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
+ permission_classes=[permissions.TokenHasReadWriteScope]))
)
+if oauth2_provider is not None:
+ urlpatterns += patterns('',
+ url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
+ url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
+ url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
+ permission_classes=[permissions.TokenHasReadWriteScope])),
+ )
+
class BasicAuthTests(TestCase):
"""Basic authentication"""
@@ -146,7 +163,7 @@ class TokenAuthTests(TestCase):
def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
- auth = "Token " + self.key
+ auth = 'Token ' + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -222,3 +239,317 @@ class IncorrectCredentialsTests(TestCase):
response = view(request)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.data, {'detail': 'Bad credentials'})
+
+
+class OAuthTests(TestCase):
+ """OAuth 1.0a authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ # these imports are here because oauth is optional and hiding them in try..except block or compat
+ # could obscure problems if something breaks
+ from oauth_provider.models import Consumer, Resource
+ from oauth_provider.models import Token as OAuthToken
+ from oauth_provider import consts
+
+ self.consts = consts
+
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CONSUMER_KEY = 'consumer_key'
+ self.CONSUMER_SECRET = 'consumer_secret'
+ self.TOKEN_KEY = "token_key"
+ self.TOKEN_SECRET = "token_secret"
+
+ self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
+ name='example', user=self.user, status=self.consts.ACCEPTED)
+
+ self.resource = Resource.objects.create(name="resource name", url="api/")
+ self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, resource=self.resource,
+ token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True
+ )
+
+ def _create_authorization_header(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+
+ return req.to_header()["Authorization"]
+
+ def _create_authorization_url_parameters(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+ return dict(req)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_passing_oauth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_repeated_nonce_failing_oauth(self):
+ """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ # simulate reply attack auth header containes already used (nonce, timestamp) pair
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_token_removed_failing_oauth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.token.delete()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_consumer_status_not_accepted_failing_oauth(self):
+ """Ensure POSTing when consumer status is anything other than ACCEPTED fails"""
+ for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED):
+ self.consumer.status = consumer_status
+ self.consumer.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_request_token_failing_oauth(self):
+ """Ensure POSTing with unauthorized request token instead of access token fails"""
+ self.token.token_type = self.token.REQUEST
+ self.token.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_urlencoded_parameters(self):
+ """Ensure POSTing with x-www-form-urlencoded auth parameters passes"""
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_url_parameters(self):
+ """Ensure GETing with auth in url parameters passes"""
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_hmac_sha1_signature_passes(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_readonly_resource_passing_auth(self):
+ """Ensure POSTing with a readonly resource instead of a write scope fails"""
+ read_only_access_token = self.token
+ read_only_access_token.resource.is_readonly = True
+ read_only_access_token.resource.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth-with-scope/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_readonly_resource_failing_auth(self):
+ """Ensure POSTing with a readonly resource instead of a write scope fails"""
+ read_only_access_token = self.token
+ read_only_access_token.resource.is_readonly = True
+ read_only_access_token.resource.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth-with-scope/', params)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_write_resource_passing_auth(self):
+ """Ensure POSTing with a write resource succeed"""
+ read_write_access_token = self.token
+ read_write_access_token.resource.is_readonly = False
+ read_write_access_token.resource.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth-with-scope/', params)
+ self.assertEqual(response.status_code, 200)
+
+
+class OAuth2Tests(TestCase):
+ """OAuth 2.0 authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CLIENT_ID = 'client_key'
+ self.CLIENT_SECRET = 'client_secret'
+ self.ACCESS_TOKEN = "access_token"
+ self.REFRESH_TOKEN = "refresh_token"
+
+ self.oauth2_client = oauth2_provider_models.Client.objects.create(
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ redirect_uri='',
+ client_type=0,
+ name='example',
+ user=None,
+ )
+
+ self.access_token = oauth2_provider_models.AccessToken.objects.create(
+ token=self.ACCESS_TOKEN,
+ client=self.oauth2_client,
+ user=self.user,
+ )
+ self.refresh_token = oauth2_provider_models.RefreshToken.objects.create(
+ user=self.user,
+ access_token=self.access_token,
+ client=self.oauth2_client
+ )
+
+ def _create_authorization_header(self, token=None):
+ return "Bearer {0}".format(token or self.access_token.token)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_type_failing(self):
+ """Ensure that a wrong token type lead to the correct HTTP error status code"""
+ auth = "Wrong token-type-obsviously"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_format_failing(self):
+ """Ensure that a wrong token format lead to the correct HTTP error status code"""
+ auth = "Bearer wrong token format"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_failing(self):
+ """Ensure that a wrong token lead to the correct HTTP error status code"""
+ auth = "Bearer wrong-token"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_passing_auth(self):
+ """Ensure GETing form over OAuth with correct client credentials succeed"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_passing_auth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_token_removed_failing_auth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.access_token.delete()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_refresh_token_failing_auth(self):
+ """Ensure POSTing with refresh token instead of access token fails"""
+ auth = self._create_authorization_header(token=self.refresh_token.token)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_expired_access_token_failing_auth(self):
+ """Ensure POSTing with expired access token fails with an 'Invalid token' error"""
+ self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
+ self.access_token.save()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+ self.assertIn('Invalid token', response.content)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_invalid_scope_failing_auth(self):
+ """Ensure POSTing with a readonly scope instead of a write scope fails"""
+ read_only_access_token = self.access_token
+ read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
+ read_only_access_token.save()
+ auth = self._create_authorization_header(token=read_only_access_token.token)
+ response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_valid_scope_passing_auth(self):
+ """Ensure POSTing with a write scope succeed"""
+ read_write_access_token = self.access_token
+ read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
+ read_write_access_token.save()
+ auth = self._create_authorization_header(token=read_write_access_token.token)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py
index 5b3315bc..52c1a34c 100644
--- a/rest_framework/tests/description.py
+++ b/rest_framework/tests/description.py
@@ -4,6 +4,7 @@ from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.views import APIView
from rest_framework.compat import apply_markdown
+from rest_framework.utils.formatting import get_view_name, get_view_description
# We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring
@@ -49,22 +50,16 @@ MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
class TestViewNamesAndDescriptions(TestCase):
- def test_resource_name_uses_classname_by_default(self):
- """Ensure Resource names are based on the classname by default."""
+ def test_view_name_uses_class_name(self):
+ """
+ Ensure view names are based on the class name.
+ """
class MockView(APIView):
pass
- self.assertEqual(MockView().get_name(), 'Mock')
+ self.assertEqual(get_view_name(MockView), 'Mock')
- def test_resource_name_can_be_set_explicitly(self):
- """Ensure Resource names can be set using the 'get_name' method."""
- example = 'Some Other Name'
- class MockView(APIView):
- def get_name(self):
- return example
- self.assertEqual(MockView().get_name(), example)
-
- def test_resource_description_uses_docstring_by_default(self):
- """Ensure Resource names are based on the docstring by default."""
+ def test_view_description_uses_docstring(self):
+ """Ensure view descriptions are based on the docstring."""
class MockView(APIView):
"""an example docstring
====================
@@ -81,44 +76,32 @@ class TestViewNamesAndDescriptions(TestCase):
# hash style header #"""
- self.assertEqual(MockView().get_description(), DESCRIPTION)
-
- def test_resource_description_can_be_set_explicitly(self):
- """Ensure Resource descriptions can be set using the 'get_description' method."""
- example = 'Some other description'
-
- class MockView(APIView):
- """docstring"""
- def get_description(self):
- return example
- self.assertEqual(MockView().get_description(), example)
+ self.assertEqual(get_view_description(MockView), DESCRIPTION)
- def test_resource_description_supports_unicode(self):
+ def test_view_description_supports_unicode(self):
+ """
+ Unicode in docstrings should be respected.
+ """
class MockView(APIView):
"""Проверка"""
pass
- self.assertEqual(MockView().get_description(), "Проверка")
-
-
- def test_resource_description_does_not_require_docstring(self):
- """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method."""
- example = 'Some other description'
-
- class MockView(APIView):
- def get_description(self):
- return example
- self.assertEqual(MockView().get_description(), example)
+ self.assertEqual(get_view_description(MockView), "Проверка")
- def test_resource_description_can_be_empty(self):
- """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string."""
+ def test_view_description_can_be_empty(self):
+ """
+ Ensure that if a view has no docstring,
+ then it's description is the empty string.
+ """
class MockView(APIView):
pass
- self.assertEqual(MockView().get_description(), '')
+ self.assertEqual(get_view_description(MockView), '')
def test_markdown(self):
- """Ensure markdown to HTML works as expected"""
+ """
+ Ensure markdown to HTML works as expected.
+ """
if apply_markdown:
gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21
lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
index 840ed320..3cdfa0f6 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/fields.py
@@ -3,10 +3,14 @@ General serializer field tests.
"""
from __future__ import unicode_literals
import datetime
+from decimal import Decimal
+
from django.db import models
from django.test import TestCase
from django.core import validators
+
from rest_framework import serializers
+from rest_framework.serializers import Serializer
class TimestampedModel(models.Model):
@@ -59,37 +63,586 @@ class BasicFieldTests(TestCase):
serializer = CharPrimaryKeyModelSerializer()
self.assertEqual(serializer.fields['id'].read_only, False)
- def test_TimeField_from_native(self):
+
+class DateFieldTest(TestCase):
+ """
+ Tests for the DateFieldTest from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native('1984-07-31')
+
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_from_native_datetime_date(self):
+ """
+ Make sure from_native() accepts a datetime.date instance.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual(result_1, datetime.date(1984, 7, 31))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+ result = f.from_native('1984 -- 31')
+
+ self.assertEqual(datetime.date(1984, 1, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+
+ try:
+ f.from_native('1984-07-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_date(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid date.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984-13-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns datetime as default.
+ """
+ f = serializers.DateField()
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with 'iso-8601' returns iso formated date.
+ """
+ f = serializers.DateField(format='iso-8601')
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984-07-31', result_1)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateField(format="%Y - %m.%d")
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984 - 07.31', result_1)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+
+class DateTimeFieldTest(TestCase):
+ """
+ Tests for the DateTimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native('1984-07-31 04:31')
+ result_2 = f.from_native('1984-07-31 04:31:59')
+ result_3 = f.from_native('1984-07-31 04:31:59.000200')
+
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3)
+
+ def test_from_native_datetime_datetime(self):
+ """
+ Make sure from_native() accepts a datetime.datetime instance.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31))
+ self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59))
+ self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+ result = f.from_native('1984 -- 04:59')
+
+ self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+
+ try:
+ f.from_native('1984-07-31 04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_datetime(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid datetime.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.DateTimeField()
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual(datetime.datetime(1984, 7, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with format=iso-8601 returns iso formatted datetime.
+ """
+ f = serializers.DateTimeField(format='iso-8601')
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984-07-31T00:00:00', result_1)
+ self.assertEqual('1984-07-31T04:31:00', result_2)
+ self.assertEqual('1984-07-31T04:31:59', result_3)
+ self.assertEqual('1984-07-31T04:31:59.000200', result_4)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateTimeField(format="%Y - %H:%M")
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984 - 00:00', result_1)
+ self.assertEqual('1984 - 04:31', result_2)
+ self.assertEqual('1984 - 04:31', result_3)
+ self.assertEqual('1984 - 04:31', result_4)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+
+class TimeFieldTest(TestCase):
+ """
+ Tests for the TimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
f = serializers.TimeField()
- result = f.from_native('12:34:56.987654')
+ result_1 = f.from_native('04:31')
+ result_2 = f.from_native('04:31:59')
+ result_3 = f.from_native('04:31:59.000200')
- self.assertEqual(datetime.time(12, 34, 56, 987654), result)
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
- def test_TimeField_from_native_datetime_time(self):
+ def test_from_native_datetime_time(self):
"""
Make sure from_native() accepts a datetime.time instance.
"""
f = serializers.TimeField()
- result = f.from_native(datetime.time(12, 34, 56))
- self.assertEqual(result, datetime.time(12, 34, 56))
+ result_1 = f.from_native(datetime.time(4, 31))
+ result_2 = f.from_native(datetime.time(4, 31, 59))
+ result_3 = f.from_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.time(4, 31))
+ self.assertEqual(result_2, datetime.time(4, 31, 59))
+ self.assertEqual(result_3, datetime.time(4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+ result = f.from_native('04 -- 31')
- def test_TimeField_from_native_empty(self):
+ self.assertEqual(datetime.time(4, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+
+ try:
+ f.from_native('04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
f = serializers.TimeField()
result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.TimeField()
+ result = f.from_native(None)
+
self.assertEqual(result, None)
- def test_TimeField_from_native_invalid_time(self):
+ def test_from_native_invalid_time(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid time.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns time object as default.
+ """
f = serializers.TimeField()
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with format='iso-8601' returns iso formatted time.
+ """
+ f = serializers.TimeField(format='iso-8601')
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04:31:00', result_1)
+ self.assertEqual('04:31:59', result_2)
+ self.assertEqual('04:31:59.000200', result_3)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.TimeField(format="%H - %S [%f]")
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04 - 00 [000000]', result_1)
+ self.assertEqual('04 - 59 [000000]', result_2)
+ self.assertEqual('04 - 59 [000200]', result_3)
+
+
+class DecimalFieldTest(TestCase):
+ """
+ Tests for the DecimalField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts string values
+ """
+ f = serializers.DecimalField()
+ result_1 = f.from_native('9000')
+ result_2 = f.from_native('1.00000001')
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_from_native_invalid_string(self):
+ """
+ Make sure from_native() raises ValidationError on passing invalid string
+ """
+ f = serializers.DecimalField()
try:
- f.from_native('12:69:12')
+ f.from_native('123.45.6')
except validators.ValidationError as e:
- self.assertEqual(e.messages, ["'12:69:12' value has an invalid "
- "format. It must be a valid time "
- "in the HH:MM[:ss[.uuuuuu]] format."])
+ self.assertEqual(e.messages, ["Enter a number."])
else:
self.fail("ValidationError was not properly raised")
- def test_TimeFieldModelSerializer(self):
- serializer = TimeFieldModelSerializer()
- self.assertTrue(isinstance(serializer.fields['clock'], serializers.TimeField))
+ def test_from_native_integer(self):
+ """
+ Make sure from_native() accepts integer values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(9000)
+
+ self.assertEqual(Decimal('9000'), result)
+
+ def test_from_native_float(self):
+ """
+ Make sure from_native() accepts float values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(1.00000001)
+
+ self.assertEqual(Decimal('1.00000001'), result)
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns Decimal as string.
+ """
+ f = serializers.DecimalField()
+
+ result_1 = f.to_native(Decimal('9000'))
+ result_2 = f.to_native(Decimal('1.00000001'))
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+ def test_valid_serialization(self):
+ """
+ Make sure the serializer works correctly
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_value=9010,
+ min_value=9000,
+ max_digits=6,
+ decimal_places=2)
+
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
+
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
+
+ def test_raise_max_value(self):
+ """
+ Make sure max_value violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '123'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
+
+ def test_raise_min_value(self):
+ """
+ Make sure min_value violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(min_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '99'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
+
+ def test_raise_max_digits(self):
+ """
+ Make sure max_digits violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_digits=5)
+
+ s = DecimalSerializer(data={'decimal_field': '123.456'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
+
+ def test_raise_max_decimal_places(self):
+ """
+ Make sure max_decimal_places violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '123.4567'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
+
+ def test_raise_max_whole_digits(self):
+ """
+ Make sure max_whole_digits violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '12345.6'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file
diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/filters.py
new file mode 100644
index 00000000..8ae6d530
--- /dev/null
+++ b/rest_framework/tests/filters.py
@@ -0,0 +1,474 @@
+from __future__ import unicode_literals
+import datetime
+from decimal import Decimal
+from django.db import models
+from django.core.urlresolvers import reverse
+from django.test import TestCase
+from django.test.client import RequestFactory
+from django.utils import unittest
+from rest_framework import generics, serializers, status, filters
+from rest_framework.compat import django_filters, patterns, url
+from rest_framework.tests.models import BasicModel
+
+factory = RequestFactory()
+
+
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
+if django_filters:
+ # Basic filter on a list view.
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These class are used to test a filter class.
+ class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterClassRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These classes are used to test a misconfigured filter class.
+ class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+
+ class Meta:
+ model = BasicModel
+ fields = ['text']
+
+ class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = MisconfiguredFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class FilterClassDetailView(generics.RetrieveAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # Regression test for #814
+ class FilterableItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FilterableItem
+
+ class FilterFieldsQuerysetView(generics.ListCreateAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class GetQuerysetView(generics.ListCreateAPIView):
+ serializer_class = FilterableItemSerializer
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ def get_queryset(self):
+ return FilterableItem.objects.all()
+
+ urlpatterns = patterns('',
+ url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
+ url(r'^$', FilterClassRootView.as_view(), name='root-view'),
+ url(r'^get-queryset/$', GetQuerysetView.as_view(),
+ name='get-queryset-view'),
+ )
+
+
+class CommonFilteringTestCase(TestCase):
+ def _serialize_object(self, obj):
+ return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+
+ def setUp(self):
+ """
+ Create 10 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(10):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ self._serialize_object(obj)
+ for obj in self.objects.all()
+ ]
+
+
+class IntegrationTestFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_fields_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter works.
+ search_date = datetime.date(2012, 9, 22)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] == search_date]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_filter_with_queryset(self):
+ """
+ Regression test for #814.
+ """
+ view = FilterFieldsQuerysetView.as_view()
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_filter_with_get_queryset_only(self):
+ """
+ Regression test for #834.
+ """
+ view = GetQuerysetView.as_view()
+ request = factory.get('/get-queryset/')
+ view(request).render()
+ # Used to raise "issubclass() arg 2 must be a class or tuple of classes"
+ # here when neither `model' nor `queryset' was specified.
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_class_root_view(self):
+ """
+ GET requests to filtered ListCreateAPIView that have a filter_class set
+ should return filtered results.
+ """
+ view = FilterClassRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter set with 'lt' in the filter class works.
+ search_decimal = Decimal('4.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter set with 'gt' in the filter class works.
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the text filter set with 'icontains' in the filter class works.
+ search_text = 'ff'
+ request = factory.get('/?text=%s' % search_text)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if search_text in f['text'].lower()]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date and
+ f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_incorrectly_configured_filter(self):
+ """
+ An error should be displayed when the filter class is misconfigured.
+ """
+ view = IncorrectlyConfiguredRootView.as_view()
+
+ request = factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_unknown_filter(self):
+ """
+ GET requests with filters that aren't configured should return 200.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ search_integer = 10
+ request = factory.get('/?integer=%s' % search_integer)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class IntegrationTestDetailFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered detail views.
+ """
+ urls = 'rest_framework.tests.filters'
+
+ def _get_url(self, item):
+ return reverse('detail-view', kwargs=dict(pk=item.pk))
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_detail_view(self):
+ """
+ GET requests to filtered RetrieveAPIView that have a filter_class set
+ should return filtered results.
+ """
+ item = self.objects.all()[0]
+ data = self._serialize_object(item)
+
+ # Basic test with no filter.
+ response = self.client.get(self._get_url(item))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, data)
+
+ # Tests that the decimal filter set that should fail.
+ search_decimal = Decimal('4.25')
+ high_item = self.objects.filter(decimal__gt=search_decimal)[0]
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Tests that the decimal filter set that should succeed.
+ search_decimal = Decimal('4.25')
+ low_item = self.objects.filter(decimal__lt=search_decimal)[0]
+ low_item_data = self._serialize_object(low_item)
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, low_item_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
+ valid_item_data = self._serialize_object(valid_item)
+ response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, valid_item_data)
+
+
+class SearchFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class SearchFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # z abc
+ # zz bcd
+ # zzz cde
+ # ...
+ for idx in range(10):
+ title = 'z' * (idx + 1)
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ SearchFilterModel(title=title, text=text).save()
+
+ def test_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'z', 'text': 'abc'},
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+ def test_exact_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=zzz')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'zzz', 'text': 'cde'}
+ ]
+ )
+
+ def test_startswith_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', '^text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+
+class OrdringFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class OrderingFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # zyx abc
+ # yxw bcd
+ # xwv cde
+ for idx in range(3):
+ title = (
+ chr(ord('z') - idx) +
+ chr(ord('y') - idx) +
+ chr(ord('x') - idx)
+ )
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ OrdringFilterModel(title=title, text=text).save()
+
+ def test_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
+ def test_reverse_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=-text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_incorrectfield_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=foobar')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering_using_string(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
deleted file mode 100644
index 8c13947c..00000000
--- a/rest_framework/tests/filterset.py
+++ /dev/null
@@ -1,169 +0,0 @@
-from __future__ import unicode_literals
-import datetime
-from decimal import Decimal
-from django.test import TestCase
-from django.test.client import RequestFactory
-from django.utils import unittest
-from rest_framework import generics, status, filters
-from rest_framework.compat import django_filters
-from rest_framework.tests.models import FilterableItem, BasicModel
-
-factory = RequestFactory()
-
-
-if django_filters:
- # Basic filter on a list view.
- class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_fields = ['decimal', 'date']
- filter_backend = filters.DjangoFilterBackend
-
- # These class are used to test a filter class.
- class SeveralFieldsFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
- decimal = django_filters.NumberFilter(lookup_type='lt')
- date = django_filters.DateFilter(lookup_type='gt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterClassRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = SeveralFieldsFilter
- filter_backend = filters.DjangoFilterBackend
-
- # These classes are used to test a misconfigured filter class.
- class MisconfiguredFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
-
- class Meta:
- model = BasicModel
- fields = ['text']
-
- class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = MisconfiguredFilter
- filter_backend = filters.DjangoFilterBackend
-
-
-class IntegrationTestFiltering(TestCase):
- """
- Integration tests for filtered list views.
- """
-
- def setUp(self):
- """
- Create 10 FilterableItem instances.
- """
- base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
- for i in range(10):
- text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
- decimal = base_data[1] + i
- date = base_data[2] - datetime.timedelta(days=i * 2)
- FilterableItem(text=text, decimal=decimal, date=date).save()
-
- self.objects = FilterableItem.objects
- self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
- ]
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_get_filtered_fields_root_view(self):
- """
- GET requests to paginated ListCreateAPIView should return paginated results.
- """
- view = FilterFieldsRootView.as_view()
-
- # Basic test with no filter.
- request = factory.get('/')
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- # Tests that the decimal filter works.
- search_decimal = Decimal('2.25')
- request = factory.get('/?decimal=%s' % search_decimal)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] == search_decimal]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the date filter works.
- search_date = datetime.date(2012, 9, 22)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] == search_date]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_get_filtered_class_root_view(self):
- """
- GET requests to filtered ListCreateAPIView that have a filter_class set
- should return filtered results.
- """
- view = FilterClassRootView.as_view()
-
- # Basic test with no filter.
- request = factory.get('/')
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- # Tests that the decimal filter set with 'lt' in the filter class works.
- search_decimal = Decimal('4.25')
- request = factory.get('/?decimal=%s' % search_decimal)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] < search_decimal]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the date filter set with 'gt' in the filter class works.
- search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the text filter set with 'icontains' in the filter class works.
- search_text = 'ff'
- request = factory.get('/?text=%s' % search_text)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if search_text in f['text'].lower()]
- self.assertEqual(response.data, expected_data)
-
- # Tests that multiple filters works.
- search_decimal = Decimal('5.25')
- search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date and
- f['decimal'] < search_decimal]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_incorrectly_configured_filter(self):
- """
- An error should be displayed when the filter class is misconfigured.
- """
- view = IncorrectlyConfiguredRootView.as_view()
-
- request = factory.get('/')
- self.assertRaises(AssertionError, view, request)
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_unknown_filter(self):
- """
- GET requests with filters that aren't configured should return 200.
- """
- view = FilterFieldsRootView.as_view()
-
- search_integer = 10
- request = factory.get('/?integer=%s' % search_integer)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index f8f2ddaa..2799d143 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -1,5 +1,6 @@
from __future__ import unicode_literals
from django.db import models
+from django.shortcuts import get_object_or_404
from django.test import TestCase
from rest_framework import generics, serializers, status
from rest_framework.tests.utils import RequestFactory
@@ -38,6 +39,7 @@ class SlugBasedInstanceView(InstanceView):
"""
model = SlugBasedModel
serializer_class = SlugSerializer
+ lookup_field = 'slug'
class TestRootView(TestCase):
@@ -60,7 +62,8 @@ class TestRootView(TestCase):
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/')
- response = self.view(request).render()
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data)
@@ -71,7 +74,8 @@ class TestRootView(TestCase):
content = {'text': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
created = self.objects.get(id=4)
@@ -84,7 +88,8 @@ class TestRootView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."})
@@ -93,7 +98,8 @@ class TestRootView(TestCase):
DELETE requests to ListCreateAPIView should not be allowed
"""
request = factory.delete('/')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."})
@@ -102,7 +108,8 @@ class TestRootView(TestCase):
OPTIONS requests to ListCreateAPIView should return metadata
"""
request = factory.options('/')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
expected = {
'parses': [
'application/json',
@@ -126,7 +133,8 @@ class TestRootView(TestCase):
content = {'id': 999, 'text': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
created = self.objects.get(id=4)
@@ -154,7 +162,8 @@ class TestInstanceView(TestCase):
GET requests to RetrieveUpdateDestroyAPIView should return a single object.
"""
request = factory.get('/1')
- response = self.view(request, pk=1).render()
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data[0])
@@ -165,7 +174,8 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."})
@@ -176,7 +186,8 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk='1').render()
+ with self.assertNumQueries(2):
+ response = self.view(request, pk='1').render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
@@ -190,7 +201,8 @@ class TestInstanceView(TestCase):
request = factory.patch('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=1).render()
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
@@ -201,7 +213,8 @@ class TestInstanceView(TestCase):
DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
"""
request = factory.delete('/1')
- response = self.view(request, pk=1).render()
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertEqual(response.content, six.b(''))
ids = [obj.id for obj in self.objects.all()]
@@ -212,7 +225,8 @@ class TestInstanceView(TestCase):
OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
"""
request = factory.options('/')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
expected = {
'parses': [
'application/json',
@@ -236,7 +250,8 @@ class TestInstanceView(TestCase):
content = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=1).render()
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
@@ -251,7 +266,8 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=1).render()
+ with self.assertNumQueries(3):
+ response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
@@ -263,10 +279,11 @@ class TestInstanceView(TestCase):
at the requested url if it doesn't exist.
"""
content = {'text': 'foobar'}
- # pk fields can not be created on demand, only the database can set th pk for a new object
+ # pk fields can not be created on demand, only the database can set the pk for a new object
request = factory.put('/5', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=5).render()
+ with self.assertNumQueries(3):
+ response = self.view(request, pk=5).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
new_obj = self.objects.get(pk=5)
self.assertEqual(new_obj.text, 'foobar')
@@ -279,13 +296,55 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/test_slug', json.dumps(content),
content_type='application/json')
- response = self.slug_based_view(request, slug='test_slug').render()
+ with self.assertNumQueries(2):
+ response = self.slug_based_view(request, slug='test_slug').render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'})
new_obj = SlugBasedModel.objects.get(slug='test_slug')
self.assertEqual(new_obj.text, 'foobar')
+class TestOverriddenGetObject(TestCase):
+ """
+ Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
+ queryset/model mechanism but instead overrides get_object()
+ """
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
+ """
+ Example detail view for override of get_object().
+ """
+ model = BasicModel
+
+ def get_object(self):
+ pk = int(self.kwargs['pk'])
+ return get_object_or_404(BasicModel.objects.all(), id=pk)
+
+ self.view = OverriddenGetObjectView.as_view()
+
+ def test_overridden_get_object_view(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object.
+ """
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+
# Regression test for #285
class CommentSerializer(serializers.ModelSerializer):
@@ -319,7 +378,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
self.assertEqual(created.content, 'foobar')
-# Test for particularly ugly regression with m2m in browseable API
+# Test for particularly ugly regression with m2m in browsable API
class ClassB(models.Model):
name = models.CharField(max_length=255)
@@ -344,9 +403,76 @@ class ExampleView(generics.ListCreateAPIView):
class TestM2MBrowseableAPI(TestCase):
def test_m2m_in_browseable_api(self):
"""
- Test for particularly ugly regression with m2m in browseable API
+ Test for particularly ugly regression with m2m in browsable API
"""
request = factory.get('/', HTTP_ACCEPT='text/html')
view = ExampleView().as_view()
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class InclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='foo')
+
+
+class ExclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='other')
+
+
+class TestFilterBackendAppliedToViews(TestCase):
+
+ def setUp(self):
+ """
+ Create 3 BasicModel instances to filter on.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ def test_get_root_view_filters_by_name_with_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
+ request = factory.get('/')
+ response = root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(len(response.data), 1)
+ self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
+
+ def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return empty list when all models are filtered out.
+ """
+ root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
+ request = factory.get('/')
+ response = root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, [])
+
+ def test_get_instance_view_filters_out_name_with_filter_backend(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
+ """
+ instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
+ request = factory.get('/1')
+ response = instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.data, {'detail': 'Not found'})
+
+ def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
+ """
+ instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
+ request = factory.get('/1')
+ response = instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index 9a61f299..8fc6ba77 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -27,6 +27,14 @@ class PhotoSerializer(serializers.Serializer):
return Photo(**attrs)
+class AlbumSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title')
+
+ class Meta:
+ model = Album
+ fields = ('title', 'url')
+
+
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
@@ -73,6 +81,8 @@ class PhotoListCreate(generics.ListCreateAPIView):
class AlbumDetail(generics.RetrieveAPIView):
model = Album
+ serializer_class = AlbumSerializer
+ lookup_field = 'title'
class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
@@ -180,6 +190,36 @@ class TestManyToManyHyperlinkedView(TestCase):
self.assertEqual(response.data, self.data[0])
+class TestHyperlinkedIdentityFieldLookup(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 Album instances.
+ """
+ titles = ['foo', 'bar', 'baz']
+ for title in titles:
+ album = Album(title=title)
+ album.save()
+ self.detail_view = AlbumDetail.as_view()
+ self.data = {
+ 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'},
+ 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
+ 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'}
+ }
+
+ def test_lookup_field(self):
+ """
+ GET requests to AlbumDetail view should return serialized Albums
+ with a url field keyed by `title`.
+ """
+ for album in Album.objects.all():
+ request = factory.get('/albums/{0}/'.format(album.title))
+ response = self.detail_view(request, title=album.title)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[album.title])
+
+
class TestCreateWithForeignKeys(TestCase):
urls = 'rest_framework.tests.hyperlinkedserializers'
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index f2117538..40e41a64 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -58,13 +58,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
-# Model to test filtering.
-class FilterableItem(RESTFrameworkModel):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
-
-
# Model for regression test for #285
class Comment(RESTFrameworkModel):
diff --git a/rest_framework/tests/multitable_inheritance.py b/rest_framework/tests/multitable_inheritance.py
new file mode 100644
index 00000000..00c15327
--- /dev/null
+++ b/rest_framework/tests/multitable_inheritance.py
@@ -0,0 +1,67 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import RESTFrameworkModel
+
+
+# Models
+class ParentModel(RESTFrameworkModel):
+ name1 = models.CharField(max_length=100)
+
+
+class ChildModel(ParentModel):
+ name2 = models.CharField(max_length=100)
+
+
+class AssociatedModel(RESTFrameworkModel):
+ ref = models.OneToOneField(ParentModel, primary_key=True)
+ name = models.CharField(max_length=100)
+
+
+# Serializers
+class DerivedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChildModel
+
+
+class AssociatedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = AssociatedModel
+
+
+# Tests
+class IneritedModelSerializationTests(TestCase):
+
+ def test_multitable_inherited_model_fields_as_expected(self):
+ """
+ Assert that the parent pointer field is not included in the fields
+ serialized fields
+ """
+ child = ChildModel(name1='parent name', name2='child name')
+ serializer = DerivedModelSerializer(child)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name1', 'name2', 'id']))
+
+ def test_onetoone_primary_key_model_fields_as_expected(self):
+ """
+ Assert that a model with a onetoone field that is the primary key is
+ not treated like a derived model
+ """
+ parent = ParentModel(name1='parent name')
+ associate = AssociatedModel(name='hello', ref=parent)
+ serializer = AssociatedModelSerializer(associate)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name', 'ref']))
+
+ def test_data_is_valid_without_parent_ptr(self):
+ """
+ Assert that the pointer to the parent table is not a required field
+ for input data
+ """
+ data = {
+ 'name1': 'parent name',
+ 'name2': 'child name',
+ }
+ serializer = DerivedModelSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 6b9970a6..e538a78e 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -1,17 +1,24 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
+from django.db import models
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters
-from rest_framework.tests.models import BasicModel, FilterableItem
+from rest_framework.tests.models import BasicModel
factory = RequestFactory()
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
class RootView(generics.ListCreateAPIView):
"""
Example description for OPTIONS.
@@ -20,21 +27,6 @@ class RootView(generics.ListCreateAPIView):
paginate_by = 10
-if django_filters:
- class DecimalFilter(django_filters.FilterSet):
- decimal = django_filters.NumberFilter(lookup_type='lt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- paginate_by = 10
- filter_class = DecimalFilter
- filter_backend = filters.DjangoFilterBackend
-
-
class DefaultPageSizeKwargView(generics.ListAPIView):
"""
View for testing default paginate_by_param usage
@@ -73,7 +65,9 @@ class IntegrationTestPagination(TestCase):
GET requests to paginated ListCreateAPIView should return paginated results.
"""
request = factory.get('/')
- response = self.view(request).render()
+ # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 26)
self.assertEqual(response.data['results'], self.data[:10])
@@ -81,7 +75,8 @@ class IntegrationTestPagination(TestCase):
self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next'])
- response = self.view(request).render()
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 26)
self.assertEqual(response.data['results'], self.data[10:20])
@@ -89,7 +84,8 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['next'])
- response = self.view(request).render()
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 26)
self.assertEqual(response.data['results'], self.data[20:])
@@ -112,20 +108,37 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
]
- self.view = FilterFieldsRootView.as_view()
@unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_get_paginated_filtered_root_view(self):
+ def test_get_django_filter_paginated_filtered_root_view(self):
"""
GET requests to paginated filtered ListCreateAPIView should return
paginated results. The next and previous links should preserve the
filtered parameters.
"""
+ class DecimalFilter(django_filters.FilterSet):
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_class = DecimalFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ view = FilterFieldsRootView.as_view()
+
+ EXPECTED_NUM_QUERIES = 2
+
request = factory.get('/?decimal=15.20')
- response = self.view(request).render()
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 15)
self.assertEqual(response.data['results'], self.data[:10])
@@ -133,7 +146,8 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next'])
- response = self.view(request).render()
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 15)
self.assertEqual(response.data['results'], self.data[10:15])
@@ -141,7 +155,53 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['previous'])
- response = self.view(request).render()
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ def test_get_basic_paginated_filtered_root_view(self):
+ """
+ Same as `test_get_django_filter_paginated_filtered_root_view`,
+ except using a custom filter backend instead of the django-filter
+ backend,
+ """
+
+ class DecimalFilterBackend(filters.BaseFilterBackend):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(decimal__lt=Decimal(request.GET['decimal']))
+
+ class BasicFilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_backends = (DecimalFilterBackend,)
+
+ view = BasicFilterFieldsRootView.as_view()
+
+ request = factory.get('/?decimal=15.20')
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[10:15])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ with self.assertNumQueries(2):
+ response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 15)
self.assertEqual(response.data['results'], self.data[:10])
diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py
index 539c5b44..7699e10c 100644
--- a/rest_framework/tests/parsers.py
+++ b/rest_framework/tests/parsers.py
@@ -1,10 +1,11 @@
from __future__ import unicode_literals
from rest_framework.compat import StringIO
from django import forms
+from django.core.files.uploadhandler import MemoryFileUploadHandler
from django.test import TestCase
from django.utils import unittest
from rest_framework.compat import etree
-from rest_framework.parsers import FormParser
+from rest_framework.parsers import FormParser, FileUploadParser
from rest_framework.parsers import XMLParser
import datetime
@@ -82,3 +83,33 @@ class TestXMLParser(TestCase):
parser = XMLParser()
data = parser.parse(self._complex_data_input)
self.assertEqual(data, self._complex_data)
+
+
+class TestFileUploadParser(TestCase):
+ def setUp(self):
+ class MockRequest(object):
+ pass
+ from io import BytesIO
+ self.stream = BytesIO(
+ "Test text file".encode('utf-8')
+ )
+ request = MockRequest()
+ request.upload_handlers = (MemoryFileUploadHandler(),)
+ request.META = {
+ 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'),
+ 'HTTP_CONTENT_LENGTH': 14,
+ }
+ self.parser_context = {'request': request, 'kwargs': {}}
+
+ def test_parse(self):
+ """ Make sure the `QueryDict` works OK """
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ data_and_files = parser.parse(self.stream, None, self.parser_context)
+ file_obj = data_and_files.files['file']
+ self.assertEqual(file_obj._size, 14)
+
+ def test_get_filename(self):
+ parser = FileUploadParser()
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ self.assertEqual(filename, 'file.txt'.encode('utf-8'))
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
index b5702a48..b1eed9a7 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -26,42 +26,44 @@ urlpatterns = patterns('',
)
+# ManyToMany
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail')
-
class Meta:
model = ManyToManyTarget
+ fields = ('url', 'name', 'sources')
class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ManyToManySource
+ fields = ('url', 'name', 'targets')
+# ForeignKey
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail')
-
class Meta:
model = ForeignKeyTarget
+ fields = ('url', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ForeignKeySource
+ fields = ('url', 'name', 'target')
# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = NullableForeignKeySource
+ fields = ('url', 'name', 'target')
-# OneToOne
+# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
- nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail')
-
class Meta:
model = OneToOneTarget
+ fields = ('url', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
index a125ba65..f6d006b3 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/relations_nested.py
@@ -6,38 +6,30 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- depth = 1
- model = ForeignKeySource
-
-
-class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+ depth = 1
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = FlatForeignKeySourceSerializer(many=True)
-
class Meta:
model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
+ depth = 1
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- depth = 1
model = NullableForeignKeySource
-
-
-class NullableOneToOneSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableOneToOneSource
+ fields = ('id', 'name', 'target')
+ depth = 1
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- nullable_source = NullableOneToOneSourceSerializer()
-
class Meta:
model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
+ depth = 1
class ReverseForeignKeyTests(TestCase):
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
index d6ae3176..5ce8b567 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/relations_pk.py
@@ -5,41 +5,44 @@ from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, Fore
from rest_framework.compat import six
+# ManyToMany
class ManyToManyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.PrimaryKeyRelatedField(many=True)
-
class Meta:
model = ManyToManyTarget
+ fields = ('id', 'name', 'sources')
class ManyToManySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManySource
+ fields = ('id', 'name', 'targets')
+# ForeignKey
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.PrimaryKeyRelatedField(many=True)
-
class Meta:
model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableForeignKeySource
+ fields = ('id', 'name', 'target')
-# OneToOne
+# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- nullable_source = serializers.PrimaryKeyRelatedField()
-
class Meta:
model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid
@@ -407,14 +410,14 @@ class PKNullableOneToOneTests(TestCase):
target.save()
new_target = OneToOneTarget(name='target-2')
new_target.save()
- source = NullableOneToOneSource(name='source-1', target=target)
+ source = NullableOneToOneSource(name='source-1', target=new_target)
source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'nullable_source': 1},
- {'id': 2, 'name': 'target-2', 'nullable_source': None},
+ {'id': 1, 'name': 'target-1', 'nullable_source': None},
+ {'id': 2, 'name': 'target-2', 'nullable_source': 1},
]
self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py
index 4892f7a6..97e5af20 100644
--- a/rest_framework/tests/request.py
+++ b/rest_framework/tests/request.py
@@ -58,6 +58,14 @@ class TestMethodOverloading(TestCase):
request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'}))
self.assertEqual(request.method, 'DELETE')
+ def test_x_http_method_override_header(self):
+ """
+ POST requests can also be overloaded to another method by setting
+ the X-HTTP-Method-Override header.
+ """
+ request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
+ self.assertEqual(request.method, 'DELETE')
+
class TestContentParsing(TestCase):
def test_standard_behaviour_determines_no_content_GET(self):
diff --git a/rest_framework/tests/routers.py b/rest_framework/tests/routers.py
new file mode 100644
index 00000000..4e4765cb
--- /dev/null
+++ b/rest_framework/tests/routers.py
@@ -0,0 +1,55 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework import status
+from rest_framework.response import Response
+from rest_framework import viewsets
+from rest_framework.decorators import link, action
+from rest_framework.routers import SimpleRouter
+import copy
+
+factory = RequestFactory()
+
+
+class BasicViewSet(viewsets.ViewSet):
+ def list(self, request, *args, **kwargs):
+ return Response({'method': 'list'})
+
+ @action()
+ def action1(self, request, *args, **kwargs):
+ return Response({'method': 'action1'})
+
+ @action()
+ def action2(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @link()
+ def link1(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @link()
+ def link2(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
+
+class TestSimpleRouter(TestCase):
+ def setUp(self):
+ self.router = SimpleRouter()
+
+ def test_link_and_action_decorator(self):
+ routes = self.router.get_routes(BasicViewSet)
+ decorator_routes = routes[2:]
+ # Make sure all these endpoints exist and none have been clobbered
+ for i, endpoint in enumerate(['action1', 'action2', 'link1', 'link2']):
+ route = decorator_routes[i]
+ # check url listing
+ self.assertEqual(route.url,
+ '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint))
+ # check method to function mapping
+ if endpoint.startswith('action'):
+ method_map = 'post'
+ else:
+ method_map = 'get'
+ self.assertEqual(route.mapping[method_map], endpoint)
+
+
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index d0300f9e..db3881f9 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -3,7 +3,7 @@ from django.utils.datastructures import MultiValueDict
from django.test import TestCase
from rest_framework import serializers
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
- BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,
+ BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
import datetime
import pickle
@@ -78,6 +78,18 @@ class PersonSerializer(serializers.ModelSerializer):
read_only_fields = ('age',)
+class PersonSerializerInvalidReadOnly(serializers.ModelSerializer):
+ """
+ Testing for #652.
+ """
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+ read_only_fields = ('age', 'info')
+
+
class AlbumsSerializer(serializers.ModelSerializer):
class Meta:
@@ -189,6 +201,12 @@ class BasicTests(TestCase):
# Assert age is unchanged (35)
self.assertEqual(instance.age, self.person_data['age'])
+ def test_invalid_read_only_fields(self):
+ """
+ Regression test for #652.
+ """
+ self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, [])
+
class DictStyleSerializer(serializers.Serializer):
"""
@@ -261,25 +279,6 @@ class ValidationTests(TestCase):
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.errors, {})
- def test_bad_type_data_is_false(self):
- """
- Data of the wrong type is not valid.
- """
- data = ['i am', 'a', 'list']
- serializer = CommentSerializer(self.comment, data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
-
- data = 'and i am a string'
- serializer = CommentSerializer(self.comment, data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
-
- data = 42
- serializer = CommentSerializer(self.comment, data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
-
def test_cross_field_validation(self):
class CommentSerializerWithCrossFieldValidator(CommentSerializer):
@@ -376,7 +375,6 @@ class CustomValidationTests(TestCase):
def validate_email(self, attrs, source):
value = attrs[source]
-
return attrs
def validate_content(self, attrs, source):
@@ -757,6 +755,43 @@ class ManyRelatedTests(TestCase):
self.assertEqual(serializer.data, expected)
+ def test_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_depth_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+ depth = 1
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post',
+ 'blogpostcomment_set': [
+ {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1},
+ {'id': 2, 'text': 'I love this blog post', 'blog_post': 1}
+ ]
+ }
+ self.assertEqual(serializer.data, expected)
+
def test_callable_source(self):
post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I love this blog post")
@@ -786,8 +821,6 @@ class RelatedTraversalTest(TestCase):
post = BlogPost.objects.create(title="Test blog post", writer=user)
post.blogpostcomment_set.create(text="I love this blog post")
- from rest_framework.tests.models import BlogPostComment
-
class PersonSerializer(serializers.ModelSerializer):
class Meta:
model = Person
@@ -987,23 +1020,26 @@ class SerializerPickleTests(TestCase):
class DepthTest(TestCase):
def test_implicit_nesting(self):
+
writer = Person.objects.create(name="django", age=1)
post = BlogPost.objects.create(title="Test blog post", writer=writer)
+ comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
- class BlogPostSerializer(serializers.ModelSerializer):
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
class Meta:
- model = BlogPost
- depth = 1
+ model = BlogPostComment
+ depth = 2
- serializer = BlogPostSerializer(instance=post)
- expected = {'id': 1, 'title': 'Test blog post',
- 'writer': {'id': 1, 'name': 'django', 'age': 1}}
+ serializer = BlogPostCommentSerializer(instance=comment)
+ expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
self.assertEqual(serializer.data, expected)
def test_explicit_nesting(self):
writer = Person.objects.create(name="django", age=1)
post = BlogPost.objects.create(title="Test blog post", writer=writer)
+ comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
class PersonSerializer(serializers.ModelSerializer):
class Meta:
@@ -1015,9 +1051,15 @@ class DepthTest(TestCase):
class Meta:
model = BlogPost
- serializer = BlogPostSerializer(instance=post)
- expected = {'id': 1, 'title': 'Test blog post',
- 'writer': {'id': 1, 'name': 'django', 'age': 1}}
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ blog_post = BlogPostSerializer()
+
+ class Meta:
+ model = BlogPostComment
+
+ serializer = BlogPostCommentSerializer(instance=comment)
+ expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
self.assertEqual(serializer.data, expected)
@@ -1072,3 +1114,32 @@ class NestedSerializerContextTests(TestCase):
# This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers
AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data
+
+
+class DeserializeListTestCase(TestCase):
+
+ def setUp(self):
+ self.data = {
+ 'email': 'nobody@nowhere.com',
+ 'content': 'This is some test content',
+ 'created': datetime.datetime(2013, 3, 7),
+ }
+
+ def test_no_errors(self):
+ data = [self.data.copy() for x in range(0, 3)]
+ serializer = CommentSerializer(data=data, many=True)
+ self.assertTrue(serializer.is_valid())
+ self.assertTrue(isinstance(serializer.object, list))
+ self.assertTrue(
+ all((isinstance(item, Comment) for item in serializer.object))
+ )
+
+ def test_errors_return_as_list(self):
+ invalid_item = self.data.copy()
+ invalid_item['email'] = ''
+ data = [self.data.copy(), invalid_item, self.data.copy()]
+
+ serializer = CommentSerializer(data=data, many=True)
+ self.assertFalse(serializer.is_valid())
+ expected = [{}, {'email': ['This field is required.']}, {}]
+ self.assertEqual(serializer.errors, expected)
diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py
new file mode 100644
index 00000000..8b0ded1a
--- /dev/null
+++ b/rest_framework/tests/serializer_bulk_update.py
@@ -0,0 +1,278 @@
+"""
+Tests to cover bulk create and update using serializers.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class BulkCreateSerializerTests(TestCase):
+ """
+ Creating multiple instances using serializers.
+ """
+
+ def setUp(self):
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
+
+ self.BookSerializer = BookSerializer
+
+ def test_bulk_create_success(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 2,
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+
+ def test_bulk_create_errors(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 'foo',
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {},
+ {'id': ['Enter a whole number.']}
+ ]
+
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_list_datatype(self):
+ """
+ Data containing list of incorrect data type should return errors.
+ """
+ data = ['foo', 'bar', 'baz']
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = [
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']}
+ ]
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_datatype(self):
+ """
+ Data containing a single incorrect data type should return errors.
+ """
+ data = 123
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_object(self):
+ """
+ Data containing only a single object, instead of a list of objects
+ should return errors.
+ """
+ data = {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+
+class BulkUpdateSerializerTests(TestCase):
+ """
+ Updating multiple instances using serializers.
+ """
+
+ def setUp(self):
+ class Book(object):
+ """
+ A data type that can be persisted to a mock storage backend
+ with `.save()` and `.delete()`.
+ """
+ object_map = {}
+
+ def __init__(self, id, title, author):
+ self.id = id
+ self.title = title
+ self.author = author
+
+ def save(self):
+ Book.object_map[self.id] = self
+
+ def delete(self):
+ del Book.object_map[self.id]
+
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.id = attrs['id']
+ instance.title = attrs['title']
+ instance.author = attrs['author']
+ return instance
+ return Book(**attrs)
+
+ self.Book = Book
+ self.BookSerializer = BookSerializer
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 2,
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+
+ for item in data:
+ book = Book(item['id'], item['title'], item['author'])
+ book.save()
+
+ def books(self):
+ """
+ Return all the objects in the mock storage backend.
+ """
+ return self.Book.object_map.values()
+
+ def test_bulk_update_success(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 2,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+ new_data = self.BookSerializer(self.books(), many=True).data
+
+ self.assertEqual(data, new_data)
+
+ def test_bulk_update_and_create(self):
+ """
+ Bulk update serialization may also include created items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+ new_data = self.BookSerializer(self.books(), many=True).data
+ self.assertEqual(data, new_data)
+
+ def test_bulk_update_invalid_create(self):
+ """
+ Bulk update serialization without allow_add_remove may not create items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_bulk_update_error(self):
+ """
+ Incorrect bulk update serialization should return error data.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 'foo',
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'id': ['Enter a whole number.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py
new file mode 100644
index 00000000..71d0e24b
--- /dev/null
+++ b/rest_framework/tests/serializer_nested.py
@@ -0,0 +1,246 @@
+"""
+Tests to cover nested serializers.
+
+Doesn't cover model serializers.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class WritableNestedSerializerBasicTests(TestCase):
+ """
+ Tests for deserializing nested entities.
+ Basic tests that use serializers that simply restore to dicts.
+ """
+
+ def setUp(self):
+ class TrackSerializer(serializers.Serializer):
+ order = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ duration = serializers.IntegerField()
+
+ class AlbumSerializer(serializers.Serializer):
+ album_name = serializers.CharField(max_length=100)
+ artist = serializers.CharField(max_length=100)
+ tracks = TrackSerializer(many=True)
+
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_nested_validation_success(self):
+ """
+ Correct nested serialization should return the input data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+
+ def test_nested_validation_error(self):
+ """
+ Incorrect nested serialization should return appropriate error data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
+ ]
+ }
+ expected_errors = {
+ 'tracks': [
+ {},
+ {},
+ {'duration': ['Enter a whole number.']}
+ ]
+ }
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_many_nested_validation_error(self):
+ """
+ Incorrect nested serialization should return appropriate error data
+ when multiple entities are being deserialized.
+ """
+
+ data = [
+ {
+ 'album_name': 'Russian Red',
+ 'artist': 'I Love Your Glasses',
+ 'tracks': [
+ {'order': 1, 'title': 'Cigarettes', 'duration': 121},
+ {'order': 2, 'title': 'No Past Land', 'duration': 198},
+ {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
+ ]
+ },
+ {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
+ ]
+ }
+ ]
+ expected_errors = [
+ {},
+ {
+ 'tracks': [
+ {},
+ {},
+ {'duration': ['Enter a whole number.']}
+ ]
+ }
+ ]
+
+ serializer = self.AlbumSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+
+class WritableNestedSerializerObjectTests(TestCase):
+ """
+ Tests for deserializing nested entities.
+ These tests use serializers that restore to concrete objects.
+ """
+
+ def setUp(self):
+ # Couple of concrete objects that we're going to deserialize into
+ class Track(object):
+ def __init__(self, order, title, duration):
+ self.order, self.title, self.duration = order, title, duration
+
+ def __eq__(self, other):
+ return (
+ self.order == other.order and
+ self.title == other.title and
+ self.duration == other.duration
+ )
+
+ class Album(object):
+ def __init__(self, album_name, artist, tracks):
+ self.album_name, self.artist, self.tracks = album_name, artist, tracks
+
+ def __eq__(self, other):
+ return (
+ self.album_name == other.album_name and
+ self.artist == other.artist and
+ self.tracks == other.tracks
+ )
+
+ # And their corresponding serializers
+ class TrackSerializer(serializers.Serializer):
+ order = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ duration = serializers.IntegerField()
+
+ def restore_object(self, attrs, instance=None):
+ return Track(attrs['order'], attrs['title'], attrs['duration'])
+
+ class AlbumSerializer(serializers.Serializer):
+ album_name = serializers.CharField(max_length=100)
+ artist = serializers.CharField(max_length=100)
+ tracks = TrackSerializer(many=True)
+
+ def restore_object(self, attrs, instance=None):
+ return Album(attrs['album_name'], attrs['artist'], attrs['tracks'])
+
+ self.Album, self.Track = Album, Track
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_nested_validation_success(self):
+ """
+ Correct nested serialization should return a restored object
+ that corresponds to the input data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+ expected_object = self.Album(
+ album_name='Discovery',
+ artist='Daft Punk',
+ tracks=[
+ self.Track(order=1, title='One More Time', duration=235),
+ self.Track(order=2, title='Aerodynamic', duration=184),
+ self.Track(order=3, title='Digital Love', duration=239),
+ ]
+ )
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected_object)
+
+ def test_many_nested_validation_success(self):
+ """
+ Correct nested serialization should return multiple restored objects
+ that corresponds to the input data when multiple objects are
+ being deserialized.
+ """
+
+ data = [
+ {
+ 'album_name': 'Russian Red',
+ 'artist': 'I Love Your Glasses',
+ 'tracks': [
+ {'order': 1, 'title': 'Cigarettes', 'duration': 121},
+ {'order': 2, 'title': 'No Past Land', 'duration': 198},
+ {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
+ ]
+ },
+ {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+ ]
+ expected_object = [
+ self.Album(
+ album_name='Russian Red',
+ artist='I Love Your Glasses',
+ tracks=[
+ self.Track(order=1, title='Cigarettes', duration=121),
+ self.Track(order=2, title='No Past Land', duration=198),
+ self.Track(order=3, title='They Don\'t Believe', duration=191),
+ ]
+ ),
+ self.Album(
+ album_name='Discovery',
+ artist='Daft Punk',
+ tracks=[
+ self.Track(order=1, title='One More Time', duration=235),
+ self.Track(order=2, title='Aerodynamic', duration=184),
+ self.Track(order=3, title='Digital Love', duration=239),
+ ]
+ )
+ ]
+
+ serializer = self.AlbumSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected_object)
diff --git a/rest_framework/tests/status.py b/rest_framework/tests/status.py
deleted file mode 100644
index e1644a6b..00000000
--- a/rest_framework/tests/status.py
+++ /dev/null
@@ -1,13 +0,0 @@
-"""Tests for the status module"""
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import status
-
-
-class TestStatus(TestCase):
- """Simple sanity test to check the status module"""
-
- def test_status(self):
- """Ensure the status module is present and correct."""
- self.assertEqual(200, status.HTTP_200_OK)
- self.assertEqual(404, status.HTTP_404_NOT_FOUND)
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 810cad63..93ea9816 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -1,3 +1,6 @@
+"""
+Provides various throttling policies.
+"""
from __future__ import unicode_literals
from django.core.cache import cache
from rest_framework import exceptions
@@ -28,9 +31,8 @@ class SimpleRateThrottle(BaseThrottle):
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.
- The rate (requests / seconds) is set by a :attr:`throttle` attribute
- on the :class:`.View` class. The attribute is a string of the form 'number of
- requests/period'.
+ The rate (requests / seconds) is set by a `throttle` attribute on the View
+ class. The attribute is a string of the form 'number_of_requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py
index af21ac79..d51374b0 100644
--- a/rest_framework/utils/breadcrumbs.py
+++ b/rest_framework/utils/breadcrumbs.py
@@ -1,26 +1,37 @@
from __future__ import unicode_literals
from django.core.urlresolvers import resolve, get_script_prefix
+from rest_framework.utils.formatting import get_view_name
def get_breadcrumbs(url):
- """Given a url returns a list of breadcrumbs, which are each a tuple of (name, url)."""
+ """
+ Given a url returns a list of breadcrumbs, which are each a
+ tuple of (name, url).
+ """
from rest_framework.views import APIView
def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):
- """Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url."""
+ """
+ Add tuples of (name, url) to the breadcrumbs list,
+ progressively chomping off parts of the url.
+ """
try:
(view, unused_args, unused_kwargs) = resolve(url)
except Exception:
pass
else:
- # Check if this is a REST framework view, and if so add it to the breadcrumbs
- if isinstance(getattr(view, 'cls_instance', None), APIView):
+ # Check if this is a REST framework view,
+ # and if so add it to the breadcrumbs
+ cls = getattr(view, 'cls', None)
+ if cls is not None and issubclass(cls, APIView):
# Don't list the same view twice in a row.
# Probably an optional trailing slash.
if not seen or seen[-1] != view:
- breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url))
+ suffix = getattr(view, 'suffix', None)
+ name = get_view_name(view.cls, suffix)
+ breadcrumbs_list.insert(0, (name, prefix + url))
seen.append(view)
if url == '':
@@ -28,11 +39,15 @@ def get_breadcrumbs(url):
return breadcrumbs_list
elif url.endswith('/'):
- # Drop trailing slash off the end and continue to try to resolve more breadcrumbs
- return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen)
-
- # Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs
- return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen)
+ # Drop trailing slash off the end and continue to try to
+ # resolve more breadcrumbs
+ url = url.rstrip('/')
+ return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
+
+ # Drop trailing non-slash off the end and continue to try to
+ # resolve more breadcrumbs
+ url = url[:url.rfind('/') + 1]
+ return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
prefix = get_script_prefix().rstrip('/')
url = url[len(prefix):]
diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py
new file mode 100644
index 00000000..ebadb3a6
--- /dev/null
+++ b/rest_framework/utils/formatting.py
@@ -0,0 +1,80 @@
+"""
+Utility functions to return a formatted name and description for a given view.
+"""
+from __future__ import unicode_literals
+
+from django.utils.html import escape
+from django.utils.safestring import mark_safe
+from rest_framework.compat import apply_markdown
+import re
+
+
+def _remove_trailing_string(content, trailing):
+ """
+ Strip trailing component `trailing` from `content` if it exists.
+ Used when generating names from view classes.
+ """
+ if content.endswith(trailing) and content != trailing:
+ return content[:-len(trailing)]
+ return content
+
+
+def _remove_leading_indent(content):
+ """
+ Remove leading indent from a block of text.
+ Used when generating descriptions from docstrings.
+ """
+ whitespace_counts = [len(line) - len(line.lstrip(' '))
+ for line in content.splitlines()[1:] if line.lstrip()]
+
+ # unindent the content if needed
+ if whitespace_counts:
+ whitespace_pattern = '^' + (' ' * min(whitespace_counts))
+ content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
+ content = content.strip('\n')
+ return content
+
+
+def _camelcase_to_spaces(content):
+ """
+ Translate 'CamelCaseNames' to 'Camel Case Names'.
+ Used when generating names from view classes.
+ """
+ camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))'
+ content = re.sub(camelcase_boundry, ' \\1', content).strip()
+ return ' '.join(content.split('_')).title()
+
+
+def get_view_name(cls, suffix=None):
+ """
+ Return a formatted name for an `APIView` class or `@api_view` function.
+ """
+ name = cls.__name__
+ name = _remove_trailing_string(name, 'View')
+ name = _remove_trailing_string(name, 'ViewSet')
+ name = _camelcase_to_spaces(name)
+ if suffix:
+ name += ' ' + suffix
+ return name
+
+
+def get_view_description(cls, html=False):
+ """
+ Return a description for an `APIView` class or `@api_view` function.
+ """
+ description = cls.__doc__ or ''
+ description = _remove_leading_indent(description)
+ if html:
+ return markup_description(description)
+ return description
+
+
+def markup_description(description):
+ """
+ Apply HTML markup to the given description.
+ """
+ if apply_markdown:
+ description = apply_markdown(description)
+ else:
+ description = escape(description).replace('\n', '<br />')
+ return mark_safe(description)
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 81cbdcbb..555fa2f4 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -1,54 +1,16 @@
"""
-Provides an APIView class that is used as the base of all class-based views.
+Provides an APIView class that is the base of all views in REST framework.
"""
from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied
-from django.http import Http404
-from django.utils.html import escape
-from django.utils.safestring import mark_safe
+from django.http import Http404, HttpResponse
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions
-from rest_framework.compat import View, apply_markdown
+from rest_framework.compat import View
from rest_framework.response import Response
from rest_framework.request import Request
from rest_framework.settings import api_settings
-import re
-
-
-def _remove_trailing_string(content, trailing):
- """
- Strip trailing component `trailing` from `content` if it exists.
- Used when generating names from view classes.
- """
- if content.endswith(trailing) and content != trailing:
- return content[:-len(trailing)]
- return content
-
-
-def _remove_leading_indent(content):
- """
- Remove leading indent from a block of text.
- Used when generating descriptions from docstrings.
- """
- whitespace_counts = [len(line) - len(line.lstrip(' '))
- for line in content.splitlines()[1:] if line.lstrip()]
-
- # unindent the content if needed
- if whitespace_counts:
- whitespace_pattern = '^' + (' ' * min(whitespace_counts))
- content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
- content = content.strip('\n')
- return content
-
-
-def _camelcase_to_spaces(content):
- """
- Translate 'CamelCaseNames' to 'Camel Case Names'.
- Used when generating names from view classes.
- """
- camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))'
- content = re.sub(camelcase_boundry, ' \\1', content).strip()
- return ' '.join(content.split('_')).title()
+from rest_framework.utils.formatting import get_view_name, get_view_description
class APIView(View):
@@ -64,22 +26,21 @@ class APIView(View):
@classmethod
def as_view(cls, **initkwargs):
"""
- Override the default :meth:`as_view` to store an instance of the view
- as an attribute on the callable function. This allows us to discover
- information about the view when we do URL reverse lookups.
+ Store the original class on the view function.
+
+ This allows us to discover information about the view when we do URL
+ reverse lookups. Used for breadcrumb generation.
"""
- # TODO: deprecate?
view = super(APIView, cls).as_view(**initkwargs)
- view.cls_instance = cls(**initkwargs)
+ view.cls = cls
return view
@property
def allowed_methods(self):
"""
- Return the list of allowed HTTP methods, uppercased.
+ Wrap Django's private `_allowed_methods` interface in a public property.
"""
- return [method.upper() for method in self.http_method_names
- if hasattr(self, method)]
+ return self._allowed_methods()
@property
def default_response_headers(self):
@@ -90,43 +51,10 @@ class APIView(View):
'Vary': 'Accept'
}
- def get_name(self):
- """
- Return the resource or view class name for use as this view's name.
- Override to customize.
- """
- # TODO: deprecate?
- name = self.__class__.__name__
- name = _remove_trailing_string(name, 'View')
- return _camelcase_to_spaces(name)
-
- def get_description(self, html=False):
- """
- Return the resource or view docstring for use as this view's description.
- Override to customize.
- """
- # TODO: deprecate?
- description = self.__doc__ or ''
- description = _remove_leading_indent(description)
- if html:
- return self.markup_description(description)
- return description
-
- def markup_description(self, description):
- """
- Apply HTML markup to the description of this view.
- """
- # TODO: deprecate?
- if apply_markdown:
- description = apply_markdown(description)
- else:
- description = escape(description).replace('\n', '<br />')
- return mark_safe(description)
-
def metadata(self, request):
return {
- 'name': self.get_name(),
- 'description': self.get_description(),
+ 'name': get_view_name(self.__class__),
+ 'description': get_view_description(self.__class__),
'renders': [renderer.media_type for renderer in self.renderer_classes],
'parses': [parser.media_type for parser in self.parser_classes],
}
@@ -140,7 +68,8 @@ class APIView(View):
def http_method_not_allowed(self, request, *args, **kwargs):
"""
- Called if `request.method` does not correspond to a handler method.
+ If `request.method` does not correspond to a handler method,
+ determine what kind of exception to raise.
"""
raise exceptions.MethodNotAllowed(request.method)
@@ -327,6 +256,12 @@ class APIView(View):
"""
Returns the final response object.
"""
+ # Make the error obvious if a proper response is not returned
+ assert isinstance(response, HttpResponse), (
+ 'Expected a `Response` to be returned from the view, '
+ 'but received a `%s`' % type(response)
+ )
+
if isinstance(response, Response):
if not getattr(request, 'accepted_renderer', None):
neg = self.perform_content_negotiation(request, force=True)
diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py
new file mode 100644
index 00000000..d91323f2
--- /dev/null
+++ b/rest_framework/viewsets.py
@@ -0,0 +1,139 @@
+"""
+ViewSets are essentially just a type of class based view, that doesn't provide
+any method handlers, such as `get()`, `post()`, etc... but instead has actions,
+such as `list()`, `retrieve()`, `create()`, etc...
+
+Actions are only bound to methods at the point of instantiating the views.
+
+ user_list = UserViewSet.as_view({'get': 'list'})
+ user_detail = UserViewSet.as_view({'get': 'retrieve'})
+
+Typically, rather than instantiate views from viewsets directly, you'll
+regsiter the viewset with a router and let the URL conf be determined
+automatically.
+
+ router = DefaultRouter()
+ router.register(r'users', UserViewSet, 'user')
+ urlpatterns = router.urls
+"""
+from __future__ import unicode_literals
+
+from functools import update_wrapper
+from django.utils.decorators import classonlymethod
+from rest_framework import views, generics, mixins
+
+
+class ViewSetMixin(object):
+ """
+ This is the magic.
+
+ Overrides `.as_view()` so that it takes an `actions` keyword that performs
+ the binding of HTTP methods to actions on the Resource.
+
+ For example, to create a concrete view binding the 'GET' and 'POST' methods
+ to the 'list' and 'create' actions...
+
+ view = MyViewSet.as_view({'get': 'list', 'post': 'create'})
+ """
+
+ @classonlymethod
+ def as_view(cls, actions=None, **initkwargs):
+ """
+ Because of the way class based views create a closure around the
+ instantiated view, we need to totally reimplement `.as_view`,
+ and slightly modify the view function that is created and returned.
+ """
+ # The suffix initkwarg is reserved for identifing the viewset type
+ # eg. 'List' or 'Instance'.
+ cls.suffix = None
+
+ # sanitize keyword arguments
+ for key in initkwargs:
+ if key in cls.http_method_names:
+ raise TypeError("You tried to pass in the %s method name as a "
+ "keyword argument to %s(). Don't do that."
+ % (key, cls.__name__))
+ if not hasattr(cls, key):
+ raise TypeError("%s() received an invalid keyword %r" % (
+ cls.__name__, key))
+
+ def view(request, *args, **kwargs):
+ self = cls(**initkwargs)
+ # We also store the mapping of request methods to actions,
+ # so that we can later set the action attribute.
+ # eg. `self.action = 'list'` on an incoming GET request.
+ self.action_map = actions
+
+ # Bind methods to actions
+ # This is the bit that's different to a standard view
+ for method, action in actions.items():
+ handler = getattr(self, action)
+ setattr(self, method, handler)
+
+ # Patch this in as it's otherwise only present from 1.5 onwards
+ if hasattr(self, 'get') and not hasattr(self, 'head'):
+ self.head = self.get
+
+ # And continue as usual
+ return self.dispatch(request, *args, **kwargs)
+
+ # take name and docstring from class
+ update_wrapper(view, cls, updated=())
+
+ # and possible attributes set by decorators
+ # like csrf_exempt from dispatch
+ update_wrapper(view, cls.dispatch, assigned=())
+
+ # We need to set these on the view function, so that breadcrumb
+ # generation can pick out these bits of information from a
+ # resolved URL.
+ view.cls = cls
+ view.suffix = initkwargs.get('suffix', None)
+ return view
+
+ def initialize_request(self, request, *args, **kargs):
+ """
+ Set the `.action` attribute on the view,
+ depending on the request method.
+ """
+ request = super(ViewSetMixin, self).initialize_request(request, *args, **kargs)
+ self.action = self.action_map.get(request.method.lower())
+ return request
+
+
+class ViewSet(ViewSetMixin, views.APIView):
+ """
+ The base ViewSet class does not provide any actions by default.
+ """
+ pass
+
+
+class GenericViewSet(ViewSetMixin, generics.GenericAPIView):
+ """
+ The GenericViewSet class does not provide any actions by default,
+ but does include the base set of generic view behavior, such as
+ the `get_object` and `get_queryset` methods.
+ """
+ pass
+
+
+class ReadOnlyModelViewSet(mixins.RetrieveModelMixin,
+ mixins.ListModelMixin,
+ GenericViewSet):
+ """
+ A viewset that provides default `list()` and `retrieve()` actions.
+ """
+ pass
+
+
+class ModelViewSet(mixins.CreateModelMixin,
+ mixins.RetrieveModelMixin,
+ mixins.UpdateModelMixin,
+ mixins.DestroyModelMixin,
+ mixins.ListModelMixin,
+ GenericViewSet):
+ """
+ A viewset that provides default `create()`, `retrieve()`, `update()`,
+ `partial_update()`, `destroy()` and `list()` actions.
+ """
+ pass