aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-03-18 21:03:05 +0000
committerTom Christie2013-03-18 21:03:05 +0000
commit74fb366c595db87bb71baeffcacfb7d2482e3a18 (patch)
tree2e28cb52542742f32cdd3fbeb625f7f59cba0a3f /rest_framework
parent4c6396108704d38f534a16577de59178b1d0df3b (diff)
parent034c4ce4081dd6d15ea47fb8318754321a3faf0c (diff)
downloaddjango-rest-framework-74fb366c595db87bb71baeffcacfb7d2482e3a18.tar.bz2
Merge branch 'master' into resources-routers
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py8
-rw-r--r--rest_framework/authentication.py276
-rw-r--r--rest_framework/authtoken/migrations/0001_initial.py17
-rw-r--r--rest_framework/authtoken/models.py13
-rw-r--r--rest_framework/compat.py91
-rw-r--r--rest_framework/decorators.py13
-rw-r--r--rest_framework/exceptions.py17
-rw-r--r--rest_framework/fields.py318
-rw-r--r--rest_framework/filters.py3
-rw-r--r--rest_framework/generics.py43
-rw-r--r--rest_framework/mixins.py67
-rw-r--r--rest_framework/negotiation.py5
-rw-r--r--rest_framework/pagination.py27
-rw-r--r--rest_framework/parsers.py54
-rw-r--r--rest_framework/permissions.py69
-rw-r--r--rest_framework/relations.py317
-rw-r--r--rest_framework/renderers.py97
-rw-r--r--rest_framework/request.py83
-rw-r--r--rest_framework/response.py6
-rw-r--r--rest_framework/reverse.py1
-rwxr-xr-xrest_framework/runtests/runcoverage.py7
-rwxr-xr-xrest_framework/runtests/runtests.py2
-rw-r--r--rest_framework/runtests/settings.py32
-rw-r--r--rest_framework/serializers.py258
-rw-r--r--rest_framework/settings.py23
-rw-r--r--rest_framework/six.py389
-rw-r--r--rest_framework/static/rest_framework/css/default.css43
-rw-r--r--rest_framework/static/rest_framework/js/default.js8
-rw-r--r--rest_framework/status.py1
-rw-r--r--rest_framework/templates/rest_framework/base.html120
-rw-r--r--rest_framework/templates/rest_framework/form.html13
-rw-r--r--rest_framework/templates/rest_framework/login.html8
-rw-r--r--rest_framework/templatetags/rest_framework.py26
-rw-r--r--rest_framework/tests/authentication.py466
-rw-r--r--rest_framework/tests/breadcrumbs.py1
-rw-r--r--rest_framework/tests/decorators.py41
-rw-r--r--rest_framework/tests/description.py24
-rw-r--r--rest_framework/tests/fields.py403
-rw-r--r--rest_framework/tests/files.py16
-rw-r--r--rest_framework/tests/filterset.py48
-rw-r--r--rest_framework/tests/genericrelations.py103
-rw-r--r--rest_framework/tests/generics.py224
-rw-r--r--rest_framework/tests/htmlrenderer.py29
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py27
-rw-r--r--rest_framework/tests/models.py52
-rw-r--r--rest_framework/tests/modelviews.py90
-rw-r--r--rest_framework/tests/multitable_inheritance.py67
-rw-r--r--rest_framework/tests/negotiation.py15
-rw-r--r--rest_framework/tests/pagination.py246
-rw-r--r--rest_framework/tests/parsers.py140
-rw-r--r--rest_framework/tests/permissions.py153
-rw-r--r--rest_framework/tests/relations.py16
-rw-r--r--rest_framework/tests/relations_hyperlink.py323
-rw-r--r--rest_framework/tests/relations_nested.py45
-rw-r--r--rest_framework/tests/relations_pk.py289
-rw-r--r--rest_framework/tests/relations_slug.py257
-rw-r--r--rest_framework/tests/renderers.py113
-rw-r--r--rest_framework/tests/request.py30
-rw-r--r--rest_framework/tests/response.py60
-rw-r--r--rest_framework/tests/reverse.py3
-rw-r--r--rest_framework/tests/serializer.py497
-rw-r--r--rest_framework/tests/settings.py1
-rw-r--r--rest_framework/tests/status.py5
-rw-r--r--rest_framework/tests/testcases.py1
-rw-r--r--rest_framework/tests/tests.py1
-rw-r--r--rest_framework/tests/throttling.py5
-rw-r--r--rest_framework/tests/urlpatterns.py76
-rw-r--r--rest_framework/tests/utils.py21
-rw-r--r--rest_framework/tests/validation.py65
-rw-r--r--rest_framework/tests/validators.py329
-rw-r--r--rest_framework/tests/views.py27
-rw-r--r--rest_framework/throttling.py3
-rw-r--r--rest_framework/urlpatterns.py46
-rw-r--r--rest_framework/urls.py1
-rw-r--r--rest_framework/utils/__init__.py100
-rw-r--r--rest_framework/utils/breadcrumbs.py1
-rw-r--r--rest_framework/utils/encoders.py7
-rw-r--r--rest_framework/utils/mediatypes.py5
-rw-r--r--rest_framework/views.py65
79 files changed, 4833 insertions, 2159 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index bc267fad..cf005636 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,3 +1,9 @@
-__version__ = '2.1.16'
+__version__ = '2.2.4'
VERSION = __version__ # synonym
+
+# Header encoding (see RFC5987)
+HTTP_HEADER_ENCODING = 'iso-8859-1'
+
+# Default datetime input and output formats
+ISO_8601 = 'iso-8601'
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 30c78ebc..b4b73699 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -1,15 +1,30 @@
"""
Provides a set of pluggable authentication policies.
"""
-
+from __future__ import unicode_literals
from django.contrib.auth import authenticate
-from django.utils.encoding import smart_unicode, DjangoUnicodeDecodeError
-from rest_framework import exceptions
+from django.core.exceptions import ImproperlyConfigured
+from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
+from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
+from rest_framework.compat import oauth2_provider, oauth2_provider_forms, oauth2_provider_backends
from rest_framework.authtoken.models import Token
import base64
+def get_authorization_header(request):
+ """
+ Return request's 'Authorization:' header, as a bytestring.
+
+ Hide some test client ickyness where the header can be unicode.
+ """
+ auth = request.META.get('HTTP_AUTHORIZATION', b'')
+ if type(auth) == type(''):
+ # Work around django test client oddness
+ auth = auth.encode(HTTP_HEADER_ENCODING)
+ return auth
+
+
class BaseAuthentication(object):
"""
All authentication classes should extend BaseAuthentication.
@@ -21,40 +36,58 @@ class BaseAuthentication(object):
"""
raise NotImplementedError(".authenticate() must be overridden.")
+ def authenticate_header(self, request):
+ """
+ Return a string to be used as the value of the `WWW-Authenticate`
+ header in a `401 Unauthenticated` response, or `None` if the
+ authentication scheme should return `403 Permission Denied` responses.
+ """
+ pass
+
class BasicAuthentication(BaseAuthentication):
"""
HTTP Basic authentication against username/password.
"""
+ www_authenticate_realm = 'api'
def authenticate(self, request):
"""
Returns a `User` if a correct username and password have been supplied
using HTTP Basic authentication. Otherwise returns `None`.
"""
- if 'HTTP_AUTHORIZATION' in request.META:
- auth = request.META['HTTP_AUTHORIZATION'].split()
- if len(auth) == 2 and auth[0].lower() == "basic":
- try:
- auth_parts = base64.b64decode(auth[1]).partition(':')
- except TypeError:
- return None
+ auth = get_authorization_header(request).split()
+
+ if not auth or auth[0].lower() != b'basic':
+ return None
- try:
- userid = smart_unicode(auth_parts[0])
- password = smart_unicode(auth_parts[2])
- except DjangoUnicodeDecodeError:
- return None
+ if len(auth) == 1:
+ msg = 'Invalid basic header. No credentials provided.'
+ raise exceptions.AuthenticationFailed(msg)
+ elif len(auth) > 2:
+ msg = 'Invalid basic header. Credentials string should not contain spaces.'
+ raise exceptions.AuthenticationFailed(msg)
- return self.authenticate_credentials(userid, password)
+ try:
+ auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':')
+ except (TypeError, UnicodeDecodeError):
+ msg = 'Invalid basic header. Credentials not correctly base64 encoded'
+ raise exceptions.AuthenticationFailed(msg)
+
+ userid, password = auth_parts[0], auth_parts[2]
+ return self.authenticate_credentials(userid, password)
def authenticate_credentials(self, userid, password):
"""
Authenticate the userid and password against username and password.
"""
user = authenticate(username=userid, password=password)
- if user is not None and user.is_active:
- return (user, None)
+ if user is None or not user.is_active:
+ raise exceptions.AuthenticationFailed('Invalid username/password')
+ return (user, None)
+
+ def authenticate_header(self, request):
+ return 'Basic realm="%s"' % self.www_authenticate_realm
class SessionAuthentication(BaseAuthentication):
@@ -74,7 +107,7 @@ class SessionAuthentication(BaseAuthentication):
# Unauthenticated, CSRF validation not required
if not user or not user.is_active:
- return
+ return None
# Enforce CSRF validation for session based authentication.
class CSRFCheck(CsrfViewMiddleware):
@@ -85,7 +118,7 @@ class SessionAuthentication(BaseAuthentication):
reason = CSRFCheck().process_view(http_request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
- raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
+ raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason)
# CSRF passed with authenticated user
return (user, None)
@@ -110,16 +143,199 @@ class TokenAuthentication(BaseAuthentication):
"""
def authenticate(self, request):
- auth = request.META.get('HTTP_AUTHORIZATION', '').split()
+ auth = get_authorization_header(request).split()
+
+ if not auth or auth[0].lower() != b'token':
+ return None
+
+ if len(auth) == 1:
+ msg = 'Invalid token header. No credentials provided.'
+ raise exceptions.AuthenticationFailed(msg)
+ elif len(auth) > 2:
+ msg = 'Invalid token header. Token string should not contain spaces.'
+ raise exceptions.AuthenticationFailed(msg)
+
+ return self.authenticate_credentials(auth[1])
+
+ def authenticate_credentials(self, key):
+ try:
+ token = self.model.objects.get(key=key)
+ except self.model.DoesNotExist:
+ raise exceptions.AuthenticationFailed('Invalid token')
+
+ if not token.user.is_active:
+ raise exceptions.AuthenticationFailed('User inactive or deleted')
+
+ return (token.user, token)
+
+ def authenticate_header(self, request):
+ return 'Token'
+
+
+class OAuthAuthentication(BaseAuthentication):
+ """
+ OAuth 1.0a authentication backend using `django-oauth-plus` and `oauth2`.
+
+ Note: The `oauth2` package actually provides oauth1.0a support. Urg.
+ We import it from the `compat` module as `oauth`.
+ """
+ www_authenticate_realm = 'api'
+
+ def __init__(self, *args, **kwargs):
+ super(OAuthAuthentication, self).__init__(*args, **kwargs)
+
+ if oauth is None:
+ raise ImproperlyConfigured(
+ "The 'oauth2' package could not be imported."
+ "It is required for use with the 'OAuthAuthentication' class.")
+
+ if oauth_provider is None:
+ raise ImproperlyConfigured(
+ "The 'django-oauth-plus' package could not be imported."
+ "It is required for use with the 'OAuthAuthentication' class.")
+
+ def authenticate(self, request):
+ """
+ Returns two-tuple of (user, token) if authentication succeeds,
+ or None otherwise.
+ """
+ try:
+ oauth_request = oauth_provider.utils.get_oauth_request(request)
+ except oauth.Error as err:
+ raise exceptions.AuthenticationFailed(err.message)
+
+ oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES
+
+ found = any(param for param in oauth_params if param in oauth_request)
+ missing = list(param for param in oauth_params if param not in oauth_request)
+
+ if not found:
+ # OAuth authentication was not attempted.
+ return None
+
+ if missing:
+ # OAuth was attempted but missing parameters.
+ msg = 'Missing parameters: %s' % (', '.join(missing))
+ raise exceptions.AuthenticationFailed(msg)
+
+ if not self.check_nonce(request, oauth_request):
+ msg = 'Nonce check failed'
+ raise exceptions.AuthenticationFailed(msg)
+
+ try:
+ consumer_key = oauth_request.get_parameter('oauth_consumer_key')
+ consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
+ except oauth_provider_store.InvalidConsumerError as err:
+ raise exceptions.AuthenticationFailed(err)
+
+ if consumer.status != oauth_provider.consts.ACCEPTED:
+ msg = 'Invalid consumer key status: %s' % consumer.get_status_display()
+ raise exceptions.AuthenticationFailed(msg)
+
+ try:
+ token_param = oauth_request.get_parameter('oauth_token')
+ token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param)
+ except oauth_provider_store.InvalidTokenError:
+ msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token')
+ raise exceptions.AuthenticationFailed(msg)
+
+ try:
+ self.validate_token(request, consumer, token)
+ except oauth.Error as err:
+ raise exceptions.AuthenticationFailed(err.message)
+
+ user = token.user
+
+ if not user.is_active:
+ msg = 'User inactive or deleted: %s' % user.username
+ raise exceptions.AuthenticationFailed(msg)
+
+ return (token.user, token)
+
+ def authenticate_header(self, request):
+ """
+ If permission is denied, return a '401 Unauthorized' response,
+ with an appropraite 'WWW-Authenticate' header.
+ """
+ return 'OAuth realm="%s"' % self.www_authenticate_realm
+
+ def validate_token(self, request, consumer, token):
+ """
+ Check the token and raise an `oauth.Error` exception if invalid.
+ """
+ oauth_server, oauth_request = oauth_provider.utils.initialize_server_request(request)
+ oauth_server.verify_request(oauth_request, consumer, token)
+
+ def check_nonce(self, request, oauth_request):
+ """
+ Checks nonce of request, and return True if valid.
+ """
+ return oauth_provider_store.check_nonce(request, oauth_request, oauth_request['oauth_nonce'])
+
+
+class OAuth2Authentication(BaseAuthentication):
+ """
+ OAuth 2 authentication backend using `django-oauth2-provider`
+ """
+ www_authenticate_realm = 'api'
+
+ def __init__(self, *args, **kwargs):
+ super(OAuth2Authentication, self).__init__(*args, **kwargs)
+
+ if oauth2_provider is None:
+ raise ImproperlyConfigured(
+ "The 'django-oauth2-provider' package could not be imported. "
+ "It is required for use with the 'OAuth2Authentication' class.")
+
+ def authenticate(self, request):
+ """
+ Returns two-tuple of (user, token) if authentication succeeds,
+ or None otherwise.
+ """
+
+ auth = get_authorization_header(request).split()
+
+ if not auth or auth[0].lower() != b'bearer':
+ return None
- if len(auth) == 2 and auth[0].lower() == "token":
- key = auth[1]
- try:
- token = self.model.objects.get(key=key)
- except self.model.DoesNotExist:
- return None
+ if len(auth) == 1:
+ msg = 'Invalid bearer header. No credentials provided.'
+ raise exceptions.AuthenticationFailed(msg)
+ elif len(auth) > 2:
+ msg = 'Invalid bearer header. Token string should not contain spaces.'
+ raise exceptions.AuthenticationFailed(msg)
- if token.user.is_active:
- return (token.user, token)
+ return self.authenticate_credentials(request, auth[1])
-# TODO: OAuthAuthentication
+ def authenticate_credentials(self, request, access_token):
+ """
+ Authenticate the request, given the access token.
+ """
+
+ # Authenticate the client
+ oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST)
+ if not oauth2_client_form.is_valid():
+ raise exceptions.AuthenticationFailed('Client could not be validated')
+ client = oauth2_client_form.cleaned_data.get('client')
+
+ # Retrieve the `OAuth2AccessToken` instance from the access_token
+ auth_backend = oauth2_provider_backends.AccessTokenBackend()
+ token = auth_backend.authenticate(access_token, client)
+ if token is None:
+ raise exceptions.AuthenticationFailed('Invalid token')
+
+ user = token.user
+
+ if not user.is_active:
+ msg = 'User inactive or deleted: %s' % user.username
+ raise exceptions.AuthenticationFailed(msg)
+
+ return (token.user, token)
+
+ def authenticate_header(self, request):
+ """
+ Bearer is the only finalized type currently
+
+ Check details on the `OAuth2Authentication.authenticate` method
+ """
+ return 'Bearer realm="%s"' % self.www_authenticate_realm
diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py
index f4e052e4..d5965e40 100644
--- a/rest_framework/authtoken/migrations/0001_initial.py
+++ b/rest_framework/authtoken/migrations/0001_initial.py
@@ -4,6 +4,8 @@ from south.db import db
from south.v2 import SchemaMigration
from django.db import models
+from rest_framework.settings import api_settings
+
try:
from django.contrib.auth import get_user_model
@@ -45,20 +47,7 @@ class Migration(SchemaMigration):
'name': ('django.db.models.fields.CharField', [], {'max_length': '50'})
},
"%s.%s" % (User._meta.app_label, User._meta.module_name): {
- 'Meta': {'object_name': 'User'},
- 'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
- 'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}),
- 'first_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}),
- 'groups': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Group']", 'symmetrical': 'False', 'blank': 'True'}),
- 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
- 'is_active': ('django.db.models.fields.BooleanField', [], {'default': 'True'}),
- 'is_staff': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
- 'is_superuser': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
- 'last_login': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
- 'last_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}),
- 'password': ('django.db.models.fields.CharField', [], {'max_length': '128'}),
- 'user_permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'}),
- 'username': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '30'})
+ 'Meta': {'object_name': User._meta.module_name},
},
'authtoken.token': {
'Meta': {'object_name': 'Token'},
diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py
index 4da2aa62..52c45ad1 100644
--- a/rest_framework/authtoken/models.py
+++ b/rest_framework/authtoken/models.py
@@ -2,6 +2,7 @@ import uuid
import hmac
from hashlib import sha1
from rest_framework.compat import User
+from django.conf import settings
from django.db import models
@@ -13,14 +14,22 @@ class Token(models.Model):
user = models.OneToOneField(User, related_name='auth_token')
created = models.DateTimeField(auto_now_add=True)
+ class Meta:
+ # Work around for a bug in Django:
+ # https://code.djangoproject.com/ticket/19422
+ #
+ # Also see corresponding ticket:
+ # https://github.com/tomchristie/django-rest-framework/issues/705
+ abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
+
def save(self, *args, **kwargs):
if not self.key:
self.key = self.generate_key()
return super(Token, self).save(*args, **kwargs)
def generate_key(self):
- unique = str(uuid.uuid4())
- return hmac.new(unique, digestmod=sha1).hexdigest()
+ unique = uuid.uuid4()
+ return hmac.new(unique.bytes, digestmod=sha1).hexdigest()
def __unicode__(self):
return self.key
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 5508f6c0..7b2ef738 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -3,26 +3,56 @@ The `compat` module provides support for backwards compatibility with older
versions of django/python, and compatibility wrappers around optional packages.
"""
# flake8: noqa
+from __future__ import unicode_literals
+
import django
+# Try to import six from Django, fallback to included `six`.
+try:
+ from django.utils import six
+except ImportError:
+ from rest_framework import six
+
# location of patterns, url, include changes in 1.4 onwards
try:
from django.conf.urls import patterns, url, include
-except:
+except ImportError:
from django.conf.urls.defaults import patterns, url, include
+# Handle django.utils.encoding rename:
+# smart_unicode -> smart_text
+# force_unicode -> force_text
+try:
+ from django.utils.encoding import smart_text
+except ImportError:
+ from django.utils.encoding import smart_unicode as smart_text
+try:
+ from django.utils.encoding import force_text
+except ImportError:
+ from django.utils.encoding import force_unicode as force_text
+
+
# django-filter is optional
try:
import django_filters
-except:
+except ImportError:
django_filters = None
# cStringIO only if it's available, otherwise StringIO
try:
- import cStringIO as StringIO
+ import cStringIO.StringIO as StringIO
+except ImportError:
+ StringIO = six.StringIO
+
+BytesIO = six.BytesIO
+
+
+# urlparse compat import (Required because it changed in python 3.x)
+try:
+ from urllib import parse as urlparse
except ImportError:
- import StringIO
+ import urlparse
# Try to import PIL in either of the two ways it can end up installed.
@@ -54,7 +84,7 @@ else:
try:
from django.contrib.auth.models import User
except ImportError:
- raise ImportError(u"User model is not to be found.")
+ raise ImportError("User model is not to be found.")
# First implementation of Django class-based views did not include head method
@@ -75,11 +105,11 @@ else:
# sanitize keyword arguments
for key in initkwargs:
if key in cls.http_method_names:
- raise TypeError(u"You tried to pass in the %s method name as a "
- u"keyword argument to %s(). Don't do that."
+ raise TypeError("You tried to pass in the %s method name as a "
+ "keyword argument to %s(). Don't do that."
% (key, cls.__name__))
if not hasattr(cls, key):
- raise TypeError(u"%s() received an invalid keyword %r" % (
+ raise TypeError("%s() received an invalid keyword %r" % (
cls.__name__, key))
def view(request, *args, **kwargs):
@@ -110,7 +140,6 @@ else:
import re
import random
import logging
- import urlparse
from django.conf import settings
from django.core.urlresolvers import get_callable
@@ -152,7 +181,8 @@ else:
randrange = random.SystemRandom().randrange
else:
randrange = random.randrange
- _MAX_CSRF_KEY = 18446744073709551616L # 2 << 63
+
+ _MAX_CSRF_KEY = 18446744073709551616 # 2 << 63
REASON_NO_REFERER = "Referer checking failed - no Referer."
REASON_BAD_REFERER = "Referer checking failed - %s does not match %s."
@@ -319,7 +349,7 @@ except ImportError:
# dateparse is ALSO new in Django 1.4
try:
- from django.utils.dateparse import parse_date, parse_datetime
+ from django.utils.dateparse import parse_date, parse_datetime, parse_time
except ImportError:
import datetime
import re
@@ -391,8 +421,39 @@ except ImportError:
yaml = None
-# xml.etree.parse only throws ParseError for python >= 2.7
+# XML is optional
+try:
+ import defusedxml.ElementTree as etree
+except ImportError:
+ etree = None
+
+# OAuth is optional
try:
- from xml.etree import ParseError as ETParseError
-except ImportError: # python < 2.7
- ETParseError = None
+ # Note: The `oauth2` package actually provides oauth1.0a support. Urg.
+ import oauth2 as oauth
+except ImportError:
+ oauth = None
+
+# OAuth is optional
+try:
+ import oauth_provider
+ from oauth_provider.store import store as oauth_provider_store
+except ImportError:
+ oauth_provider = None
+ oauth_provider_store = None
+
+# OAuth 2 support is optional
+try:
+ import provider.oauth2 as oauth2_provider
+ from provider.oauth2 import backends as oauth2_provider_backends
+ from provider.oauth2 import models as oauth2_provider_models
+ from provider.oauth2 import forms as oauth2_provider_forms
+ from provider import scope as oauth2_provider_scope
+ from provider import constants as oauth2_constants
+except ImportError:
+ oauth2_provider = None
+ oauth2_provider_backends = None
+ oauth2_provider_models = None
+ oauth2_provider_forms = None
+ oauth2_provider_scope = None
+ oauth2_constants = None
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 1b710a03..8250cd3b 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -1,4 +1,7 @@
+from __future__ import unicode_literals
+from rest_framework.compat import six
from rest_framework.views import APIView
+import types
def api_view(http_method_names):
@@ -11,7 +14,7 @@ def api_view(http_method_names):
def decorator(func):
WrappedAPIView = type(
- 'WrappedAPIView',
+ six.PY3 and 'WrappedAPIView' or b'WrappedAPIView',
(APIView,),
{'__doc__': func.__doc__}
)
@@ -23,6 +26,14 @@ def api_view(http_method_names):
# pass
# WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
+ # api_view applied without (method_names)
+ assert not(isinstance(http_method_names, types.FunctionType)), \
+ '@api_view missing list of allowed HTTP methods'
+
+ # api_view applied with eg. string instead of list of strings
+ assert isinstance(http_method_names, (list, tuple)), \
+ '@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__
+
allowed_methods = set(http_method_names) | set(('options',))
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 89479deb..0c96ecdd 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -4,6 +4,7 @@ Handled exceptions raised by REST framework.
In addition Django's built in 403 and 404 exceptions are handled.
(`django.http.Http404` and `django.core.exceptions.PermissionDenied`)
"""
+from __future__ import unicode_literals
from rest_framework import status
@@ -23,6 +24,22 @@ class ParseError(APIException):
self.detail = detail or self.default_detail
+class AuthenticationFailed(APIException):
+ status_code = status.HTTP_401_UNAUTHORIZED
+ default_detail = 'Incorrect authentication credentials.'
+
+ def __init__(self, detail=None):
+ self.detail = detail or self.default_detail
+
+
+class NotAuthenticated(APIException):
+ status_code = status.HTTP_401_UNAUTHORIZED
+ default_detail = 'Authentication credentials were not provided.'
+
+ def __init__(self, detail=None):
+ self.detail = detail or self.default_detail
+
+
class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN
default_detail = 'You do not have permission to perform this action.'
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 998911e1..4b6931ad 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,30 +1,96 @@
+from __future__ import unicode_literals
+
import copy
import datetime
import inspect
import re
import warnings
-from io import BytesIO
-
from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
from django import forms
from django.forms import widgets
-from django.utils.encoding import is_protected_type, smart_unicode
+from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
-from rest_framework.compat import parse_date, parse_datetime
-from rest_framework.compat import timezone
+
+from rest_framework import ISO_8601
+from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time
+from rest_framework.compat import BytesIO
+from rest_framework.compat import six
+from rest_framework.compat import smart_text
+from rest_framework.settings import api_settings
def is_simple_callable(obj):
"""
True if the object is a callable that takes no arguments.
"""
- return (
- (inspect.isfunction(obj) and not inspect.getargspec(obj)[0]) or
- (inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1)
- )
+ function = inspect.isfunction(obj)
+ method = inspect.ismethod(obj)
+
+ if not (function or method):
+ return False
+
+ args, _, _, defaults = inspect.getargspec(obj)
+ len_args = len(args) if function else len(args) - 1
+ len_defaults = len(defaults) if defaults else 0
+ return len_args <= len_defaults
+
+
+def get_component(obj, attr_name):
+ """
+ Given an object, and an attribute name,
+ return that attribute on the object.
+ """
+ if isinstance(obj, dict):
+ val = obj[attr_name]
+ else:
+ val = getattr(obj, attr_name)
+
+ if is_simple_callable(val):
+ return val()
+ return val
+
+
+def readable_datetime_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]')
+ return humanize_strptime(format)
+
+
+def readable_date_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
+ return humanize_strptime(format)
+
+
+def readable_time_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
+ return humanize_strptime(format)
+
+
+def humanize_strptime(format_string):
+ # Note that we're missing some of the locale specific mappings that
+ # don't really make sense.
+ mapping = {
+ "%Y": "YYYY",
+ "%y": "YY",
+ "%m": "MM",
+ "%b": "[Jan-Dec]",
+ "%B": "[January-December]",
+ "%d": "DD",
+ "%H": "hh",
+ "%I": "hh", # Requires '%p' to differentiate from '%H'.
+ "%M": "mm",
+ "%S": "ss",
+ "%f": "uuuuuu",
+ "%a": "[Mon-Sun]",
+ "%A": "[Monday-Sunday]",
+ "%p": "[AM|PM]",
+ "%z": "[+HHMM|-HHMM]"
+ }
+ for key, val in mapping.items():
+ format_string = format_string.replace(key, val)
+ return format_string
class Field(object):
@@ -32,7 +98,8 @@ class Field(object):
creation_counter = 0
empty = ''
type_name = None
- _use_files = None
+ partial = False
+ use_files = False
form_field_class = forms.CharField
def __init__(self, source=None):
@@ -53,7 +120,8 @@ class Field(object):
self.parent = parent
self.root = parent.root or parent
self.context = self.root.context
- if self.root.partial:
+ self.partial = self.root.partial
+ if self.partial:
self.required = False
def field_from_native(self, data, files, field_name, into):
@@ -74,14 +142,14 @@ class Field(object):
if self.source == '*':
return self.to_native(obj)
- if self.source:
- value = obj
- for component in self.source.split('.'):
- value = getattr(value, component)
- if is_simple_callable(value):
- value = value()
- else:
- value = getattr(obj, field_name)
+ source = self.source or field_name
+ value = obj
+
+ for component in source.split('.'):
+ value = get_component(value, component)
+ if value is None:
+ break
+
return self.to_native(value)
def to_native(self, value):
@@ -93,11 +161,11 @@ class Field(object):
if is_protected_type(value):
return value
- elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
+ elif hasattr(value, '__iter__') and not isinstance(value, (dict, six.string_types)):
return [self.to_native(item) for item in value]
elif isinstance(value, dict):
return dict(map(self.to_native, (k, v)) for k, v in value.items())
- return smart_unicode(value)
+ return smart_text(value)
def attributes(self):
"""
@@ -124,6 +192,13 @@ class WritableField(Field):
validators=[], error_messages=None, widget=None,
default=None, blank=None):
+ # 'blank' is to be deprecated in favor of 'required'
+ if blank is not None:
+ warnings.warn('The `blank` keyword argument is due to deprecated. '
+ 'Use the `required` keyword argument instead.',
+ PendingDeprecationWarning, stacklevel=2)
+ required = not(blank)
+
super(WritableField, self).__init__(source=source)
self.read_only = read_only
@@ -141,7 +216,6 @@ class WritableField(Field):
self.validators = self.default_validators + validators
self.default = default if default is not None else self.default
- self.blank = blank
# Widgets are ony used for HTML forms.
widget = widget or self.widget
@@ -180,13 +254,13 @@ class WritableField(Field):
return
try:
- if self._use_files:
+ if self.use_files:
files = files or {}
native = files[field_name]
else:
native = data[field_name]
except KeyError:
- if self.default is not None and not self.root.partial:
+ if self.default is not None and not self.partial:
# Note: partial updates shouldn't set defaults
native = self.default
else:
@@ -217,7 +291,7 @@ class ModelField(WritableField):
def __init__(self, *args, **kwargs):
try:
self.model_field = kwargs.pop('model_field')
- except:
+ except KeyError:
raise ValueError("ModelField requires 'model_field' kwarg")
self.min_length = kwargs.pop('min_length',
@@ -258,7 +332,7 @@ class BooleanField(WritableField):
form_field_class = forms.BooleanField
widget = widgets.CheckboxInput
default_error_messages = {
- 'invalid': _(u"'%s' value must be either True or False."),
+ 'invalid': _("'%s' value must be either True or False."),
}
empty = False
@@ -287,20 +361,10 @@ class CharField(WritableField):
if max_length is not None:
self.validators.append(validators.MaxLengthValidator(max_length))
- def validate(self, value):
- """
- Validates that the value is supplied (if required).
- """
- # if empty string and allow blank
- if self.blank and not value:
- return
- else:
- super(CharField, self).validate(value)
-
def from_native(self, value):
- if isinstance(value, basestring) or value is None:
+ if isinstance(value, six.string_types) or value is None:
return value
- return smart_unicode(value)
+ return smart_text(value)
class URLField(CharField):
@@ -325,7 +389,8 @@ class ChoiceField(WritableField):
form_field_class = forms.ChoiceField
widget = widgets.Select
default_error_messages = {
- 'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'),
+ 'invalid_choice': _('Select a valid choice. %(value)s is not one of '
+ 'the available choices.'),
}
def __init__(self, choices=(), *args, **kwargs):
@@ -359,10 +424,10 @@ class ChoiceField(WritableField):
if isinstance(v, (list, tuple)):
# This is an optgroup, so look inside the group for options
for k2, v2 in v:
- if value == smart_unicode(k2):
+ if value == smart_text(k2):
return True
else:
- if value == smart_unicode(k) or value == k:
+ if value == smart_text(k) or value == k:
return True
return False
@@ -402,7 +467,7 @@ class RegexField(CharField):
return self._regex
def _set_regex(self, regex):
- if isinstance(regex, basestring):
+ if isinstance(regex, six.string_types):
regex = re.compile(regex)
self._regex = regex
if hasattr(self, '_regex_validator') and self._regex_validator in self.validators:
@@ -425,12 +490,16 @@ class DateField(WritableField):
form_field_class = forms.DateField
default_error_messages = {
- 'invalid': _(u"'%s' value has an invalid date format. It must be "
- u"in YYYY-MM-DD format."),
- 'invalid_date': _(u"'%s' value has the correct format (YYYY-MM-DD) "
- u"but it is an invalid date."),
+ 'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
}
empty = None
+ input_formats = api_settings.DATE_INPUT_FORMATS
+ format = api_settings.DATE_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(DateField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
@@ -446,17 +515,37 @@ class DateField(WritableField):
if isinstance(value, datetime.date):
return value
- try:
- parsed = parse_date(value)
- if parsed is not None:
- return parsed
- except ValueError:
- msg = self.error_messages['invalid_date'] % value
- raise ValidationError(msg)
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_date(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed.date()
- msg = self.error_messages['invalid'] % value
+ msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats)
raise ValidationError(msg)
+ def to_native(self, value):
+ if value is None:
+ return None
+
+ if isinstance(value, datetime.datetime):
+ value = value.date()
+
+ if self.format.lower() == ISO_8601:
+ return value.isoformat()
+ return value.strftime(self.format)
+
class DateTimeField(WritableField):
type_name = 'DateTimeField'
@@ -464,15 +553,16 @@ class DateTimeField(WritableField):
form_field_class = forms.DateTimeField
default_error_messages = {
- 'invalid': _(u"'%s' value has an invalid format. It must be in "
- u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."),
- 'invalid_date': _(u"'%s' value has the correct format "
- u"(YYYY-MM-DD) but it is an invalid date."),
- 'invalid_datetime': _(u"'%s' value has the correct format "
- u"(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) "
- u"but it is an invalid date/time."),
+ 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
}
empty = None
+ input_formats = api_settings.DATETIME_INPUT_FORMATS
+ format = api_settings.DATETIME_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(DateTimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
@@ -487,32 +577,97 @@ class DateTimeField(WritableField):
# local time. This won't work during DST change, but we can't
# do much about it, so we let the exceptions percolate up the
# call stack.
- warnings.warn(u"DateTimeField received a naive datetime (%s)"
- u" while time zone support is active." % value,
+ warnings.warn("DateTimeField received a naive datetime (%s)"
+ " while time zone support is active." % value,
RuntimeWarning)
default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone)
return value
- try:
- parsed = parse_datetime(value)
- if parsed is not None:
- return parsed
- except ValueError:
- msg = self.error_messages['invalid_datetime'] % value
- raise ValidationError(msg)
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_datetime(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed
- try:
- parsed = parse_date(value)
- if parsed is not None:
- return datetime.datetime(parsed.year, parsed.month, parsed.day)
- except ValueError:
- msg = self.error_messages['invalid_date'] % value
- raise ValidationError(msg)
+ msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats)
+ raise ValidationError(msg)
+
+ def to_native(self, value):
+ if value is None:
+ return None
+
+ if self.format.lower() == ISO_8601:
+ return value.isoformat()
+ return value.strftime(self.format)
+
+
+class TimeField(WritableField):
+ type_name = 'TimeField'
+ widget = widgets.TimeInput
+ form_field_class = forms.TimeField
+
+ default_error_messages = {
+ 'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
+ }
+ empty = None
+ input_formats = api_settings.TIME_INPUT_FORMATS
+ format = api_settings.TIME_FORMAT
- msg = self.error_messages['invalid'] % value
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(TimeField, self).__init__(*args, **kwargs)
+
+ def from_native(self, value):
+ if value in validators.EMPTY_VALUES:
+ return None
+
+ if isinstance(value, datetime.time):
+ return value
+
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_time(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed.time()
+
+ msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats)
raise ValidationError(msg)
+ def to_native(self, value):
+ if value is None:
+ return None
+
+ if isinstance(value, datetime.datetime):
+ value = value.time()
+
+ if self.format.lower() == ISO_8601:
+ return value.isoformat()
+ return value.strftime(self.format)
+
class IntegerField(WritableField):
type_name = 'IntegerField'
@@ -564,7 +719,7 @@ class FloatField(WritableField):
class FileField(WritableField):
- _use_files = True
+ use_files = True
type_name = 'FileField'
form_field_class = forms.FileField
widget = widgets.FileInput
@@ -608,11 +763,12 @@ class FileField(WritableField):
class ImageField(FileField):
- _use_files = True
+ use_files = True
form_field_class = forms.ImageField
default_error_messages = {
- 'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."),
+ 'invalid_image': _("Upload a valid image. The file you uploaded was "
+ "either not an image or a corrupted image."),
}
def from_native(self, data):
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index bcc87660..6fea46fa 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from rest_framework.compat import django_filters
FilterSet = django_filters and django_filters.FilterSet or None
@@ -54,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend):
filter_class = self.get_filter_class(view)
if filter_class:
- return filter_class(request.GET, queryset=queryset)
+ return filter_class(request.QUERY_PARAMS, queryset=queryset)
return queryset
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index f575470e..55918267 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -1,7 +1,7 @@
"""
Generic views that provide commonly needed behaviour.
"""
-
+from __future__ import unicode_literals
from rest_framework import views, mixins
from rest_framework.settings import api_settings
from django.views.generic.detail import SingleObjectMixin
@@ -18,6 +18,16 @@ class GenericAPIView(views.APIView):
model = None
serializer_class = None
model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
+ filter_backend = api_settings.FILTER_BACKEND
+
+ def filter_queryset(self, queryset):
+ """
+ Given a queryset, filter it with whichever filter backend is in use.
+ """
+ if not self.filter_backend:
+ return queryset
+ backend = self.filter_backend()
+ return backend.filter_queryset(self.request, queryset, self)
def get_serializer_context(self):
"""
@@ -48,7 +58,7 @@ class GenericAPIView(views.APIView):
return serializer_class
def get_serializer(self, instance=None, data=None,
- files=None, partial=False):
+ files=None, many=False, partial=False):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
@@ -56,7 +66,21 @@ class GenericAPIView(views.APIView):
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
return serializer_class(instance, data=data, files=files,
- partial=partial, context=context)
+ many=many, partial=partial, context=context)
+
+ def pre_save(self, obj):
+ """
+ Placeholder method for calling before saving an object.
+ May be used eg. to set attributes on the object that are implicit
+ in either the request, or the url.
+ """
+ pass
+
+ def post_save(self, obj, created=False):
+ """
+ Placeholder method for calling after saving an object.
+ """
+ pass
def pre_save(self, obj):
pass
@@ -70,16 +94,6 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
paginate_by = api_settings.PAGINATE_BY
paginate_by_param = api_settings.PAGINATE_BY_PARAM
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
- filter_backend = api_settings.FILTER_BACKEND
-
- def filter_queryset(self, queryset):
- """
- Given a queryset, filter it with whichever filter backend is in use.
- """
- if not self.filter_backend:
- return queryset
- backend = self.filter_backend()
- return backend.filter_queryset(self.request, queryset, self)
def get_pagination_serializer(self, page=None):
"""
@@ -120,8 +134,7 @@ class SingleObjectAPIView(SingleObjectMixin, GenericAPIView):
Override default to add support for object-level permissions.
"""
obj = super(SingleObjectAPIView, self).get_object(queryset)
- if not self.has_permission(self.request, obj):
- self.permission_denied(self.request)
+ self.check_object_permissions(self.request, obj)
return obj
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index e0ae216e..7d9a6e65 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -4,22 +4,48 @@ Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways.
"""
+from __future__ import unicode_literals
+
from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
+from rest_framework.request import clone_request
+
+
+def _get_validation_exclusions(obj, pk=None, slug_field=None):
+ """
+ Given a model instance, and an optional pk and slug field,
+ return the full list of all other field names on that model.
+
+ For use when performing full_clean on a model instance,
+ so we only clean the required fields.
+ """
+ include = []
+
+ if pk:
+ pk_field = obj._meta.pk
+ while pk_field.rel:
+ pk_field = pk_field.rel.to._meta.pk
+ include.append(pk_field.name)
+
+ if slug_field:
+ include.append(slug_field)
+
+ return [field.name for field in obj._meta.fields if field.name not in include]
class CreateModelMixin(object):
"""
Create a model instance.
- Should be mixed in with any `BaseView`.
+ Should be mixed in with any `GenericAPIView`.
"""
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
if serializer.is_valid():
self.pre_save(serializer.object)
- self.object = serializer.save()
+ self.object = serializer.save(force_insert=True)
+ self.post_save(self.object, created=True)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED,
headers=headers)
@@ -38,7 +64,7 @@ class ListModelMixin(object):
List a queryset.
Should be mixed in with `MultipleObjectAPIView`.
"""
- empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
+ empty_error = "Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs):
queryset = self.get_queryset()
@@ -60,7 +86,7 @@ class ListModelMixin(object):
paginator, page, queryset, is_paginated = packed
serializer = self.get_pagination_serializer(page)
else:
- serializer = self.get_serializer(self.object_list)
+ serializer = self.get_serializer(self.object_list, many=True)
return Response(serializer.data)
@@ -68,10 +94,12 @@ class ListModelMixin(object):
class RetrieveModelMixin(object):
"""
Retrieve a model instance.
- Should be mixed in with `SingleObjectBaseView`.
+ Should be mixed in with `SingleObjectAPIView`.
"""
def retrieve(self, request, *args, **kwargs):
- self.object = self.get_object()
+ queryset = self.get_queryset()
+ filtered_queryset = self.filter_queryset(queryset)
+ self.object = self.get_object(filtered_queryset)
serializer = self.get_serializer(self.object)
return Response(serializer.data)
@@ -79,23 +107,32 @@ class RetrieveModelMixin(object):
class UpdateModelMixin(object):
"""
Update a model instance.
- Should be mixed in with `SingleObjectBaseView`.
+ Should be mixed in with `SingleObjectAPIView`.
"""
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
+ self.object = None
try:
self.object = self.get_object()
- success_status_code = status.HTTP_200_OK
except Http404:
- self.object = None
+ # If this is a PUT-as-create operation, we need to ensure that
+ # we have relevant permissions, as if this was a POST request.
+ self.check_permissions(clone_request(request, 'POST'))
+ created = True
+ save_kwargs = {'force_insert': True}
success_status_code = status.HTTP_201_CREATED
+ else:
+ created = False
+ save_kwargs = {'force_update': True}
+ success_status_code = status.HTTP_200_OK
serializer = self.get_serializer(self.object, data=request.DATA,
files=request.FILES, partial=partial)
if serializer.is_valid():
self.pre_save(serializer.object)
- self.object = serializer.save()
+ self.object = serializer.save(**save_kwargs)
+ self.post_save(self.object, created=created)
return Response(serializer.data, status=success_status_code)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@@ -106,24 +143,26 @@ class UpdateModelMixin(object):
"""
# pk and/or slug attributes are implicit in the URL.
pk = self.kwargs.get(self.pk_url_kwarg, None)
+ slug = self.kwargs.get(self.slug_url_kwarg, None)
+ slug_field = slug and self.get_slug_field() or None
+
if pk:
setattr(obj, 'pk', pk)
- slug = self.kwargs.get(self.slug_url_kwarg, None)
if slug:
- slug_field = self.get_slug_field()
setattr(obj, slug_field, slug)
# Ensure we clean the attributes so that we don't eg return integer
# pk using a string representation, as provided by the url conf kwarg.
if hasattr(obj, 'full_clean'):
- obj.full_clean()
+ exclude = _get_validation_exclusions(obj, pk, slug_field)
+ obj.full_clean(exclude)
class DestroyModelMixin(object):
"""
Destroy a model instance.
- Should be mixed in with `SingleObjectBaseView`.
+ Should be mixed in with `SingleObjectAPIView`.
"""
def destroy(self, request, *args, **kwargs):
obj = self.get_object()
diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py
index ee2800a6..0694d35f 100644
--- a/rest_framework/negotiation.py
+++ b/rest_framework/negotiation.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from django.http import Http404
from rest_framework import exceptions
from rest_framework.settings import api_settings
@@ -33,7 +34,7 @@ class DefaultContentNegotiation(BaseContentNegotiation):
"""
# Allow URL style format override. eg. "?format=json
format_query_param = self.settings.URL_FORMAT_OVERRIDE
- format = format_suffix or request.GET.get(format_query_param)
+ format = format_suffix or request.QUERY_PARAMS.get(format_query_param)
if format:
renderers = self.filter_renderers(renderers, format)
@@ -80,5 +81,5 @@ class DefaultContentNegotiation(BaseContentNegotiation):
Allows URL style accept override. eg. "?accept=application/json"
"""
header = request.META.get('HTTP_ACCEPT', '*/*')
- header = request.GET.get(self.settings.URL_ACCEPT_OVERRIDE, header)
+ header = request.QUERY_PARAMS.get(self.settings.URL_ACCEPT_OVERRIDE, header)
return [token.strip() for token in header.split(',')]
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index d241ade7..03a7a30f 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from rest_framework import serializers
from rest_framework.templatetags.rest_framework import replace_query_param
@@ -34,6 +35,17 @@ class PreviousPageField(serializers.Field):
return replace_query_param(url, self.page_field, page)
+class DefaultObjectSerializer(serializers.Field):
+ """
+ If no object serializer is specified, then this serializer will be applied
+ as the default.
+ """
+
+ def __init__(self, source=None, context=None):
+ # Note: Swallow context kwarg - only required for eg. ModelSerializer.
+ super(DefaultObjectSerializer, self).__init__(source=source)
+
+
class PaginationSerializerOptions(serializers.SerializerOptions):
"""
An object that stores the options that may be provided to a
@@ -44,7 +56,7 @@ class PaginationSerializerOptions(serializers.SerializerOptions):
def __init__(self, meta):
super(PaginationSerializerOptions, self).__init__(meta)
self.object_serializer_class = getattr(meta, 'object_serializer_class',
- serializers.Field)
+ DefaultObjectSerializer)
class BasePaginationSerializer(serializers.Serializer):
@@ -62,14 +74,13 @@ class BasePaginationSerializer(serializers.Serializer):
super(BasePaginationSerializer, self).__init__(*args, **kwargs)
results_field = self.results_field
object_serializer = self.opts.object_serializer_class
- self.fields[results_field] = object_serializer(source='object_list')
- def to_native(self, obj):
- """
- Prevent default behaviour of iterating over elements, and serializing
- each in turn.
- """
- return self.convert_object(obj)
+ if 'context' in kwargs:
+ context_kwarg = {'context': kwargs['context']}
+ else:
+ context_kwarg = {}
+
+ self.fields[results_field] = object_serializer(source='object_list', **context_kwarg)
class PaginationSerializer(BasePaginationSerializer):
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 149d6431..491acd68 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -4,14 +4,14 @@ Parsers are used to parse the content of incoming HTTP requests.
They give us a generic way of being able to handle various media types
on the request, such as form content or json encoded data.
"""
-
+from __future__ import unicode_literals
+from django.conf import settings
from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError
-from rest_framework.compat import yaml, ETParseError
+from rest_framework.compat import yaml, etree
from rest_framework.exceptions import ParseError
-from xml.etree import ElementTree as ET
-from xml.parsers.expat import ExpatError
+from rest_framework.compat import six
import json
import datetime
import decimal
@@ -54,10 +54,14 @@ class JSONParser(BaseParser):
`data` will be an object which is the parsed content of the response.
`files` will always be `None`.
"""
+ parser_context = parser_context or {}
+ encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
+
try:
- return json.load(stream)
- except ValueError, exc:
- raise ParseError('JSON parse error - %s' % unicode(exc))
+ data = stream.read().decode(encoding)
+ return json.loads(data)
+ except ValueError as exc:
+ raise ParseError('JSON parse error - %s' % six.text_type(exc))
class YAMLParser(BaseParser):
@@ -74,10 +78,16 @@ class YAMLParser(BaseParser):
`data` will be an object which is the parsed content of the response.
`files` will always be `None`.
"""
+ assert yaml, 'YAMLParser requires pyyaml to be installed'
+
+ parser_context = parser_context or {}
+ encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
+
try:
- return yaml.safe_load(stream)
- except (ValueError, yaml.parser.ParserError), exc:
- raise ParseError('YAML parse error - %s' % unicode(exc))
+ data = stream.read().decode(encoding)
+ return yaml.safe_load(data)
+ except (ValueError, yaml.parser.ParserError) as exc:
+ raise ParseError('YAML parse error - %s' % six.u(exc))
class FormParser(BaseParser):
@@ -94,7 +104,9 @@ class FormParser(BaseParser):
`data` will be a :class:`QueryDict` containing all the form parameters.
`files` will always be :const:`None`.
"""
- data = QueryDict(stream.read())
+ parser_context = parser_context or {}
+ encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
+ data = QueryDict(stream.read(), encoding=encoding)
return data
@@ -114,15 +126,16 @@ class MultiPartParser(BaseParser):
"""
parser_context = parser_context or {}
request = parser_context['request']
+ encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
meta = request.META
upload_handlers = request.upload_handlers
try:
- parser = DjangoMultiPartParser(meta, stream, upload_handlers)
+ parser = DjangoMultiPartParser(meta, stream, upload_handlers, encoding)
data, files = parser.parse()
return DataAndFiles(data, files)
- except MultiPartParserError, exc:
- raise ParseError('Multipart form parse error - %s' % unicode(exc))
+ except MultiPartParserError as exc:
+ raise ParseError('Multipart form parse error - %s' % six.u(exc))
class XMLParser(BaseParser):
@@ -133,10 +146,15 @@ class XMLParser(BaseParser):
media_type = 'application/xml'
def parse(self, stream, media_type=None, parser_context=None):
+ assert etree, 'XMLParser requires defusedxml to be installed'
+
+ parser_context = parser_context or {}
+ encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
+ parser = etree.DefusedXMLParser(encoding=encoding)
try:
- tree = ET.parse(stream)
- except (ExpatError, ETParseError, ValueError), exc:
- raise ParseError('XML parse error - %s' % unicode(exc))
+ tree = etree.parse(stream, parser=parser, forbid_dtd=True)
+ except (etree.ParseError, ValueError) as exc:
+ raise ParseError('XML parse error - %s' % six.u(exc))
data = self._xml_convert(tree.getroot())
return data
@@ -146,7 +164,7 @@ class XMLParser(BaseParser):
convert the xml `element` into the corresponding python object
"""
- children = element.getchildren()
+ children = list(element)
if len(children) == 0:
return self._type_convert(element.text)
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index 655b78a3..ae895f39 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -1,21 +1,36 @@
"""
Provides a set of pluggable permission policies.
"""
-
+from __future__ import unicode_literals
+import inspect
+import warnings
SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
+from rest_framework.compat import oauth2_provider_scope, oauth2_constants
+
class BasePermission(object):
"""
A base class from which all permission classes should inherit.
"""
- def has_permission(self, request, view, obj=None):
+ def has_permission(self, request, view):
+ """
+ Return `True` if permission is granted, `False` otherwise.
+ """
+ return True
+
+ def has_object_permission(self, request, view, obj):
"""
Return `True` if permission is granted, `False` otherwise.
"""
- raise NotImplementedError(".has_permission() must be overridden.")
+ if len(inspect.getargspec(self.has_permission)[0]) == 4:
+ warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. '
+ 'Use `has_object_permission()` instead for object permissions.',
+ PendingDeprecationWarning, stacklevel=2)
+ return self.has_permission(request, view, obj)
+ return True
class AllowAny(BasePermission):
@@ -25,7 +40,7 @@ class AllowAny(BasePermission):
permission_classes list, but it's useful because it makes the intention
more explicit.
"""
- def has_permission(self, request, view, obj=None):
+ def has_permission(self, request, view):
return True
@@ -34,7 +49,7 @@ class IsAuthenticated(BasePermission):
Allows access only to authenticated users.
"""
- def has_permission(self, request, view, obj=None):
+ def has_permission(self, request, view):
if request.user and request.user.is_authenticated():
return True
return False
@@ -45,7 +60,7 @@ class IsAdminUser(BasePermission):
Allows access only to admin users.
"""
- def has_permission(self, request, view, obj=None):
+ def has_permission(self, request, view):
if request.user and request.user.is_staff:
return True
return False
@@ -56,7 +71,7 @@ class IsAuthenticatedOrReadOnly(BasePermission):
The request is authenticated as a user, or is a read-only request.
"""
- def has_permission(self, request, view, obj=None):
+ def has_permission(self, request, view):
if (request.method in SAFE_METHODS or
request.user and
request.user.is_authenticated()):
@@ -89,6 +104,8 @@ class DjangoModelPermissions(BasePermission):
'DELETE': ['%(app_label)s.delete_%(model_name)s'],
}
+ authenticated_users_only = True
+
def get_required_permissions(self, method, model_cls):
"""
Given a model and an HTTP method, return the list of permission
@@ -100,15 +117,43 @@ class DjangoModelPermissions(BasePermission):
}
return [perm % kwargs for perm in self.perms_map[method]]
- def has_permission(self, request, view, obj=None):
+ def has_permission(self, request, view):
model_cls = getattr(view, 'model', None)
- if not model_cls:
- return True
+ queryset = getattr(view, 'queryset', None)
+
+ if model_cls is None and queryset is not None:
+ model_cls = queryset.model
+
+ assert model_cls, ('Cannot apply DjangoModelPermissions on a view that'
+ ' does not have `.model` or `.queryset` property.')
perms = self.get_required_permissions(request.method, model_cls)
if (request.user and
- request.user.is_authenticated() and
- request.user.has_perms(perms, obj)):
+ (request.user.is_authenticated() or not self.authenticated_users_only) and
+ request.user.has_perms(perms)):
return True
return False
+
+
+class TokenHasReadWriteScope(BasePermission):
+ """
+ The request is authenticated as a user and the token used has the right scope
+ """
+
+ def has_permission(self, request, view):
+ token = request.auth
+ read_only = request.method in SAFE_METHODS
+
+ if not token:
+ return False
+
+ if hasattr(token, 'resource'): # OAuth 1
+ return read_only or not request.auth.resource.is_readonly
+ elif hasattr(token, 'scope'): # OAuth 2
+ required = oauth2_constants.READ if read_only else oauth2_constants.WRITE
+ return oauth2_provider_scope.check(required, request.auth.scope)
+
+ assert False, ('TokenHasReadWriteScope requires either the'
+ '`OAuthAuthentication` or `OAuth2Authentication` authentication '
+ 'class to be used.')
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 5e4552b7..2a10e9af 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -1,13 +1,16 @@
+from __future__ import unicode_literals
from django.core.exceptions import ObjectDoesNotExist, ValidationError
-from django.core.urlresolvers import resolve, get_script_prefix
+from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
from django import forms
from django.forms import widgets
from django.forms.models import ModelChoiceIterator
-from django.utils.encoding import smart_unicode
from django.utils.translation import ugettext_lazy as _
-from rest_framework.fields import Field, WritableField
+from rest_framework.fields import Field, WritableField, get_component
from rest_framework.reverse import reverse
-from urlparse import urlparse
+from rest_framework.compat import urlparse
+from rest_framework.compat import smart_text
+import warnings
+
##### Relational fields #####
@@ -17,19 +20,35 @@ class RelatedField(WritableField):
"""
Base class for related model fields.
- If not overridden, this represents a to-one relationship, using the unicode
- representation of the target.
+ This represents a relationship using the unicode representation of the target.
"""
widget = widgets.Select
+ many_widget = widgets.SelectMultiple
+ form_field_class = forms.ChoiceField
+ many_form_field_class = forms.MultipleChoiceField
+
cache_choices = False
empty_label = None
- default_read_only = True # TODO: Remove this
+ read_only = True
+ many = False
def __init__(self, *args, **kwargs):
+
+ # 'null' is to be deprecated in favor of 'required'
+ if 'null' in kwargs:
+ warnings.warn('The `null` keyword argument is due to be deprecated. '
+ 'Use the `required` keyword argument instead.',
+ PendingDeprecationWarning, stacklevel=2)
+ kwargs['required'] = not kwargs.pop('null')
+
self.queryset = kwargs.pop('queryset', None)
- self.null = kwargs.pop('null', False)
+ self.many = kwargs.pop('many', self.many)
+ if self.many:
+ self.widget = self.many_widget
+ self.form_field_class = self.many_form_field_class
+
+ kwargs['read_only'] = kwargs.pop('read_only', self.read_only)
super(RelatedField, self).__init__(*args, **kwargs)
- self.read_only = kwargs.pop('read_only', self.default_read_only)
def initialize(self, parent, field_name):
super(RelatedField, self).initialize(parent, field_name)
@@ -40,7 +59,7 @@ class RelatedField(WritableField):
self.queryset = manager.related.model._default_manager.all()
else: # Reverse
self.queryset = manager.field.rel.to._default_manager.all()
- except:
+ except Exception:
raise
msg = ('Serializer related fields must include a `queryset`' +
' argument or set `read_only=True')
@@ -48,11 +67,6 @@ class RelatedField(WritableField):
### We need this stuff to make form choices work...
- # def __deepcopy__(self, memo):
- # result = super(RelatedField, self).__deepcopy__(memo)
- # result.queryset = result.queryset
- # return result
-
def prepare_value(self, obj):
return self.to_native(obj)
@@ -60,8 +74,8 @@ class RelatedField(WritableField):
"""
Return a readable representation for use with eg. select widgets.
"""
- desc = smart_unicode(obj)
- ident = smart_unicode(self.to_native(obj))
+ desc = smart_text(obj)
+ ident = smart_text(self.to_native(obj))
if desc == ident:
return desc
return "%s - %s" % (desc, ident)
@@ -102,9 +116,24 @@ class RelatedField(WritableField):
def field_to_native(self, obj, field_name):
try:
- value = getattr(obj, self.source or field_name)
+ if self.source == '*':
+ return self.to_native(obj)
+
+ source = self.source or field_name
+ value = obj
+
+ for component in source.split('.'):
+ value = get_component(value, component)
+ if value is None:
+ break
except ObjectDoesNotExist:
return None
+
+ if value is None:
+ return None
+
+ if self.many:
+ return [self.to_native(item) for item in value.all()]
return self.to_native(value)
def field_from_native(self, data, files, field_name, into):
@@ -112,69 +141,43 @@ class RelatedField(WritableField):
return
try:
- value = data[field_name]
+ if self.many:
+ try:
+ # Form data
+ value = data.getlist(field_name)
+ if value == [''] or value == []:
+ raise KeyError
+ except AttributeError:
+ # Non-form data
+ value = data[field_name]
+ else:
+ value = data[field_name]
except KeyError:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- return
+ if self.partial:
+ return
+ value = [] if self.many else None
- if value in (None, '') and not self.null:
- raise ValidationError('Value may not be null')
- elif value in (None, '') and self.null:
+ if value in (None, '') and self.required:
+ raise ValidationError(self.error_messages['required'])
+ elif value in (None, ''):
into[(self.source or field_name)] = None
+ elif self.many:
+ into[(self.source or field_name)] = [self.from_native(item) for item in value]
else:
into[(self.source or field_name)] = self.from_native(value)
-class ManyRelatedMixin(object):
- """
- Mixin to convert a related field to a many related field.
- """
- widget = widgets.SelectMultiple
-
- def field_to_native(self, obj, field_name):
- value = getattr(obj, self.source or field_name)
- return [self.to_native(item) for item in value.all()]
-
- def field_from_native(self, data, files, field_name, into):
- if self.read_only:
- return
-
- try:
- # Form data
- value = data.getlist(self.source or field_name)
- except:
- # Non-form data
- value = data.get(self.source or field_name)
- else:
- if value == ['']:
- value = []
-
- into[field_name] = [self.from_native(item) for item in value]
-
-
-class ManyRelatedField(ManyRelatedMixin, RelatedField):
- """
- Base class for related model managers.
-
- If not overridden, this represents a to-many relationship, using the unicode
- representations of the target, and is read-only.
- """
- pass
-
-
### PrimaryKey relationships
class PrimaryKeyRelatedField(RelatedField):
"""
- Represents a to-one relationship as a pk value.
+ Represents a relationship as a pk value.
"""
- default_read_only = False
- form_field_class = forms.ChoiceField
+ read_only = False
default_error_messages = {
'does_not_exist': _("Invalid pk '%s' - object does not exist."),
- 'invalid': _('Invalid value.'),
+ 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'),
}
# TODO: Remove these field hacks...
@@ -185,8 +188,8 @@ class PrimaryKeyRelatedField(RelatedField):
"""
Return a readable representation for use with eg. select widgets.
"""
- desc = smart_unicode(obj)
- ident = smart_unicode(self.to_native(obj.pk))
+ desc = smart_text(obj)
+ ident = smart_text(self.to_native(obj.pk))
if desc == ident:
return desc
return "%s - %s" % (desc, ident)
@@ -202,85 +205,49 @@ class PrimaryKeyRelatedField(RelatedField):
try:
return self.queryset.get(pk=data)
except ObjectDoesNotExist:
- msg = self.error_messages['does_not_exist'] % smart_unicode(data)
+ msg = self.error_messages['does_not_exist'] % smart_text(data)
raise ValidationError(msg)
except (TypeError, ValueError):
- msg = self.error_messages['invalid']
+ received = type(data).__name__
+ msg = self.error_messages['incorrect_type'] % received
raise ValidationError(msg)
def field_to_native(self, obj, field_name):
+ if self.many:
+ # To-many relationship
+ try:
+ # Prefer obj.serializable_value for performance reasons
+ queryset = obj.serializable_value(self.source or field_name)
+ except AttributeError:
+ # RelatedManager (reverse relationship)
+ queryset = getattr(obj, self.source or field_name)
+
+ # Forward relationship
+ return [self.to_native(item.pk) for item in queryset.all()]
+
+ # To-one relationship
try:
# Prefer obj.serializable_value for performance reasons
pk = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedObject (reverse relationship)
try:
- obj = getattr(obj, self.source or field_name)
+ pk = getattr(obj, self.source or field_name).pk
except ObjectDoesNotExist:
return None
- return self.to_native(obj.pk)
- # Forward relationship
- return self.to_native(pk)
-
-
-class ManyPrimaryKeyRelatedField(ManyRelatedField):
- """
- Represents a to-many relationship as a pk value.
- """
- default_read_only = False
- form_field_class = forms.MultipleChoiceField
-
- default_error_messages = {
- 'does_not_exist': _("Invalid pk '%s' - object does not exist."),
- 'invalid': _('Invalid value.'),
- }
- def prepare_value(self, obj):
- return self.to_native(obj.pk)
-
- def label_from_instance(self, obj):
- """
- Return a readable representation for use with eg. select widgets.
- """
- desc = smart_unicode(obj)
- ident = smart_unicode(self.to_native(obj.pk))
- if desc == ident:
- return desc
- return "%s - %s" % (desc, ident)
-
- def to_native(self, pk):
- return pk
-
- def field_to_native(self, obj, field_name):
- try:
- # Prefer obj.serializable_value for performance reasons
- queryset = obj.serializable_value(self.source or field_name)
- except AttributeError:
- # RelatedManager (reverse relationship)
- queryset = getattr(obj, self.source or field_name)
- return [self.to_native(item.pk) for item in queryset.all()]
# Forward relationship
- return [self.to_native(item.pk) for item in queryset.all()]
-
- def from_native(self, data):
- if self.queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
+ return self.to_native(pk)
- try:
- return self.queryset.get(pk=data)
- except ObjectDoesNotExist:
- msg = self.error_messages['does_not_exist'] % smart_unicode(data)
- raise ValidationError(msg)
- except (TypeError, ValueError):
- msg = self.error_messages['invalid']
- raise ValidationError(msg)
### Slug relationships
class SlugRelatedField(RelatedField):
- default_read_only = False
- form_field_class = forms.ChoiceField
+ """
+ Represents a relationship using a unique field on the target.
+ """
+ read_only = False
default_error_messages = {
'does_not_exist': _("Object with %s=%s does not exist."),
@@ -303,40 +270,35 @@ class SlugRelatedField(RelatedField):
return self.queryset.get(**{self.slug_field: data})
except ObjectDoesNotExist:
raise ValidationError(self.error_messages['does_not_exist'] %
- (self.slug_field, unicode(data)))
+ (self.slug_field, smart_text(data)))
except (TypeError, ValueError):
msg = self.error_messages['invalid']
raise ValidationError(msg)
-class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField):
- form_field_class = forms.MultipleChoiceField
-
-
### Hyperlinked relationships
class HyperlinkedRelatedField(RelatedField):
"""
- Represents a to-one relationship, using hyperlinking.
+ Represents a relationship using hyperlinking.
"""
pk_url_kwarg = 'pk'
slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
- default_read_only = False
- form_field_class = forms.ChoiceField
+ read_only = False
default_error_messages = {
'no_match': _('Invalid hyperlink - No URL match'),
'incorrect_match': _('Invalid hyperlink - Incorrect URL match'),
'configuration_error': _('Invalid hyperlink due to configuration error'),
'does_not_exist': _("Invalid hyperlink - object does not exist."),
- 'invalid': _('Invalid value.'),
+ 'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
}
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
- except:
+ except KeyError:
raise ValueError("Hyperlinked field requires 'view_name' kwarg")
self.slug_field = kwargs.pop('slug_field', self.slug_field)
@@ -357,13 +319,20 @@ class HyperlinkedRelatedField(RelatedField):
view_name = self.view_name
request = self.context.get('request', None)
format = self.format or self.context.get('format', None)
+
+ if request is None:
+ warnings.warn("Using `HyperlinkedRelatedField` without including the "
+ "request in the serializer context is due to be deprecated. "
+ "Add `context={'request': request}` when instantiating the serializer.",
+ PendingDeprecationWarning, stacklevel=4)
+
pk = getattr(obj, 'pk', None)
if pk is None:
return
kwargs = {self.pk_url_kwarg: pk}
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except:
+ except NoReverseMatch:
pass
slug = getattr(obj, self.slug_field, None)
@@ -374,13 +343,13 @@ class HyperlinkedRelatedField(RelatedField):
kwargs = {self.slug_url_kwarg: slug}
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except:
+ except NoReverseMatch:
pass
kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except:
+ except NoReverseMatch:
pass
raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
@@ -394,19 +363,19 @@ class HyperlinkedRelatedField(RelatedField):
try:
http_prefix = value.startswith('http:') or value.startswith('https:')
except AttributeError:
- msg = self.error_messages['invalid']
- raise ValidationError(msg)
+ msg = self.error_messages['incorrect_type']
+ raise ValidationError(msg % type(value).__name__)
if http_prefix:
# If needed convert absolute URLs to relative path
- value = urlparse(value).path
+ value = urlparse.urlparse(value).path
prefix = get_script_prefix()
if value.startswith(prefix):
value = '/' + value[len(prefix):]
try:
match = resolve(value)
- except:
+ except Exception:
raise ValidationError(self.error_messages['no_match'])
if match.view_name != self.view_name:
@@ -431,19 +400,12 @@ class HyperlinkedRelatedField(RelatedField):
except ObjectDoesNotExist:
raise ValidationError(self.error_messages['does_not_exist'])
except (TypeError, ValueError):
- msg = self.error_messages['invalid']
- raise ValidationError(msg)
+ msg = self.error_messages['incorrect_type']
+ raise ValidationError(msg % type(value).__name__)
return obj
-class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
- """
- Represents a to-many relationship, using hyperlinking.
- """
- form_field_class = forms.MultipleChoiceField
-
-
class HyperlinkedIdentityField(Field):
"""
Represents the instance, or a property on the instance, using hyperlinking.
@@ -451,6 +413,7 @@ class HyperlinkedIdentityField(Field):
pk_url_kwarg = 'pk'
slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
+ read_only = True
def __init__(self, *args, **kwargs):
# TODO: Make view_name mandatory, and have the
@@ -472,6 +435,12 @@ class HyperlinkedIdentityField(Field):
view_name = self.view_name or self.parent.opts.view_name
kwargs = {self.pk_url_kwarg: obj.pk}
+ if request is None:
+ warnings.warn("Using `HyperlinkedIdentityField` without including the "
+ "request in the serializer context is due to be deprecated. "
+ "Add `context={'request': request}` when instantiating the serializer.",
+ PendingDeprecationWarning, stacklevel=4)
+
# By default use whatever format is given for the current context
# unless the target is a different type to the source.
#
@@ -486,7 +455,7 @@ class HyperlinkedIdentityField(Field):
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except:
+ except NoReverseMatch:
pass
slug = getattr(obj, self.slug_field, None)
@@ -497,13 +466,51 @@ class HyperlinkedIdentityField(Field):
kwargs = {self.slug_url_kwarg: slug}
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except:
+ except NoReverseMatch:
pass
kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except:
+ except NoReverseMatch:
pass
raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+
+
+### Old-style many classes for backwards compat
+
+class ManyRelatedField(RelatedField):
+ def __init__(self, *args, **kwargs):
+ warnings.warn('`ManyRelatedField()` is due to be deprecated. '
+ 'Use `RelatedField(many=True)` instead.',
+ PendingDeprecationWarning, stacklevel=2)
+ kwargs['many'] = True
+ super(ManyRelatedField, self).__init__(*args, **kwargs)
+
+
+class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
+ def __init__(self, *args, **kwargs):
+ warnings.warn('`ManyPrimaryKeyRelatedField()` is due to be deprecated. '
+ 'Use `PrimaryKeyRelatedField(many=True)` instead.',
+ PendingDeprecationWarning, stacklevel=2)
+ kwargs['many'] = True
+ super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs)
+
+
+class ManySlugRelatedField(SlugRelatedField):
+ def __init__(self, *args, **kwargs):
+ warnings.warn('`ManySlugRelatedField()` is due to be deprecated. '
+ 'Use `SlugRelatedField(many=True)` instead.',
+ PendingDeprecationWarning, stacklevel=2)
+ kwargs['many'] = True
+ super(ManySlugRelatedField, self).__init__(*args, **kwargs)
+
+
+class ManyHyperlinkedRelatedField(HyperlinkedRelatedField):
+ def __init__(self, *args, **kwargs):
+ warnings.warn('`ManyHyperlinkedRelatedField()` is due to be deprecated. '
+ 'Use `HyperlinkedRelatedField(many=True)` instead.',
+ PendingDeprecationWarning, stacklevel=2)
+ kwargs['many'] = True
+ super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs)
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 0a34abaa..4c15e0db 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -6,21 +6,25 @@ on the response, such as JSON encoded data or HTML output.
REST framework also provides an HTML renderer the renders the browsable API.
"""
+from __future__ import unicode_literals
+
import copy
import string
import json
from django import forms
from django.http.multipartparser import parse_header
from django.template import RequestContext, loader, Template
+from django.utils.xmlutils import SimplerXMLGenerator
+from rest_framework.compat import StringIO
+from rest_framework.compat import six
+from rest_framework.compat import smart_text
from rest_framework.compat import yaml
from rest_framework.exceptions import ConfigurationError
from rest_framework.settings import api_settings
from rest_framework.request import clone_request
-from rest_framework.utils import dict2xml
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
-from rest_framework import VERSION, status
-from rest_framework import parsers
+from rest_framework import exceptions, parsers, status, VERSION
class BaseRenderer(object):
@@ -60,7 +64,7 @@ class JSONRenderer(BaseRenderer):
if accepted_media_type:
# If the media type looks like 'application/json; indent=4',
# then pretty print the result.
- base_media_type, params = parse_header(accepted_media_type)
+ base_media_type, params = parse_header(accepted_media_type.encode('ascii'))
indent = params.get('indent', indent)
try:
indent = max(min(int(indent), 8), 0)
@@ -86,7 +90,7 @@ class JSONPRenderer(JSONRenderer):
Determine the name of the callback to wrap around the json output.
"""
request = renderer_context.get('request', None)
- params = request and request.GET or {}
+ params = request and request.QUERY_PARAMS or {}
return params.get(self.callback_parameter, self.default_callback)
def render(self, data, accepted_media_type=None, renderer_context=None):
@@ -100,7 +104,7 @@ class JSONPRenderer(JSONRenderer):
callback = self.get_callback(renderer_context)
json = super(JSONPRenderer, self).render(data, accepted_media_type,
renderer_context)
- return u"%s(%s);" % (callback, json)
+ return "%s(%s);" % (callback, json)
class XMLRenderer(BaseRenderer):
@@ -117,7 +121,38 @@ class XMLRenderer(BaseRenderer):
"""
if data is None:
return ''
- return dict2xml(data)
+
+ stream = StringIO()
+
+ xml = SimplerXMLGenerator(stream, "utf-8")
+ xml.startDocument()
+ xml.startElement("root", {})
+
+ self._to_xml(xml, data)
+
+ xml.endElement("root")
+ xml.endDocument()
+ return stream.getvalue()
+
+ def _to_xml(self, xml, data):
+ if isinstance(data, (list, tuple)):
+ for item in data:
+ xml.startElement("list-item", {})
+ self._to_xml(xml, item)
+ xml.endElement("list-item")
+
+ elif isinstance(data, dict):
+ for key, value in six.iteritems(data):
+ xml.startElement(key, {})
+ self._to_xml(xml, value)
+ xml.endElement(key)
+
+ elif data is None:
+ # Don't output any value
+ pass
+
+ else:
+ xml.characters(smart_text(data))
class YAMLRenderer(BaseRenderer):
@@ -133,6 +168,8 @@ class YAMLRenderer(BaseRenderer):
"""
Renders *obj* into serialized YAML.
"""
+ assert yaml, 'YAMLRenderer requires pyyaml to be installed'
+
if data is None:
return ''
@@ -215,7 +252,7 @@ class TemplateHTMLRenderer(BaseRenderer):
try:
# Try to find an appropriate error template
return self.resolve_template(template_names)
- except:
+ except Exception:
# Fall back to using eg '404 Not Found'
return Template('%d %s' % (response.status_code,
response.status_text.title()))
@@ -297,12 +334,10 @@ class BrowsableAPIRenderer(BaseRenderer):
if not api_settings.FORM_METHOD_OVERRIDE:
return # Cannot use form overloading
- request = clone_request(request, method)
try:
- if not view.has_permission(request, obj):
- return # Don't have permission
- except:
- return # Don't have permission and exception explicitly raise
+ view.check_permissions(clone_request(request, method))
+ except exceptions.APIException:
+ return False # Doesn't have permissions
return True
def serializer_to_form_fields(self, serializer):
@@ -333,6 +368,7 @@ class BrowsableAPIRenderer(BaseRenderer):
kwargs['label'] = k
fields[k] = v.form_field_class(**kwargs)
+
return fields
def get_form(self, view, method, request):
@@ -345,24 +381,23 @@ class BrowsableAPIRenderer(BaseRenderer):
if not self.show_form_for_method(view, method, request, obj):
return
- if method == 'DELETE' or method == 'OPTIONS':
+ if method in ('DELETE', 'OPTIONS'):
return True # Don't actually need to return a form
if not getattr(view, 'get_serializer', None) or not parsers.FormParser in view.parser_classes:
- media_types = [parser.media_type for parser in view.parser_classes]
- return self.get_generic_content_form(media_types)
+ return
serializer = view.get_serializer(instance=obj)
fields = self.serializer_to_form_fields(serializer)
# Creating an on the fly form see:
# http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
- OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields)
+ OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields)
data = (obj is not None) and serializer.data or None
form_instance = OnTheFlyForm(data)
return form_instance
- def get_generic_content_form(self, media_types):
+ def get_raw_data_form(self, view, method, request, media_types):
"""
Returns a form that allows for arbitrary content types to be tunneled
via standard HTML forms.
@@ -375,6 +410,11 @@ class BrowsableAPIRenderer(BaseRenderer):
and api_settings.FORM_CONTENTTYPE_OVERRIDE):
return None
+ # Check permissions
+ obj = getattr(view, 'object', None)
+ if not self.show_form_for_method(view, method, request, obj):
+ return
+
content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE
content_field = api_settings.FORM_CONTENT_OVERRIDE
choices = [(media_type, media_type) for media_type in media_types]
@@ -386,7 +426,7 @@ class BrowsableAPIRenderer(BaseRenderer):
super(GenericContentForm, self).__init__()
self.fields[content_type_field] = forms.ChoiceField(
- label='Content Type',
+ label='Media type',
choices=choices,
initial=initial
)
@@ -401,13 +441,13 @@ class BrowsableAPIRenderer(BaseRenderer):
try:
return view.get_name()
except AttributeError:
- return view.__doc__
+ return smart_text(view.__class__.__name__)
def get_description(self, view):
try:
return view.get_description(html=True)
except AttributeError:
- return view.__doc__
+ return smart_text(view.__doc__ or '')
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -422,15 +462,22 @@ class BrowsableAPIRenderer(BaseRenderer):
view = renderer_context['view']
request = renderer_context['request']
response = renderer_context['response']
+ media_types = [parser.media_type for parser in view.parser_classes]
renderer = self.get_default_renderer(view)
content = self.get_content(renderer, data, accepted_media_type, renderer_context)
put_form = self.get_form(view, 'PUT', request)
post_form = self.get_form(view, 'POST', request)
+ patch_form = self.get_form(view, 'PATCH', request)
delete_form = self.get_form(view, 'DELETE', request)
options_form = self.get_form(view, 'OPTIONS', request)
+ raw_data_put_form = self.get_raw_data_form(view, 'PUT', request, media_types)
+ raw_data_post_form = self.get_raw_data_form(view, 'POST', request, media_types)
+ raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request, media_types)
+ raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form
+
name = self.get_name(view)
description = self.get_description(view)
breadcrumb_list = get_breadcrumbs(request.path)
@@ -447,10 +494,18 @@ class BrowsableAPIRenderer(BaseRenderer):
'breadcrumblist': breadcrumb_list,
'allowed_methods': view.allowed_methods,
'available_formats': [renderer.format for renderer in view.renderer_classes],
+
'put_form': put_form,
'post_form': post_form,
+ 'patch_form': patch_form,
'delete_form': delete_form,
'options_form': options_form,
+
+ 'raw_data_put_form': raw_data_put_form,
+ 'raw_data_post_form': raw_data_post_form,
+ 'raw_data_patch_form': raw_data_patch_form,
+ 'raw_data_put_or_patch_form': raw_data_put_or_patch_form,
+
'api_settings': api_settings
})
diff --git a/rest_framework/request.py b/rest_framework/request.py
index b7133608..ffbbab33 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -9,10 +9,14 @@ The wrapped request then offers a richer API, in particular :
- full support of PUT method, including support for file uploads
- form overloading of HTTP method, content type and content
"""
-from StringIO import StringIO
-
+from __future__ import unicode_literals
+from django.conf import settings
+from django.http import QueryDict
from django.http.multipartparser import parse_header
+from django.utils.datastructures import MultiValueDict
+from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions
+from rest_framework.compat import BytesIO
from rest_framework.settings import api_settings
@@ -20,7 +24,7 @@ def is_form_media_type(media_type):
"""
Return True if the media type is a valid form media type.
"""
- base_media_type, params = parse_header(media_type)
+ base_media_type, params = parse_header(media_type.encode(HTTP_HEADER_ENCODING))
return (base_media_type == 'application/x-www-form-urlencoded' or
base_media_type == 'multipart/form-data')
@@ -42,10 +46,11 @@ def clone_request(request, method):
Internal helper method to clone a request, replacing with a different
HTTP method. Used for checking permissions against other methods.
"""
- ret = Request(request._request,
- request.parsers,
- request.authenticators,
- request.parser_context)
+ ret = Request(request=request._request,
+ parsers=request.parsers,
+ authenticators=request.authenticators,
+ negotiator=request.negotiator,
+ parser_context=request.parser_context)
ret._data = request._data
ret._files = request._files
ret._content_type = request._content_type
@@ -55,6 +60,8 @@ def clone_request(request, method):
ret._user = request._user
if hasattr(request, '_auth'):
ret._auth = request._auth
+ if hasattr(request, '_authenticator'):
+ ret._authenticator = request._authenticator
return ret
@@ -90,6 +97,7 @@ class Request(object):
if self.parser_context is None:
self.parser_context = {}
self.parser_context['request'] = self
+ self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
def _default_negotiator(self):
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
@@ -166,17 +174,17 @@ class Request(object):
by the authentication classes provided to the request.
"""
if not hasattr(self, '_user'):
- self._user, self._auth = self._authenticate()
+ self._authenticator, self._user, self._auth = self._authenticate()
return self._user
@user.setter
def user(self, value):
- """
- Sets the user on the current request. This is necessary to maintain
- compatilbility with django.contrib.auth where the user proprety is
- set in the login and logout functions.
- """
- self._user = value
+ """
+ Sets the user on the current request. This is necessary to maintain
+ compatilbility with django.contrib.auth where the user proprety is
+ set in the login and logout functions.
+ """
+ self._user = value
@property
def auth(self):
@@ -185,7 +193,7 @@ class Request(object):
request, such as an authentication token.
"""
if not hasattr(self, '_auth'):
- self._user, self._auth = self._authenticate()
+ self._authenticator, self._user, self._auth = self._authenticate()
return self._auth
@auth.setter
@@ -196,6 +204,16 @@ class Request(object):
"""
self._auth = value
+ @property
+ def successful_authenticator(self):
+ """
+ Return the instance of the authentication instance class that was used
+ to authenticate the request, or `None`.
+ """
+ if not hasattr(self, '_authenticator'):
+ self._authenticator, self._user, self._auth = self._authenticate()
+ return self._authenticator
+
def _load_data_and_files(self):
"""
Parses the request content into self.DATA and self.FILES.
@@ -213,11 +231,17 @@ class Request(object):
"""
self._content_type = self.META.get('HTTP_CONTENT_TYPE',
self.META.get('CONTENT_TYPE', ''))
+
self._perform_form_overloading()
- # if the HTTP method was not overloaded, we take the raw HTTP method
+
if not _hasattr(self, '_method'):
self._method = self._request.method
+ if self._method == 'POST':
+ # Allow X-HTTP-METHOD-OVERRIDE header
+ self._method = self.META.get('HTTP_X_HTTP_METHOD_OVERRIDE',
+ self._method)
+
def _load_stream(self):
"""
Return the content body of the request, as a stream.
@@ -233,7 +257,7 @@ class Request(object):
elif hasattr(self._request, 'read'):
self._stream = self._request
else:
- self._stream = StringIO(self.raw_post_data)
+ self._stream = BytesIO(self.raw_post_data)
def _perform_form_overloading(self):
"""
@@ -268,7 +292,7 @@ class Request(object):
self._CONTENT_PARAM in self._data and
self._CONTENTTYPE_PARAM in self._data):
self._content_type = self._data[self._CONTENTTYPE_PARAM]
- self._stream = StringIO(self._data[self._CONTENT_PARAM])
+ self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(HTTP_HEADER_ENCODING))
self._data, self._files = (Empty, Empty)
def _parse(self):
@@ -281,7 +305,9 @@ class Request(object):
media_type = self.content_type
if stream is None or media_type is None:
- return (None, None)
+ empty_data = QueryDict('', self._request._encoding)
+ empty_files = MultiValueDict()
+ return (empty_data, empty_files)
parser = self.negotiator.select_parser(self, self.parsers)
@@ -295,25 +321,28 @@ class Request(object):
try:
return (parsed.data, parsed.files)
except AttributeError:
- return (parsed, None)
+ empty_files = MultiValueDict()
+ return (parsed, empty_files)
def _authenticate(self):
"""
- Attempt to authenticate the request using each authentication instance in turn.
- Returns a two-tuple of (user, authtoken).
+ Attempt to authenticate the request using each authentication instance
+ in turn.
+ Returns a three-tuple of (authenticator, user, authtoken).
"""
for authenticator in self.authenticators:
user_auth_tuple = authenticator.authenticate(self)
if not user_auth_tuple is None:
- return user_auth_tuple
+ user, auth = user_auth_tuple
+ return (authenticator, user, auth)
return self._not_authenticated()
def _not_authenticated(self):
"""
- Return a two-tuple of (user, authtoken), representing an
- unauthenticated request.
+ Return a three-tuple of (authenticator, user, authtoken), representing
+ an unauthenticated request.
- By default this will be (AnonymousUser, None).
+ By default this will be (None, AnonymousUser, None).
"""
if api_settings.UNAUTHENTICATED_USER:
user = api_settings.UNAUTHENTICATED_USER()
@@ -325,7 +354,7 @@ class Request(object):
else:
auth = None
- return (user, auth)
+ return (None, user, auth)
def __getattr__(self, attr):
"""
diff --git a/rest_framework/response.py b/rest_framework/response.py
index be78c43a..5e1bf46e 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -1,5 +1,7 @@
+from __future__ import unicode_literals
from django.core.handlers.wsgi import STATUS_CODE_TEXT
from django.template.response import SimpleTemplateResponse
+from rest_framework.compat import six
class Response(SimpleTemplateResponse):
@@ -22,9 +24,9 @@ class Response(SimpleTemplateResponse):
self.data = data
self.template_name = template_name
self.exception = exception
-
+
if headers:
- for name,value in headers.iteritems():
+ for name, value in six.iteritems(headers):
self[name] = value
@property
diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py
index c9db02f0..a51b07f5 100644
--- a/rest_framework/reverse.py
+++ b/rest_framework/reverse.py
@@ -1,6 +1,7 @@
"""
Provide reverse functions that return fully qualified URLs
"""
+from __future__ import unicode_literals
from django.core.urlresolvers import reverse as django_reverse
from django.utils.functional import lazy
diff --git a/rest_framework/runtests/runcoverage.py b/rest_framework/runtests/runcoverage.py
index bcab1d14..ce11b213 100755
--- a/rest_framework/runtests/runcoverage.py
+++ b/rest_framework/runtests/runcoverage.py
@@ -52,12 +52,15 @@ def main():
if os.path.basename(path) in ['tests', 'runtests', 'migrations']:
continue
- # Drop the compat module from coverage, since we're not interested in the coverage
- # of a module which is specifically for resolving environment dependant imports.
+ # Drop the compat and six modules from coverage, since we're not interested in the coverage
+ # of modules which are specifically for resolving environment dependant imports.
# (Because we'll end up getting different coverage reports for it for each environment)
if 'compat.py' in files:
files.remove('compat.py')
+ if 'six.py' in files:
+ files.remove('six.py')
+
# Same applies to template tags module.
# This module has to include branching on Django versions,
# so it's never possible for it to have full coverage.
diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py
index 505994e2..4a333fb3 100755
--- a/rest_framework/runtests/runtests.py
+++ b/rest_framework/runtests/runtests.py
@@ -33,7 +33,7 @@ def main():
elif len(sys.argv) == 1:
test_case = ''
else:
- print usage()
+ print(usage())
sys.exit(1)
failures = test_runner.run_tests(['tests' + test_case])
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index dd5d9dc3..9b519f27 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -97,11 +97,41 @@ INSTALLED_APPS = (
# 'django.contrib.admindocs',
'rest_framework',
'rest_framework.authtoken',
- 'rest_framework.tests'
+ 'rest_framework.tests',
)
+# OAuth is optional and won't work if there is no oauth_provider & oauth2
+try:
+ import oauth_provider
+ import oauth2
+except ImportError:
+ pass
+else:
+ INSTALLED_APPS += (
+ 'oauth_provider',
+ )
+
+try:
+ import provider
+except ImportError:
+ pass
+else:
+ INSTALLED_APPS += (
+ 'provider',
+ 'provider.oauth2',
+ )
+
STATIC_URL = '/static/'
+PASSWORD_HASHERS = (
+ 'django.contrib.auth.hashers.SHA1PasswordHasher',
+ 'django.contrib.auth.hashers.PBKDF2PasswordHasher',
+ 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher',
+ 'django.contrib.auth.hashers.BCryptPasswordHasher',
+ 'django.contrib.auth.hashers.MD5PasswordHasher',
+ 'django.contrib.auth.hashers.CryptPasswordHasher',
+)
+
import django
if django.VERSION < (1, 3):
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 27458f96..4fe857a6 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -1,11 +1,13 @@
+from __future__ import unicode_literals
import copy
import datetime
import types
from decimal import Decimal
+from django.core.paginator import Page
from django.db import models
from django.forms import widgets
from django.utils.datastructures import SortedDict
-from rest_framework.compat import get_concrete_model
+from rest_framework.compat import get_concrete_model, six
# Note: We do the following so that users of the framework can use this style:
#
@@ -25,20 +27,23 @@ class DictWithMetadata(dict):
def __getstate__(self):
"""
Used by pickle (e.g., caching).
- Overriden to remove metadata from the dict, since it shouldn't be pickled
- and may in some instances be unpickleable.
+ Overriden to remove the metadata from the dict, since it shouldn't be
+ pickled and may in some instances be unpickleable.
"""
- # return an instance of the first dict in MRO that isn't a DictWithMetadata
- for base in self.__class__.__mro__:
- if not isinstance(base, DictWithMetadata) and isinstance(base, dict):
- return base(self)
+ return dict(self)
-class SortedDictWithMetadata(SortedDict, DictWithMetadata):
+class SortedDictWithMetadata(SortedDict):
"""
A sorted dict-like object, that can have additional properties attached.
"""
- pass
+ def __getstate__(self):
+ """
+ Used by pickle (e.g., caching).
+ Overriden to remove the metadata from the dict, since it shouldn't be
+ pickle and may in some instances be unpickleable.
+ """
+ return SortedDict(self).__dict__
def _is_protected_type(obj):
@@ -63,7 +68,7 @@ def _get_declared_fields(bases, attrs):
Note that all fields from the base classes are used.
"""
fields = [(field_name, attrs.pop(field_name))
- for field_name, obj in attrs.items()
+ for field_name, obj in list(six.iteritems(attrs))
if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1].creation_counter)
@@ -72,7 +77,7 @@ def _get_declared_fields(bases, attrs):
# in order to maintain the correct order of fields.
for base in bases[::-1]:
if hasattr(base, 'base_fields'):
- fields = base.base_fields.items() + fields
+ fields = list(base.base_fields.items()) + fields
return SortedDict(fields)
@@ -94,19 +99,24 @@ class SerializerOptions(object):
class BaseSerializer(Field):
+ """
+ This is the Serializer implementation.
+ We need to implement it as `BaseSerializer` due to metaclass magicks.
+ """
class Meta(object):
pass
_options_class = SerializerOptions
- _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations.
+ _dict_class = SortedDictWithMetadata
def __init__(self, instance=None, data=None, files=None,
- context=None, partial=False, **kwargs):
- super(BaseSerializer, self).__init__(**kwargs)
+ context=None, partial=False, many=None, source=None):
+ super(BaseSerializer, self).__init__(source=source)
self.opts = self._options_class(self.Meta)
self.parent = None
self.root = None
self.partial = partial
+ self.many = many
self.context = context or {}
@@ -150,6 +160,7 @@ class BaseSerializer(Field):
# If 'fields' is specified, use those fields, in that order.
if self.opts.fields:
+ assert isinstance(self.opts.fields, (list, tuple)), '`include` must be a list or tuple'
new = SortedDict()
for key in self.opts.fields:
new[key] = ret[key]
@@ -157,6 +168,7 @@ class BaseSerializer(Field):
# Remove anything in 'exclude'
if self.opts.exclude:
+ assert isinstance(self.opts.fields, (list, tuple)), '`exclude` must be a list or tuple'
for key in self.opts.exclude:
ret.pop(key, None)
@@ -186,22 +198,6 @@ class BaseSerializer(Field):
"""
return field_name
- def convert_object(self, obj):
- """
- Core of serialization.
- Convert an object into a dictionary of serialized field values.
- """
- ret = self._dict_class()
- ret.fields = {}
-
- for field_name, field in self.fields.items():
- field.initialize(parent=self, field_name=field_name)
- key = self.get_field_key(field_name)
- value = field.field_to_native(obj, field_name)
- ret[key] = value
- ret.fields[key] = field
- return ret
-
def restore_fields(self, data, files):
"""
Core of deserialization, together with `restore_object`.
@@ -210,7 +206,7 @@ class BaseSerializer(Field):
reverted_data = {}
if data is not None and not isinstance(data, dict):
- self._errors['non_field_errors'] = [u'Invalid data']
+ self._errors['non_field_errors'] = ['Invalid data']
return None
for field_name, field in self.fields.items():
@@ -227,6 +223,8 @@ class BaseSerializer(Field):
Run `validate_<fieldname>()` and `validate()` methods on the serializer
"""
for field_name, field in self.fields.items():
+ if field_name in self._errors:
+ continue
try:
validate_method = getattr(self, 'validate_%s' % field_name, None)
if validate_method:
@@ -271,18 +269,21 @@ class BaseSerializer(Field):
"""
Serialize objects -> primitives.
"""
- if hasattr(obj, '__iter__'):
- return [self.convert_object(item) for item in obj]
- return self.convert_object(obj)
+ ret = self._dict_class()
+ ret.fields = {}
+
+ for field_name, field in self.fields.items():
+ field.initialize(parent=self, field_name=field_name)
+ key = self.get_field_key(field_name)
+ value = field.field_to_native(obj, field_name)
+ ret[key] = value
+ ret.fields[key] = field
+ return ret
def from_native(self, data, files):
"""
Deserialize primitives -> objects.
"""
- if hasattr(data, '__iter__') and not isinstance(data, dict):
- # TODO: error data when deserializing lists
- return [self.from_native(item, None) for item in data]
-
self._errors = {}
if data is not None or files is not None:
attrs = self.restore_fields(data, files)
@@ -298,6 +299,9 @@ class BaseSerializer(Field):
Override default so that we can apply ModelSerializer as a nested
field to relationships.
"""
+ if self.source == '*':
+ return self.to_native(obj)
+
try:
if self.source:
for component in self.source.split('.'):
@@ -318,6 +322,13 @@ class BaseSerializer(Field):
if obj is None:
return None
+ if self.many is not None:
+ many = self.many
+ else:
+ many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict, six.text_type))
+
+ if many:
+ return [self.to_native(item) for item in obj]
return self.to_native(obj)
@property
@@ -327,9 +338,30 @@ class BaseSerializer(Field):
setting self.object if no errors occurred.
"""
if self._errors is None:
- obj = self.from_native(self.init_data, self.init_files)
+ data, files = self.init_data, self.init_files
+
+ if self.many is not None:
+ many = self.many
+ else:
+ many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))
+ if many:
+ warnings.warn('Implict list/queryset serialization is due to be deprecated. '
+ 'Use the `many=True` flag when instantiating the serializer.',
+ PendingDeprecationWarning, stacklevel=3)
+
+ if many:
+ ret = []
+ errors = []
+ for item in data:
+ ret.append(self.from_native(item, None))
+ errors.append(self._errors)
+ self._errors = any(errors) and errors or []
+ else:
+ ret = self.from_native(data, files)
+
if not self._errors:
- self.object = obj
+ self.object = ret
+
return self._errors
def is_valid(self):
@@ -337,20 +369,44 @@ class BaseSerializer(Field):
@property
def data(self):
+ """
+ Returns the serialized data on the serializer.
+ """
if self._data is None:
- self._data = self.to_native(self.object)
+ obj = self.object
+
+ if self.many is not None:
+ many = self.many
+ else:
+ many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
+ if many:
+ warnings.warn('Implict list/queryset serialization is due to be deprecated. '
+ 'Use the `many=True` flag when instantiating the serializer.',
+ PendingDeprecationWarning, stacklevel=2)
+
+ if many:
+ self._data = [self.to_native(item) for item in obj]
+ else:
+ self._data = self.to_native(obj)
+
return self._data
- def save(self):
+ def save_object(self, obj, **kwargs):
+ obj.save(**kwargs)
+
+ def save(self, **kwargs):
"""
Save the deserialized object and return it.
"""
- self.object.save()
+ if isinstance(self.object, list):
+ [self.save_object(item, **kwargs) for item in self.object]
+ else:
+ self.save_object(self.object, **kwargs)
return self.object
-class Serializer(BaseSerializer):
- __metaclass__ = SerializerMetaclass
+class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)):
+ pass
class ModelSerializerOptions(SerializerOptions):
@@ -369,16 +425,42 @@ class ModelSerializer(Serializer):
"""
_options_class = ModelSerializerOptions
+ field_mapping = {
+ models.AutoField: IntegerField,
+ models.FloatField: FloatField,
+ models.IntegerField: IntegerField,
+ models.PositiveIntegerField: IntegerField,
+ models.SmallIntegerField: IntegerField,
+ models.PositiveSmallIntegerField: IntegerField,
+ models.DateTimeField: DateTimeField,
+ models.DateField: DateField,
+ models.TimeField: TimeField,
+ models.EmailField: EmailField,
+ models.CharField: CharField,
+ models.URLField: URLField,
+ models.SlugField: SlugField,
+ models.TextField: CharField,
+ models.CommaSeparatedIntegerField: CharField,
+ models.BooleanField: BooleanField,
+ models.FileField: FileField,
+ models.ImageField: ImageField,
+ }
+
def get_default_fields(self):
"""
Return all the fields that should be serialized for the model.
"""
cls = self.opts.model
+ assert cls is not None, \
+ "Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__
opts = get_concrete_model(cls)._meta
pk_field = opts.pk
- while pk_field.rel:
+
+ # If model is a child via multitable inheritance, use parent's pk
+ while pk_field.rel and pk_field.rel.parent_link:
pk_field = pk_field.rel.to._meta.pk
+
fields = [pk_field]
fields += [field for field in opts.fields if field.serialize]
fields += [field for field in opts.many_to_many if field.serialize]
@@ -433,12 +515,11 @@ class ModelSerializer(Serializer):
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
kwargs = {
- 'null': model_field.null or model_field.blank,
- 'queryset': model_field.rel.to._default_manager
+ 'required': not(model_field.null or model_field.blank),
+ 'queryset': model_field.rel.to._default_manager,
+ 'many': to_many
}
- if to_many:
- return ManyPrimaryKeyRelatedField(**kwargs)
return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
@@ -446,20 +527,18 @@ class ModelSerializer(Serializer):
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
+ has_default = model_field.has_default()
- kwargs['blank'] = model_field.blank
-
- if model_field.null or model_field.blank:
+ if model_field.null or model_field.blank or has_default:
kwargs['required'] = False
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
- if model_field.has_default():
- kwargs['required'] = False
+ if has_default:
kwargs['default'] = model_field.get_default()
- if model_field.__class__ == models.TextField:
+ if issubclass(model_field.__class__, models.TextField):
kwargs['widget'] = widgets.Textarea
# TODO: TypedChoiceField?
@@ -467,27 +546,8 @@ class ModelSerializer(Serializer):
kwargs['choices'] = model_field.flatchoices
return ChoiceField(**kwargs)
- field_mapping = {
- models.AutoField: IntegerField,
- models.FloatField: FloatField,
- models.IntegerField: IntegerField,
- models.PositiveIntegerField: IntegerField,
- models.SmallIntegerField: IntegerField,
- models.PositiveSmallIntegerField: IntegerField,
- models.DateTimeField: DateTimeField,
- models.DateField: DateField,
- models.EmailField: EmailField,
- models.CharField: CharField,
- models.URLField: URLField,
- models.SlugField: SlugField,
- models.TextField: CharField,
- models.CommaSeparatedIntegerField: CharField,
- models.BooleanField: BooleanField,
- models.FileField: FileField,
- models.ImageField: ImageField,
- }
try:
- return field_mapping[model_field.__class__](**kwargs)
+ return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
return ModelField(model_field=model_field, **kwargs)
@@ -499,10 +559,27 @@ class ModelSerializer(Serializer):
opts = get_concrete_model(cls)._meta
exclusions = [field.name for field in opts.fields + opts.many_to_many]
for field_name, field in self.fields.items():
+ field_name = field.source or field_name
if field_name in exclusions and not field.read_only:
exclusions.remove(field_name)
return exclusions
+ def full_clean(self, instance):
+ """
+ Perform Django's full_clean, and populate the `errors` dictionary
+ if any validation errors occur.
+
+ Note that we don't perform this inside the `.restore_object()` method,
+ so that subclasses can override `.restore_object()`, and still get
+ the full_clean validation checking.
+ """
+ try:
+ instance.full_clean(exclude=self.get_validation_exclusions())
+ except ValidationError as err:
+ self._errors = err.message_dict
+ return None
+ return instance
+
def restore_object(self, attrs, instance=None):
"""
Restore the model instance.
@@ -534,19 +611,21 @@ class ModelSerializer(Serializer):
else:
instance = self.opts.model(**attrs)
- try:
- instance.full_clean(exclude=self.get_validation_exclusions())
- except ValidationError, err:
- self._errors = err.message_dict
- return None
-
return instance
- def save(self):
+ def from_native(self, data, files):
+ """
+ Override the default method to also include model field validation.
+ """
+ instance = super(ModelSerializer, self).from_native(data, files)
+ if instance:
+ return self.full_clean(instance)
+
+ def save_object(self, obj, **kwargs):
"""
Save the deserialized object and return it.
"""
- self.object.save()
+ obj.save(**kwargs)
if getattr(self, 'm2m_data', None):
for accessor_name, object_list in self.m2m_data.items():
@@ -558,8 +637,6 @@ class ModelSerializer(Serializer):
setattr(self.object, accessor_name, object_list)
self.related_data = {}
- return self.object
-
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
"""
@@ -572,6 +649,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
class HyperlinkedModelSerializer(ModelSerializer):
"""
+ A subclass of ModelSerializer that uses hyperlinked relationships,
+ instead of primary key relationships.
"""
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
@@ -605,10 +684,9 @@ class HyperlinkedModelSerializer(ModelSerializer):
# .using(db).complex_filter(self.rel.limit_choices_to)
rel = model_field.rel.to
kwargs = {
- 'null': model_field.null,
+ 'required': not(model_field.null or model_field.blank),
'queryset': rel._default_manager,
- 'view_name': self._get_default_view_name(rel)
+ 'view_name': self._get_default_view_name(rel),
+ 'many': to_many
}
- if to_many:
- return ManyHyperlinkedRelatedField(**kwargs)
return HyperlinkedRelatedField(**kwargs)
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 5c77c55c..eede0c5a 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -17,9 +17,14 @@ This module provides the `api_setting` object, that is used to access
REST framework settings, checking for user settings first, then falling
back to the defaults.
"""
+from __future__ import unicode_literals
+
from django.conf import settings
from django.utils import importlib
+from rest_framework import ISO_8601
+from rest_framework.compat import six
+
USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None)
@@ -74,6 +79,22 @@ DEFAULTS = {
'URL_FORMAT_OVERRIDE': 'format',
'FORMAT_SUFFIX_KWARG': 'format',
+
+ # Input and output formats
+ 'DATE_INPUT_FORMATS': (
+ ISO_8601,
+ ),
+ 'DATE_FORMAT': ISO_8601,
+
+ 'DATETIME_INPUT_FORMATS': (
+ ISO_8601,
+ ),
+ 'DATETIME_FORMAT': ISO_8601,
+
+ 'TIME_INPUT_FORMATS': (
+ ISO_8601,
+ ),
+ 'TIME_FORMAT': ISO_8601,
}
@@ -98,7 +119,7 @@ def perform_import(val, setting_name):
If the given setting is a string import notation,
then perform the necessary import or imports.
"""
- if isinstance(val, basestring):
+ if isinstance(val, six.string_types):
return import_from_string(val, setting_name)
elif isinstance(val, (list, tuple)):
return [import_from_string(item, setting_name) for item in val]
diff --git a/rest_framework/six.py b/rest_framework/six.py
new file mode 100644
index 00000000..9e382312
--- /dev/null
+++ b/rest_framework/six.py
@@ -0,0 +1,389 @@
+"""Utilities for writing code that runs on Python 2 and 3"""
+
+import operator
+import sys
+import types
+
+__author__ = "Benjamin Peterson <benjamin@python.org>"
+__version__ = "1.2.0"
+
+
+# True if we are running on Python 3.
+PY3 = sys.version_info[0] == 3
+
+if PY3:
+ string_types = str,
+ integer_types = int,
+ class_types = type,
+ text_type = str
+ binary_type = bytes
+
+ MAXSIZE = sys.maxsize
+else:
+ string_types = basestring,
+ integer_types = (int, long)
+ class_types = (type, types.ClassType)
+ text_type = unicode
+ binary_type = str
+
+ if sys.platform == "java":
+ # Jython always uses 32 bits.
+ MAXSIZE = int((1 << 31) - 1)
+ else:
+ # It's possible to have sizeof(long) != sizeof(Py_ssize_t).
+ class X(object):
+ def __len__(self):
+ return 1 << 31
+ try:
+ len(X())
+ except OverflowError:
+ # 32-bit
+ MAXSIZE = int((1 << 31) - 1)
+ else:
+ # 64-bit
+ MAXSIZE = int((1 << 63) - 1)
+ del X
+
+
+def _add_doc(func, doc):
+ """Add documentation to a function."""
+ func.__doc__ = doc
+
+
+def _import_module(name):
+ """Import module, returning the module after the last dot."""
+ __import__(name)
+ return sys.modules[name]
+
+
+class _LazyDescr(object):
+
+ def __init__(self, name):
+ self.name = name
+
+ def __get__(self, obj, tp):
+ result = self._resolve()
+ setattr(obj, self.name, result)
+ # This is a bit ugly, but it avoids running this again.
+ delattr(tp, self.name)
+ return result
+
+
+class MovedModule(_LazyDescr):
+
+ def __init__(self, name, old, new=None):
+ super(MovedModule, self).__init__(name)
+ if PY3:
+ if new is None:
+ new = name
+ self.mod = new
+ else:
+ self.mod = old
+
+ def _resolve(self):
+ return _import_module(self.mod)
+
+
+class MovedAttribute(_LazyDescr):
+
+ def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
+ super(MovedAttribute, self).__init__(name)
+ if PY3:
+ if new_mod is None:
+ new_mod = name
+ self.mod = new_mod
+ if new_attr is None:
+ if old_attr is None:
+ new_attr = name
+ else:
+ new_attr = old_attr
+ self.attr = new_attr
+ else:
+ self.mod = old_mod
+ if old_attr is None:
+ old_attr = name
+ self.attr = old_attr
+
+ def _resolve(self):
+ module = _import_module(self.mod)
+ return getattr(module, self.attr)
+
+
+
+class _MovedItems(types.ModuleType):
+ """Lazy loading of moved objects"""
+
+
+_moved_attributes = [
+ MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
+ MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
+ MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
+ MovedAttribute("map", "itertools", "builtins", "imap", "map"),
+ MovedAttribute("reload_module", "__builtin__", "imp", "reload"),
+ MovedAttribute("reduce", "__builtin__", "functools"),
+ MovedAttribute("StringIO", "StringIO", "io"),
+ MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
+ MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
+
+ MovedModule("builtins", "__builtin__"),
+ MovedModule("configparser", "ConfigParser"),
+ MovedModule("copyreg", "copy_reg"),
+ MovedModule("http_cookiejar", "cookielib", "http.cookiejar"),
+ MovedModule("http_cookies", "Cookie", "http.cookies"),
+ MovedModule("html_entities", "htmlentitydefs", "html.entities"),
+ MovedModule("html_parser", "HTMLParser", "html.parser"),
+ MovedModule("http_client", "httplib", "http.client"),
+ MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"),
+ MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"),
+ MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"),
+ MovedModule("cPickle", "cPickle", "pickle"),
+ MovedModule("queue", "Queue"),
+ MovedModule("reprlib", "repr"),
+ MovedModule("socketserver", "SocketServer"),
+ MovedModule("tkinter", "Tkinter"),
+ MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"),
+ MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"),
+ MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"),
+ MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"),
+ MovedModule("tkinter_tix", "Tix", "tkinter.tix"),
+ MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
+ MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
+ MovedModule("tkinter_colorchooser", "tkColorChooser",
+ "tkinter.colorchooser"),
+ MovedModule("tkinter_commondialog", "tkCommonDialog",
+ "tkinter.commondialog"),
+ MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
+ MovedModule("tkinter_font", "tkFont", "tkinter.font"),
+ MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
+ MovedModule("tkinter_tksimpledialog", "tkSimpleDialog",
+ "tkinter.simpledialog"),
+ MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"),
+ MovedModule("winreg", "_winreg"),
+]
+for attr in _moved_attributes:
+ setattr(_MovedItems, attr.name, attr)
+del attr
+
+moves = sys.modules["django.utils.six.moves"] = _MovedItems("moves")
+
+
+def add_move(move):
+ """Add an item to six.moves."""
+ setattr(_MovedItems, move.name, move)
+
+
+def remove_move(name):
+ """Remove item from six.moves."""
+ try:
+ delattr(_MovedItems, name)
+ except AttributeError:
+ try:
+ del moves.__dict__[name]
+ except KeyError:
+ raise AttributeError("no such move, %r" % (name,))
+
+
+if PY3:
+ _meth_func = "__func__"
+ _meth_self = "__self__"
+
+ _func_code = "__code__"
+ _func_defaults = "__defaults__"
+
+ _iterkeys = "keys"
+ _itervalues = "values"
+ _iteritems = "items"
+else:
+ _meth_func = "im_func"
+ _meth_self = "im_self"
+
+ _func_code = "func_code"
+ _func_defaults = "func_defaults"
+
+ _iterkeys = "iterkeys"
+ _itervalues = "itervalues"
+ _iteritems = "iteritems"
+
+
+try:
+ advance_iterator = next
+except NameError:
+ def advance_iterator(it):
+ return it.next()
+next = advance_iterator
+
+
+if PY3:
+ def get_unbound_function(unbound):
+ return unbound
+
+ Iterator = object
+
+ def callable(obj):
+ return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
+else:
+ def get_unbound_function(unbound):
+ return unbound.im_func
+
+ class Iterator(object):
+
+ def next(self):
+ return type(self).__next__(self)
+
+ callable = callable
+_add_doc(get_unbound_function,
+ """Get the function out of a possibly unbound function""")
+
+
+get_method_function = operator.attrgetter(_meth_func)
+get_method_self = operator.attrgetter(_meth_self)
+get_function_code = operator.attrgetter(_func_code)
+get_function_defaults = operator.attrgetter(_func_defaults)
+
+
+def iterkeys(d):
+ """Return an iterator over the keys of a dictionary."""
+ return iter(getattr(d, _iterkeys)())
+
+def itervalues(d):
+ """Return an iterator over the values of a dictionary."""
+ return iter(getattr(d, _itervalues)())
+
+def iteritems(d):
+ """Return an iterator over the (key, value) pairs of a dictionary."""
+ return iter(getattr(d, _iteritems)())
+
+
+if PY3:
+ def b(s):
+ return s.encode("latin-1")
+ def u(s):
+ return s
+ if sys.version_info[1] <= 1:
+ def int2byte(i):
+ return bytes((i,))
+ else:
+ # This is about 2x faster than the implementation above on 3.2+
+ int2byte = operator.methodcaller("to_bytes", 1, "big")
+ import io
+ StringIO = io.StringIO
+ BytesIO = io.BytesIO
+else:
+ def b(s):
+ return s
+ def u(s):
+ return unicode(s, "unicode_escape")
+ int2byte = chr
+ import StringIO
+ StringIO = BytesIO = StringIO.StringIO
+_add_doc(b, """Byte literal""")
+_add_doc(u, """Text literal""")
+
+
+if PY3:
+ import builtins
+ exec_ = getattr(builtins, "exec")
+
+
+ def reraise(tp, value, tb=None):
+ if value.__traceback__ is not tb:
+ raise value.with_traceback(tb)
+ raise value
+
+
+ print_ = getattr(builtins, "print")
+ del builtins
+
+else:
+ def exec_(code, globs=None, locs=None):
+ """Execute code in a namespace."""
+ if globs is None:
+ frame = sys._getframe(1)
+ globs = frame.f_globals
+ if locs is None:
+ locs = frame.f_locals
+ del frame
+ elif locs is None:
+ locs = globs
+ exec("""exec code in globs, locs""")
+
+
+ exec_("""def reraise(tp, value, tb=None):
+ raise tp, value, tb
+""")
+
+
+ def print_(*args, **kwargs):
+ """The new-style print function."""
+ fp = kwargs.pop("file", sys.stdout)
+ if fp is None:
+ return
+ def write(data):
+ if not isinstance(data, basestring):
+ data = str(data)
+ fp.write(data)
+ want_unicode = False
+ sep = kwargs.pop("sep", None)
+ if sep is not None:
+ if isinstance(sep, unicode):
+ want_unicode = True
+ elif not isinstance(sep, str):
+ raise TypeError("sep must be None or a string")
+ end = kwargs.pop("end", None)
+ if end is not None:
+ if isinstance(end, unicode):
+ want_unicode = True
+ elif not isinstance(end, str):
+ raise TypeError("end must be None or a string")
+ if kwargs:
+ raise TypeError("invalid keyword arguments to print()")
+ if not want_unicode:
+ for arg in args:
+ if isinstance(arg, unicode):
+ want_unicode = True
+ break
+ if want_unicode:
+ newline = unicode("\n")
+ space = unicode(" ")
+ else:
+ newline = "\n"
+ space = " "
+ if sep is None:
+ sep = space
+ if end is None:
+ end = newline
+ for i, arg in enumerate(args):
+ if i:
+ write(sep)
+ write(arg)
+ write(end)
+
+_add_doc(reraise, """Reraise an exception.""")
+
+
+def with_metaclass(meta, base=object):
+ """Create a base class with a metaclass."""
+ return meta("NewBase", (base,), {})
+
+
+### Additional customizations for Django ###
+
+if PY3:
+ _iterlists = "lists"
+ _assertRaisesRegex = "assertRaisesRegex"
+else:
+ _iterlists = "iterlists"
+ _assertRaisesRegex = "assertRaisesRegexp"
+
+
+def iterlists(d):
+ """Return an iterator over the values of a MultiValueDict."""
+ return getattr(d, _iterlists)()
+
+
+def assertRaisesRegex(self, *args, **kwargs):
+ return getattr(self, _assertRaisesRegex)(*args, **kwargs)
+
+
+add_move(MovedModule("_dummy_thread", "dummy_thread"))
+add_move(MovedModule("_thread", "thread"))
diff --git a/rest_framework/static/rest_framework/css/default.css b/rest_framework/static/rest_framework/css/default.css
index b2e41b99..d806267b 100644
--- a/rest_framework/static/rest_framework/css/default.css
+++ b/rest_framework/static/rest_framework/css/default.css
@@ -150,6 +150,49 @@ html, body {
margin: 0 auto -60px;
}
+.form-switcher {
+ margin-bottom: 0;
+}
+
+.well {
+ -webkit-box-shadow: none;
+ -moz-box-shadow: none;
+ box-shadow: none;
+}
+
+.well .form-actions {
+ padding-bottom: 0;
+ margin-bottom: 0;
+}
+
+.well form {
+ margin-bottom: 0;
+}
+
+.nav-tabs {
+ border: 0;
+}
+
+.nav-tabs > li {
+ float: right;
+}
+
+.nav-tabs li a {
+ margin-right: 0;
+}
+
+.nav-tabs > .active > a {
+ background: #f5f5f5;
+}
+
+.nav-tabs > .active > a:hover {
+ background: #f5f5f5;
+}
+
+.tabbable.first-tab-active .tab-content
+{
+ border-top-right-radius: 0;
+}
#footer, #push {
height: 60px; /* .push must be the same height as .footer */
diff --git a/rest_framework/static/rest_framework/js/default.js b/rest_framework/static/rest_framework/js/default.js
index ecaccc0f..c74829d7 100644
--- a/rest_framework/static/rest_framework/js/default.js
+++ b/rest_framework/static/rest_framework/js/default.js
@@ -3,3 +3,11 @@ prettyPrint();
$('.js-tooltip').tooltip({
delay: 1000
});
+
+$('a[data-toggle="tab"]:first').on('shown', function (e) {
+ $(e.target).parents('.tabbable').addClass('first-tab-active');
+});
+$('a[data-toggle="tab"]:not(:first)').on('shown', function (e) {
+ $(e.target).parents('.tabbable').removeClass('first-tab-active');
+});
+$('.form-switcher a:first').tab('show');
diff --git a/rest_framework/status.py b/rest_framework/status.py
index a1eb48da..b9f249f9 100644
--- a/rest_framework/status.py
+++ b/rest_framework/status.py
@@ -4,6 +4,7 @@ Descriptive HTTP status codes, for code readability.
See RFC 2616 - http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html
And RFC 6585 - http://tools.ietf.org/html/rfc6585
"""
+from __future__ import unicode_literals
HTTP_100_CONTINUE = 100
HTTP_101_SWITCHING_PROTOCOLS = 101
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 092bf2e4..44633f5a 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -13,7 +13,7 @@
<title>{% block title %}Django REST framework{% endblock %}</title>
{% block style %}
- <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>
+ {% block bootstrap_theme %}<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>{% endblock %}
<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/prettify.css" %}"/>
<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>
@@ -123,56 +123,88 @@
{% if response.status_code != 403 %}
- {% if post_form %}
- <div class="well">
- <form action="{{ request.get_full_path }}" method="POST" {% if post_form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
- <fieldset>
- {% csrf_token %}
- {{ post_form.non_field_errors }}
- {% for field in post_form %}
- <div class="control-group"> <!--{% if field.errors %}error{% endif %}-->
- {{ field.label_tag|add_class:"control-label" }}
- <div class="controls">
- {{ field }}
- <span class="help-inline">{{ field.help_text }}</span>
- <!--{{ field.errors|add_class:"help-block" }}-->
+ {% if post_form or raw_data_post_form %}
+ <div {% if post_form %}class="tabbable"{% endif %}>
+ {% if post_form %}
+ <ul class="nav nav-tabs form-switcher">
+ <li><a href="#object-form" data-toggle="tab">HTML form</a></li>
+ <li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li>
+ </ul>
+ {% endif %}
+ <div class="well tab-content">
+ {% if post_form %}
+ <div class="tab-pane" id="object-form">
+ {% with form=post_form %}
+ <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
+ <fieldset>
+ {% include "rest_framework/form.html" %}
+ <div class="form-actions">
+ <button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button>
</div>
- </div>
- {% endfor %}
- <div class="form-actions">
- <button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button>
- </div>
- </fieldset>
- </form>
+ </fieldset>
+ </form>
+ {% endwith %}
+ </div>
+ {% endif %}
+ <div {% if post_form %}class="tab-pane"{% endif %} id="generic-content-form">
+ {% with form=raw_data_post_form %}
+ <form action="{{ request.get_full_path }}" method="POST" class="form-horizontal">
+ <fieldset>
+ {% include "rest_framework/form.html" %}
+ <div class="form-actions">
+ <button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button>
+ </div>
+ </fieldset>
+ </form>
+ {% endwith %}
+ </div>
+ </div>
</div>
{% endif %}
- {% if put_form %}
- <div class="well">
- <form action="{{ request.get_full_path }}" method="POST" {% if put_form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
- <fieldset>
- <input type="hidden" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" />
- {% csrf_token %}
- {{ put_form.non_field_errors }}
- {% for field in put_form %}
- <div class="control-group"> <!--{% if field.errors %}error{% endif %}-->
- {{ field.label_tag|add_class:"control-label" }}
- <div class="controls">
- {{ field }}
- <span class='help-inline'>{{ field.help_text }}</span>
- <!--{{ field.errors|add_class:"help-block" }}-->
+ {% if put_form or raw_data_put_form or raw_data_patch_form %}
+ <div {% if put_form %}class="tabbable"{% endif %}>
+ {% if put_form %}
+ <ul class="nav nav-tabs form-switcher">
+ <li><a href="#object-form" data-toggle="tab">HTML form</a></li>
+ <li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li>
+ </ul>
+ {% endif %}
+ <div class="well tab-content">
+ {% if put_form %}
+ <div class="tab-pane" id="object-form">
+ {% with form=put_form %}
+ <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
+ <fieldset>
+ {% include "rest_framework/form.html" %}
+ <div class="form-actions">
+ <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button>
</div>
- </div>
- {% endfor %}
- <div class="form-actions">
- <button class="btn btn-primary js-tooltip" title="Make a PUT request on the {{ name }} resource">PUT</button>
- </div>
-
- </fieldset>
- </form>
+ </fieldset>
+ </form>
+ {% endwith %}
+ </div>
+ {% endif %}
+ <div {% if put_form %}class="tab-pane"{% endif %} id="generic-content-form">
+ {% with form=raw_data_put_or_patch_form %}
+ <form action="{{ request.get_full_path }}" method="POST" class="form-horizontal">
+ <fieldset>
+ {% include "rest_framework/form.html" %}
+ <div class="form-actions">
+ {% if raw_data_put_form %}
+ <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button>
+ {% endif %}
+ {% if raw_data_patch_form %}
+ <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PATCH" title="Make a PUT request on the {{ name }} resource">PATCH</button>
+ {% endif %}
+ </div>
+ </fieldset>
+ </form>
+ {% endwith %}
+ </div>
+ </div>
</div>
{% endif %}
-
{% endif %}
</div>
diff --git a/rest_framework/templates/rest_framework/form.html b/rest_framework/templates/rest_framework/form.html
new file mode 100644
index 00000000..dc7acc70
--- /dev/null
+++ b/rest_framework/templates/rest_framework/form.html
@@ -0,0 +1,13 @@
+{% load rest_framework %}
+{% csrf_token %}
+{{ form.non_field_errors }}
+{% for field in form %}
+ <div class="control-group"> <!--{% if field.errors %}error{% endif %}-->
+ {{ field.label_tag|add_class:"control-label" }}
+ <div class="controls">
+ {{ field }}
+ <span class="help-inline">{{ field.help_text }}</span>
+ <!--{{ field.errors|add_class:"help-block" }}-->
+ </div>
+ </div>
+{% endfor %}
diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html
index 6e2bd8d4..e10ce20f 100644
--- a/rest_framework/templates/rest_framework/login.html
+++ b/rest_framework/templates/rest_framework/login.html
@@ -25,14 +25,14 @@
<form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">
{% csrf_token %}
<div id="div_id_username" class="clearfix control-group">
- <div class="controls" style="height: 30px">
- <Label class="span4" style="margin-top: 3px">Username:</label>
+ <div class="controls">
+ <Label class="span4">Username:</label>
<input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">
</div>
</div>
<div id="div_id_password" class="clearfix control-group">
- <div class="controls" style="height: 30px">
- <Label class="span4" style="margin-top: 3px">Password:</label>
+ <div class="controls">
+ <Label class="span4">Password:</label>
<input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">
</div>
</div>
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index 82fcdfe7..c21ddcd7 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -1,10 +1,12 @@
+from __future__ import unicode_literals, absolute_import
from django import template
-from django.core.urlresolvers import reverse
+from django.core.urlresolvers import reverse, NoReverseMatch
from django.http import QueryDict
-from django.utils.encoding import force_unicode
from django.utils.html import escape
from django.utils.safestring import SafeData, mark_safe
-from urlparse import urlsplit, urlunsplit
+from rest_framework.compat import urlparse
+from rest_framework.compat import force_text
+from rest_framework.compat import six
import re
import string
@@ -29,7 +31,7 @@ try: # Django 1.5+
def do_static(parser, token):
return StaticFilesNode.handle_token(parser, token)
-except:
+except ImportError:
try: # Django 1.4
from django.contrib.staticfiles.storage import staticfiles_storage
@@ -41,7 +43,7 @@ except:
"""
return staticfiles_storage.url(path)
- except: # Django 1.3
+ except ImportError: # Django 1.3
from urlparse import urljoin
from django import template
from django.templatetags.static import PrefixNode
@@ -99,11 +101,11 @@ def replace_query_param(url, key, val):
Given a URL and a key/val pair, set or replace an item in the query
parameters of the URL, and return the new URL.
"""
- (scheme, netloc, path, query, fragment) = urlsplit(url)
+ (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url)
query_dict = QueryDict(query).copy()
query_dict[key] = val
query = query_dict.urlencode()
- return urlunsplit((scheme, netloc, path, query, fragment))
+ return urlparse.urlunsplit((scheme, netloc, path, query, fragment))
# Regex for adding classes to html snippets
@@ -135,7 +137,7 @@ def optional_login(request):
"""
try:
login_url = reverse('rest_framework:login')
- except:
+ except NoReverseMatch:
return ''
snippet = "<a href='%s?next=%s'>Log in</a>" % (login_url, request.path)
@@ -149,7 +151,7 @@ def optional_logout(request):
"""
try:
logout_url = reverse('rest_framework:logout')
- except:
+ except NoReverseMatch:
return ''
snippet = "<a href='%s?next=%s'>Log out</a>" % (logout_url, request.path)
@@ -179,7 +181,7 @@ def add_class(value, css_class):
In the case of REST Framework, the filter is used to add Bootstrap-specific
classes to the forms.
"""
- html = unicode(value)
+ html = six.text_type(value)
match = class_re.search(html)
if match:
m = re.search(r'^%s$|^%s\s|\s%s\s|\s%s$' % (css_class, css_class,
@@ -213,7 +215,7 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
"""
trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
safe_input = isinstance(text, SafeData)
- words = word_split_re.split(force_unicode(text))
+ words = word_split_re.split(force_text(text))
nofollow_attr = nofollow and ' rel="nofollow"' or ''
for i, word in enumerate(words):
match = None
@@ -249,4 +251,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
words[i] = mark_safe(word)
elif autoescape:
words[i] = escape(word)
- return mark_safe(u''.join(words))
+ return mark_safe(''.join(words))
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index e86041bc..b663ca48 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -1,33 +1,65 @@
+from __future__ import unicode_literals
from django.contrib.auth.models import User
from django.http import HttpResponse
from django.test import Client, TestCase
-
+from django.utils import unittest
+from rest_framework import HTTP_HEADER_ENCODING
+from rest_framework import exceptions
from rest_framework import permissions
+from rest_framework import status
+from rest_framework.authentication import (
+ BaseAuthentication,
+ TokenAuthentication,
+ BasicAuthentication,
+ SessionAuthentication,
+ OAuthAuthentication,
+ OAuth2Authentication
+)
from rest_framework.authtoken.models import Token
-from rest_framework.authentication import TokenAuthentication
-from rest_framework.compat import patterns
+from rest_framework.compat import patterns, url, include
+from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
+from rest_framework.compat import oauth, oauth_provider
+from rest_framework.tests.utils import RequestFactory
from rest_framework.views import APIView
-
import json
import base64
+import time
+import datetime
+
+factory = RequestFactory()
class MockView(APIView):
permission_classes = (permissions.IsAuthenticated,)
+ def get(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
def post(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
def put(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
-MockView.authentication_classes += (TokenAuthentication,)
urlpatterns = patterns('',
- (r'^$', MockView.as_view()),
+ (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
+ (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
+ (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
+ (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
+ (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
+ permission_classes=[permissions.TokenHasReadWriteScope]))
)
+if oauth2_provider is not None:
+ urlpatterns += patterns('',
+ url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
+ url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
+ url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
+ permission_classes=[permissions.TokenHasReadWriteScope])),
+ )
+
class BasicAuthTests(TestCase):
"""Basic authentication"""
@@ -42,25 +74,30 @@ class BasicAuthTests(TestCase):
def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
- auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
- response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
+ credentials = ('%s:%s' % (self.username, self.password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ auth = 'Basic %s' % base64_credentials
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
- auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
+ credentials = ('%s:%s' % (self.username, self.password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ auth = 'Basic %s' % base64_credentials
+ response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_failing_basic_auth(self):
"""Ensure POSTing form over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/basic/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+ self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
class SessionAuthTests(TestCase):
@@ -83,31 +120,31 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication without CSRF token fails.
"""
self.csrf_client.login(username=self.username, password=self.password)
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_post_form_session_auth_passing(self):
"""
Ensure POSTing form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 200)
+ response = self.non_csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_put_form_session_auth_passing(self):
"""
Ensure PUTting form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.put('/', {'example': 'example'})
- self.assertEqual(response.status_code, 200)
+ response = self.non_csrf_client.put('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_session_auth_failing(self):
"""
Ensure POSTing form over session authentication without logged in user fails.
"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
class TokenAuthTests(TestCase):
@@ -126,25 +163,25 @@ class TokenAuthTests(TestCase):
def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
- auth = "Token " + self.key
- response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
+ auth = 'Token ' + self.key
+ response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
+ response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/token/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails"""
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key"""
@@ -157,8 +194,8 @@ class TokenAuthTests(TestCase):
client = Client(enforce_csrf_checks=True)
response = client.post('/auth-token/',
json.dumps({'username': self.username, 'password': self.password}), 'application/json')
- self.assertEqual(response.status_code, 200)
- self.assertEqual(json.loads(response.content)['token'], self.key)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key)
def test_token_login_json_bad_creds(self):
"""Ensure token login view using JSON POST fails if bad credentials are used."""
@@ -179,5 +216,362 @@ class TokenAuthTests(TestCase):
client = Client(enforce_csrf_checks=True)
response = client.post('/auth-token/',
{'username': self.username, 'password': self.password})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key)
+
+
+class IncorrectCredentialsTests(TestCase):
+ def test_incorrect_credentials(self):
+ """
+ If a request contains bad authentication credentials, then
+ authentication should run and error, even if no permissions
+ are set on the view.
+ """
+ class IncorrectCredentialsAuth(BaseAuthentication):
+ def authenticate(self, request):
+ raise exceptions.AuthenticationFailed('Bad credentials')
+
+ request = factory.get('/')
+ view = MockView.as_view(
+ authentication_classes=(IncorrectCredentialsAuth,),
+ permission_classes=()
+ )
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.data, {'detail': 'Bad credentials'})
+
+
+class OAuthTests(TestCase):
+ """OAuth 1.0a authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ # these imports are here because oauth is optional and hiding them in try..except block or compat
+ # could obscure problems if something breaks
+ from oauth_provider.models import Consumer, Resource
+ from oauth_provider.models import Token as OAuthToken
+ from oauth_provider import consts
+
+ self.consts = consts
+
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CONSUMER_KEY = 'consumer_key'
+ self.CONSUMER_SECRET = 'consumer_secret'
+ self.TOKEN_KEY = "token_key"
+ self.TOKEN_SECRET = "token_secret"
+
+ self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
+ name='example', user=self.user, status=self.consts.ACCEPTED)
+
+ self.resource = Resource.objects.create(name="resource name", url="api/")
+ self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, resource=self.resource,
+ token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True
+ )
+
+ def _create_authorization_header(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+
+ return req.to_header()["Authorization"]
+
+ def _create_authorization_url_parameters(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+ return dict(req)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_passing_oauth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_repeated_nonce_failing_oauth(self):
+ """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ # simulate reply attack auth header containes already used (nonce, timestamp) pair
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_token_removed_failing_oauth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.token.delete()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_consumer_status_not_accepted_failing_oauth(self):
+ """Ensure POSTing when consumer status is anything other than ACCEPTED fails"""
+ for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED):
+ self.consumer.status = consumer_status
+ self.consumer.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_request_token_failing_oauth(self):
+ """Ensure POSTing with unauthorized request token instead of access token fails"""
+ self.token.token_type = self.token.REQUEST
+ self.token.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_urlencoded_parameters(self):
+ """Ensure POSTing with x-www-form-urlencoded auth parameters passes"""
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_url_parameters(self):
+ """Ensure GETing with auth in url parameters passes"""
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_hmac_sha1_signature_passes(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_readonly_resource_passing_auth(self):
+ """Ensure POSTing with a readonly resource instead of a write scope fails"""
+ read_only_access_token = self.token
+ read_only_access_token.resource.is_readonly = True
+ read_only_access_token.resource.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth-with-scope/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_readonly_resource_failing_auth(self):
+ """Ensure POSTing with a readonly resource instead of a write scope fails"""
+ read_only_access_token = self.token
+ read_only_access_token.resource.is_readonly = True
+ read_only_access_token.resource.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth-with-scope/', params)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_write_resource_passing_auth(self):
+ """Ensure POSTing with a write resource succeed"""
+ read_write_access_token = self.token
+ read_write_access_token.resource.is_readonly = False
+ read_write_access_token.resource.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth-with-scope/', params)
+ self.assertEqual(response.status_code, 200)
+
+
+class OAuth2Tests(TestCase):
+ """OAuth 2.0 authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CLIENT_ID = 'client_key'
+ self.CLIENT_SECRET = 'client_secret'
+ self.ACCESS_TOKEN = "access_token"
+ self.REFRESH_TOKEN = "refresh_token"
+
+ self.oauth2_client = oauth2_provider_models.Client.objects.create(
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ redirect_uri='',
+ client_type=0,
+ name='example',
+ user=None,
+ )
+
+ self.access_token = oauth2_provider_models.AccessToken.objects.create(
+ token=self.ACCESS_TOKEN,
+ client=self.oauth2_client,
+ user=self.user,
+ )
+ self.refresh_token = oauth2_provider_models.RefreshToken.objects.create(
+ user=self.user,
+ access_token=self.access_token,
+ client=self.oauth2_client
+ )
+
+ def _create_authorization_header(self, token=None):
+ return "Bearer {0}".format(token or self.access_token.token)
+
+ def _client_credentials_params(self):
+ return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET}
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_type_failing(self):
+ """Ensure that a wrong token type lead to the correct HTTP error status code"""
+ auth = "Wrong token-type-obsviously"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_format_failing(self):
+ """Ensure that a wrong token format lead to the correct HTTP error status code"""
+ auth = "Bearer wrong token format"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_failing(self):
+ """Ensure that a wrong token lead to the correct HTTP error status code"""
+ auth = "Bearer wrong-token"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_client_data_failing_auth(self):
+ """Ensure GETing form over OAuth with incorrect client credentials fails"""
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ params['client_id'] += 'a'
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_passing_auth(self):
+ """Ensure GETing form over OAuth with correct client credentials succeed"""
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_passing_auth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_token_removed_failing_auth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.access_token.delete()
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_refresh_token_failing_auth(self):
+ """Ensure POSTing with refresh token instead of access token fails"""
+ auth = self._create_authorization_header(token=self.refresh_token.token)
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_expired_access_token_failing_auth(self):
+ """Ensure POSTing with expired access token fails with an 'Invalid token' error"""
+ self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
+ self.access_token.save()
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+ self.assertIn('Invalid token', response.content)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_invalid_scope_failing_auth(self):
+ """Ensure POSTing with a readonly scope instead of a write scope fails"""
+ read_only_access_token = self.access_token
+ read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
+ read_only_access_token.save()
+ auth = self._create_authorization_header(token=read_only_access_token.token)
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_valid_scope_passing_auth(self):
+ """Ensure POSTing with a write scope succeed"""
+ read_write_access_token = self.access_token
+ read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
+ read_write_access_token.save()
+ auth = self._create_authorization_header(token=read_write_access_token.token)
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
- self.assertEqual(json.loads(response.content)['token'], self.key)
diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/breadcrumbs.py
index df891683..d9ed647e 100644
--- a/rest_framework/tests/breadcrumbs.py
+++ b/rest_framework/tests/breadcrumbs.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.compat import patterns, url
from rest_framework.utils.breadcrumbs import get_breadcrumbs
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py
index 5e6bce4e..1016fed3 100644
--- a/rest_framework/tests/decorators.py
+++ b/rest_framework/tests/decorators.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import status
from rest_framework.response import Response
@@ -28,13 +29,27 @@ class DecoratorTestCase(TestCase):
response.request = request
return APIView.finalize_response(self, request, response, *args, **kwargs)
- def test_wrap_view(self):
+ def test_api_view_incorrect(self):
+ """
+ If @api_view is not applied correct, we should raise an assertion.
+ """
- @api_view(['GET'])
+ @api_view
def view(request):
- return Response({})
+ return Response()
+
+ request = self.factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ def test_api_view_incorrect_arguments(self):
+ """
+ If @api_view is missing arguments, we should raise an assertion.
+ """
- self.assertTrue(isinstance(view.cls_instance, APIView))
+ with self.assertRaises(AssertionError):
+ @api_view('GET')
+ def view(request):
+ return Response()
def test_calling_method(self):
@@ -44,11 +59,11 @@ class DecoratorTestCase(TestCase):
request = self.factory.get('/')
response = view(request)
- self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
request = self.factory.post('/')
response = view(request)
- self.assertEqual(response.status_code, 405)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_calling_put_method(self):
@@ -58,11 +73,11 @@ class DecoratorTestCase(TestCase):
request = self.factory.put('/')
response = view(request)
- self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
request = self.factory.post('/')
response = view(request)
- self.assertEqual(response.status_code, 405)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_calling_patch_method(self):
@@ -72,11 +87,11 @@ class DecoratorTestCase(TestCase):
request = self.factory.patch('/')
response = view(request)
- self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
request = self.factory.post('/')
response = view(request)
- self.assertEqual(response.status_code, 405)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_renderer_classes(self):
@@ -124,7 +139,7 @@ class DecoratorTestCase(TestCase):
request = self.factory.get('/')
response = view(request)
- self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_throttle_classes(self):
class OncePerDayUserThrottle(UserRateThrottle):
@@ -137,7 +152,7 @@ class DecoratorTestCase(TestCase):
request = self.factory.get('/')
response = view(request)
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
response = view(request)
- self.assertEquals(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
+ self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py
index d958b840..5b3315bc 100644
--- a/rest_framework/tests/description.py
+++ b/rest_framework/tests/description.py
@@ -1,3 +1,6 @@
+# -- coding: utf-8 --
+
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.views import APIView
from rest_framework.compat import apply_markdown
@@ -50,7 +53,7 @@ class TestViewNamesAndDescriptions(TestCase):
"""Ensure Resource names are based on the classname by default."""
class MockView(APIView):
pass
- self.assertEquals(MockView().get_name(), 'Mock')
+ self.assertEqual(MockView().get_name(), 'Mock')
def test_resource_name_can_be_set_explicitly(self):
"""Ensure Resource names can be set using the 'get_name' method."""
@@ -58,7 +61,7 @@ class TestViewNamesAndDescriptions(TestCase):
class MockView(APIView):
def get_name(self):
return example
- self.assertEquals(MockView().get_name(), example)
+ self.assertEqual(MockView().get_name(), example)
def test_resource_description_uses_docstring_by_default(self):
"""Ensure Resource names are based on the docstring by default."""
@@ -78,7 +81,7 @@ class TestViewNamesAndDescriptions(TestCase):
# hash style header #"""
- self.assertEquals(MockView().get_description(), DESCRIPTION)
+ self.assertEqual(MockView().get_description(), DESCRIPTION)
def test_resource_description_can_be_set_explicitly(self):
"""Ensure Resource descriptions can be set using the 'get_description' method."""
@@ -88,7 +91,16 @@ class TestViewNamesAndDescriptions(TestCase):
"""docstring"""
def get_description(self):
return example
- self.assertEquals(MockView().get_description(), example)
+ self.assertEqual(MockView().get_description(), example)
+
+ def test_resource_description_supports_unicode(self):
+
+ class MockView(APIView):
+ """Проверка"""
+ pass
+
+ self.assertEqual(MockView().get_description(), "Проверка")
+
def test_resource_description_does_not_require_docstring(self):
"""Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method."""
@@ -97,13 +109,13 @@ class TestViewNamesAndDescriptions(TestCase):
class MockView(APIView):
def get_description(self):
return example
- self.assertEquals(MockView().get_description(), example)
+ self.assertEqual(MockView().get_description(), example)
def test_resource_description_can_be_empty(self):
"""Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string."""
class MockView(APIView):
pass
- self.assertEquals(MockView().get_description(), '')
+ self.assertEqual(MockView().get_description(), '')
def test_markdown(self):
"""Ensure markdown to HTML works as expected"""
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
index 8068272d..fd6de779 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/fields.py
@@ -1,9 +1,13 @@
"""
General serializer field tests.
"""
+from __future__ import unicode_literals
+import datetime
from django.db import models
from django.test import TestCase
+from django.core import validators
+
from rest_framework import serializers
@@ -26,24 +30,415 @@ class CharPrimaryKeyModelSerializer(serializers.ModelSerializer):
model = CharPrimaryKeyModel
-class ReadOnlyFieldTests(TestCase):
+class TimeFieldModel(models.Model):
+ clock = models.TimeField()
+
+
+class TimeFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TimeFieldModel
+
+
+class BasicFieldTests(TestCase):
def test_auto_now_fields_read_only(self):
"""
auto_now and auto_now_add fields should be read_only by default.
"""
serializer = TimestampedModelSerializer()
- self.assertEquals(serializer.fields['added'].read_only, True)
+ self.assertEqual(serializer.fields['added'].read_only, True)
def test_auto_pk_fields_read_only(self):
"""
AutoField fields should be read_only by default.
"""
serializer = TimestampedModelSerializer()
- self.assertEquals(serializer.fields['id'].read_only, True)
+ self.assertEqual(serializer.fields['id'].read_only, True)
def test_non_auto_pk_fields_not_read_only(self):
"""
PK fields other than AutoField fields should not be read_only by default.
"""
serializer = CharPrimaryKeyModelSerializer()
- self.assertEquals(serializer.fields['id'].read_only, False)
+ self.assertEqual(serializer.fields['id'].read_only, False)
+
+
+class DateFieldTest(TestCase):
+ """
+ Tests for the DateFieldTest from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native('1984-07-31')
+
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_from_native_datetime_date(self):
+ """
+ Make sure from_native() accepts a datetime.date instance.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual(result_1, datetime.date(1984, 7, 31))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+ result = f.from_native('1984 -- 31')
+
+ self.assertEqual(datetime.date(1984, 1, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+
+ try:
+ f.from_native('1984-07-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_date(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid date.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984-13-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.DateField()
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984-07-31', result_1)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateField(format="%Y - %m.%d")
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984 - 07.31', result_1)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+
+class DateTimeFieldTest(TestCase):
+ """
+ Tests for the DateTimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native('1984-07-31 04:31')
+ result_2 = f.from_native('1984-07-31 04:31:59')
+ result_3 = f.from_native('1984-07-31 04:31:59.000200')
+
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3)
+
+ def test_from_native_datetime_datetime(self):
+ """
+ Make sure from_native() accepts a datetime.datetime instance.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31))
+ self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59))
+ self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+ result = f.from_native('1984 -- 04:59')
+
+ self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+
+ try:
+ f.from_native('1984-07-31 04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_datetime(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid datetime.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.DateTimeField()
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984-07-31T00:00:00', result_1)
+ self.assertEqual('1984-07-31T04:31:00', result_2)
+ self.assertEqual('1984-07-31T04:31:59', result_3)
+ self.assertEqual('1984-07-31T04:31:59.000200', result_4)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateTimeField(format="%Y - %H:%M")
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984 - 00:00', result_1)
+ self.assertEqual('1984 - 04:31', result_2)
+ self.assertEqual('1984 - 04:31', result_3)
+ self.assertEqual('1984 - 04:31', result_4)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+
+class TimeFieldTest(TestCase):
+ """
+ Tests for the TimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.TimeField()
+ result_1 = f.from_native('04:31')
+ result_2 = f.from_native('04:31:59')
+ result_3 = f.from_native('04:31:59.000200')
+
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
+
+ def test_from_native_datetime_time(self):
+ """
+ Make sure from_native() accepts a datetime.time instance.
+ """
+ f = serializers.TimeField()
+ result_1 = f.from_native(datetime.time(4, 31))
+ result_2 = f.from_native(datetime.time(4, 31, 59))
+ result_3 = f.from_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.time(4, 31))
+ self.assertEqual(result_2, datetime.time(4, 31, 59))
+ self.assertEqual(result_3, datetime.time(4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+ result = f.from_native('04 -- 31')
+
+ self.assertEqual(datetime.time(4, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+
+ try:
+ f.from_native('04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.TimeField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.TimeField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_time(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid time.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.TimeField()
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04:31:00', result_1)
+ self.assertEqual('04:31:59', result_2)
+ self.assertEqual('04:31:59.000200', result_3)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.TimeField(format="%H - %S [%f]")
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04 - 00 [000000]', result_1)
+ self.assertEqual('04 - 59 [000000]', result_2)
+ self.assertEqual('04 - 59 [000200]', result_3)
diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py
index 446e23c0..487046ac 100644
--- a/rest_framework/tests/files.py
+++ b/rest_framework/tests/files.py
@@ -1,9 +1,9 @@
-import StringIO
-import datetime
-
+from __future__ import unicode_literals
from django.test import TestCase
-
from rest_framework import serializers
+from rest_framework.compat import BytesIO
+from rest_framework.compat import six
+import datetime
class UploadedFile(object):
@@ -27,14 +27,14 @@ class UploadedFileSerializer(serializers.Serializer):
class FileSerializerTests(TestCase):
def test_create(self):
now = datetime.datetime.now()
- file = StringIO.StringIO('stuff')
+ file = BytesIO(six.b('stuff'))
file.name = 'stuff.txt'
- file.size = file.len
+ file.size = len(file.getvalue())
serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
uploaded_file = UploadedFile(file=file, created=now)
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.object.created, uploaded_file.created)
- self.assertEquals(serializer.object.file, uploaded_file.file)
+ self.assertEqual(serializer.object.created, uploaded_file.created)
+ self.assertEqual(serializer.object.file, uploaded_file.file)
self.assertFalse(serializer.object is uploaded_file)
def test_creation_failure(self):
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
index af2e6c2e..fe92e0bc 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/filterset.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
import datetime
from decimal import Decimal
from django.test import TestCase
@@ -64,8 +65,8 @@ class IntegrationTestFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()}
+ for obj in self.objects.all()
]
@unittest.skipUnless(django_filters, 'django-filters not installed')
@@ -78,24 +79,24 @@ class IntegrationTestFiltering(TestCase):
# Basic test with no filter.
request = factory.get('/')
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
# Tests that the decimal filter works.
search_decimal = Decimal('2.25')
request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal]
- self.assertEquals(response.data, expected_data)
+ self.assertEqual(response.data, expected_data)
# Tests that the date filter works.
search_date = datetime.date(2012, 9, 22)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] == search_date]
- self.assertEquals(response.data, expected_data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() == search_date]
+ self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_filtered_class_root_view(self):
@@ -108,42 +109,43 @@ class IntegrationTestFiltering(TestCase):
# Basic test with no filter.
request = factory.get('/')
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
# Tests that the decimal filter set with 'lt' in the filter class works.
search_decimal = Decimal('4.25')
request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] < search_decimal]
- self.assertEquals(response.data, expected_data)
+ self.assertEqual(response.data, expected_data)
# Tests that the date filter set with 'gt' in the filter class works.
search_date = datetime.date(2012, 10, 2)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date]
- self.assertEquals(response.data, expected_data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date]
+ self.assertEqual(response.data, expected_data)
# Tests that the text filter set with 'icontains' in the filter class works.
search_text = 'ff'
request = factory.get('/?text=%s' % search_text)
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if search_text in f['text'].lower()]
- self.assertEquals(response.data, expected_data)
+ self.assertEqual(response.data, expected_data)
# Tests that multiple filters works.
search_decimal = Decimal('5.25')
search_date = datetime.date(2012, 10, 2)
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date and
- f['decimal'] < search_decimal]
- self.assertEquals(response.data, expected_data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if
+ datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date and
+ f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed')
def test_incorrectly_configured_filter(self):
@@ -165,4 +167,4 @@ class IntegrationTestFiltering(TestCase):
search_integer = 10
request = factory.get('/?integer=%s' % search_integer)
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py
index bc7378e1..c38bfb9f 100644
--- a/rest_framework/tests/genericrelations.py
+++ b/rest_framework/tests/genericrelations.py
@@ -1,25 +1,62 @@
+from __future__ import unicode_literals
+from django.contrib.contenttypes.models import ContentType
+from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
+from django.db import models
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import *
+
+
+class Tag(models.Model):
+ """
+ Tags have a descriptive slug, and are attached to an arbitrary object.
+ """
+ tag = models.SlugField()
+ content_type = models.ForeignKey(ContentType)
+ object_id = models.PositiveIntegerField()
+ tagged_item = GenericForeignKey('content_type', 'object_id')
+
+ def __unicode__(self):
+ return self.tag
+
+
+class Bookmark(models.Model):
+ """
+ A URL bookmark that may have multiple tags attached.
+ """
+ url = models.URLField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Bookmark: %s' % self.url
+
+
+class Note(models.Model):
+ """
+ A textual note that may have multiple tags attached.
+ """
+ text = models.TextField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Note: %s' % self.text
class TestGenericRelations(TestCase):
def setUp(self):
- bookmark = Bookmark(url='https://www.djangoproject.com/')
- bookmark.save()
- django = Tag(tag_name='django')
- django.save()
- python = Tag(tag_name='python')
- python.save()
- t1 = TaggedItem(content_object=bookmark, tag=django)
- t1.save()
- t2 = TaggedItem(content_object=bookmark, tag=python)
- t2.save()
- self.bookmark = bookmark
-
- def test_reverse_generic_relation(self):
+ self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
+ Tag.objects.create(tagged_item=self.bookmark, tag='django')
+ Tag.objects.create(tagged_item=self.bookmark, tag='python')
+ self.note = Note.objects.create(text='Remember the milk')
+ Tag.objects.create(tagged_item=self.note, tag='reminder')
+
+ def test_generic_relation(self):
+ """
+ Test a relationship that spans a GenericRelation field.
+ IE. A reverse generic relationship.
+ """
+
class BookmarkSerializer(serializers.ModelSerializer):
- tags = serializers.ManyRelatedField(source='tags')
+ tags = serializers.RelatedField(many=True)
class Meta:
model = Bookmark
@@ -27,7 +64,37 @@ class TestGenericRelations(TestCase):
serializer = BookmarkSerializer(self.bookmark)
expected = {
- 'tags': [u'django', u'python'],
- 'url': u'https://www.djangoproject.com/'
+ 'tags': ['django', 'python'],
+ 'url': 'https://www.djangoproject.com/'
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_generic_fk(self):
+ """
+ Test a relationship that spans a GenericForeignKey field.
+ IE. A forward generic relationship.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ tagged_item = serializers.RelatedField()
+
+ class Meta:
+ model = Tag
+ exclude = ('id', 'content_type', 'object_id')
+
+ serializer = TagSerializer(Tag.objects.all(), many=True)
+ expected = [
+ {
+ 'tag': 'django',
+ 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': 'python',
+ 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': 'reminder',
+ 'tagged_item': 'Note: Remember the milk'
}
- self.assertEquals(serializer.data, expected)
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index 4799a04b..f564890c 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -1,10 +1,11 @@
-import json
+from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
from rest_framework import generics, serializers, status
from rest_framework.tests.utils import RequestFactory
from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
-
+from rest_framework.compat import six
+import json
factory = RequestFactory()
@@ -42,7 +43,7 @@ class SlugBasedInstanceView(InstanceView):
class TestRootView(TestCase):
def setUp(self):
"""
- Create 3 BasicModel intances.
+ Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
for item in items:
@@ -59,9 +60,10 @@ class TestRootView(TestCase):
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data)
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
def test_post_root_view(self):
"""
@@ -70,11 +72,12 @@ class TestRootView(TestCase):
content = {'text': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_201_CREATED)
- self.assertEquals(response.data, {'id': 4, 'text': u'foobar'})
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
created = self.objects.get(id=4)
- self.assertEquals(created.text, 'foobar')
+ self.assertEqual(created.text, 'foobar')
def test_put_root_view(self):
"""
@@ -83,25 +86,28 @@ class TestRootView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEquals(response.data, {"detail": "Method 'PUT' not allowed."})
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."})
def test_delete_root_view(self):
"""
DELETE requests to ListCreateAPIView should not be allowed
"""
request = factory.delete('/')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEquals(response.data, {"detail": "Method 'DELETE' not allowed."})
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."})
def test_options_root_view(self):
"""
OPTIONS requests to ListCreateAPIView should return metadata
"""
request = factory.options('/')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
expected = {
'parses': [
'application/json',
@@ -115,8 +121,8 @@ class TestRootView(TestCase):
'name': 'Root',
'description': 'Example description for OPTIONS.'
}
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, expected)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, expected)
def test_post_cannot_set_id(self):
"""
@@ -125,11 +131,12 @@ class TestRootView(TestCase):
content = {'id': 999, 'text': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_201_CREATED)
- self.assertEquals(response.data, {'id': 4, 'text': u'foobar'})
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
created = self.objects.get(id=4)
- self.assertEquals(created.text, 'foobar')
+ self.assertEqual(created.text, 'foobar')
class TestInstanceView(TestCase):
@@ -153,9 +160,10 @@ class TestInstanceView(TestCase):
GET requests to RetrieveUpdateDestroyAPIView should return a single object.
"""
request = factory.get('/1')
- response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data[0])
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
def test_post_instance_view(self):
"""
@@ -164,9 +172,10 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEquals(response.data, {"detail": "Method 'POST' not allowed."})
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."})
def test_put_instance_view(self):
"""
@@ -175,11 +184,12 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk='1').render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ with self.assertNumQueries(2):
+ response = self.view(request, pk='1').render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
- self.assertEquals(updated.text, 'foobar')
+ self.assertEqual(updated.text, 'foobar')
def test_patch_instance_view(self):
"""
@@ -189,29 +199,32 @@ class TestInstanceView(TestCase):
request = factory.patch('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
- self.assertEquals(updated.text, 'foobar')
+ self.assertEqual(updated.text, 'foobar')
def test_delete_instance_view(self):
"""
DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
"""
request = factory.delete('/1')
- response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_204_NO_CONTENT)
- self.assertEquals(response.content, '')
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+ self.assertEqual(response.content, six.b(''))
ids = [obj.id for obj in self.objects.all()]
- self.assertEquals(ids, [2, 3])
+ self.assertEqual(ids, [2, 3])
def test_options_instance_view(self):
"""
OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
"""
request = factory.options('/')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
expected = {
'parses': [
'application/json',
@@ -225,8 +238,8 @@ class TestInstanceView(TestCase):
'name': 'Instance',
'description': 'Example description for OPTIONS.'
}
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, expected)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, expected)
def test_put_cannot_set_id(self):
"""
@@ -235,11 +248,12 @@ class TestInstanceView(TestCase):
content = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
- self.assertEquals(updated.text, 'foobar')
+ self.assertEqual(updated.text, 'foobar')
def test_put_to_deleted_instance(self):
"""
@@ -250,11 +264,12 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_201_CREATED)
- self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ with self.assertNumQueries(3):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
- self.assertEquals(updated.text, 'foobar')
+ self.assertEqual(updated.text, 'foobar')
def test_put_as_create_on_id_based_url(self):
"""
@@ -262,13 +277,14 @@ class TestInstanceView(TestCase):
at the requested url if it doesn't exist.
"""
content = {'text': 'foobar'}
- # pk fields can not be created on demand, only the database can set th pk for a new object
+ # pk fields can not be created on demand, only the database can set the pk for a new object
request = factory.put('/5', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=5).render()
- self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ with self.assertNumQueries(3):
+ response = self.view(request, pk=5).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
new_obj = self.objects.get(pk=5)
- self.assertEquals(new_obj.text, 'foobar')
+ self.assertEqual(new_obj.text, 'foobar')
def test_put_as_create_on_slug_based_url(self):
"""
@@ -278,11 +294,12 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/test_slug', json.dumps(content),
content_type='application/json')
- response = self.slug_based_view(request, slug='test_slug').render()
- self.assertEquals(response.status_code, status.HTTP_201_CREATED)
- self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'})
+ with self.assertNumQueries(2):
+ response = self.slug_based_view(request, slug='test_slug').render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'})
new_obj = SlugBasedModel.objects.get(slug='test_slug')
- self.assertEquals(new_obj.text, 'foobar')
+ self.assertEqual(new_obj.text, 'foobar')
# Regression test for #285
@@ -313,12 +330,12 @@ class TestCreateModelWithAutoNowAddField(TestCase):
request = factory.post('/', json.dumps(content),
content_type='application/json')
response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
created = self.objects.get(id=1)
- self.assertEquals(created.content, 'foobar')
+ self.assertEqual(created.content, 'foobar')
-# Test for particularly ugly reression with m2m in browseable API
+# Test for particularly ugly regression with m2m in browseable API
class ClassB(models.Model):
name = models.CharField(max_length=255)
@@ -329,7 +346,7 @@ class ClassA(models.Model):
class ClassASerializer(serializers.ModelSerializer):
- childs = serializers.ManyPrimaryKeyRelatedField(source='childs')
+ childs = serializers.PrimaryKeyRelatedField(many=True, source='childs')
class Meta:
model = ClassA
@@ -343,9 +360,84 @@ class ExampleView(generics.ListCreateAPIView):
class TestM2MBrowseableAPI(TestCase):
def test_m2m_in_browseable_api(self):
"""
- Test for particularly ugly reression with m2m in browseable API
+ Test for particularly ugly regression with m2m in browseable API
"""
request = factory.get('/', HTTP_ACCEPT='text/html')
view = ExampleView().as_view()
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class InclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='foo')
+
+
+class ExclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='other')
+
+
+class TestFilterBackendAppliedToViews(TestCase):
+
+ def setUp(self):
+ """
+ Create 3 BasicModel instances to filter on.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.root_view = RootView.as_view()
+ self.instance_view = InstanceView.as_view()
+ self.original_root_backend = getattr(RootView, 'filter_backend')
+ self.original_instance_backend = getattr(InstanceView, 'filter_backend')
+
+ def tearDown(self):
+ setattr(RootView, 'filter_backend', self.original_root_backend)
+ setattr(InstanceView, 'filter_backend', self.original_instance_backend)
+
+ def test_get_root_view_filters_by_name_with_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ setattr(RootView, 'filter_backend', InclusiveFilterBackend)
+ request = factory.get('/')
+ response = self.root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(len(response.data), 1)
+ self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
+
+ def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return empty list when all models are filtered out.
+ """
+ setattr(RootView, 'filter_backend', ExclusiveFilterBackend)
+ request = factory.get('/')
+ response = self.root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, [])
+
+ def test_get_instance_view_filters_out_name_with_filter_backend(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
+ """
+ setattr(InstanceView, 'filter_backend', ExclusiveFilterBackend)
+ request = factory.get('/1')
+ response = self.instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.data, {'detail': 'Not found'})
+
+ def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
+ """
+ setattr(InstanceView, 'filter_backend', InclusiveFilterBackend)
+ request = factory.get('/1')
+ response = self.instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py
index 54096206..8f2e2b5a 100644
--- a/rest_framework/tests/htmlrenderer.py
+++ b/rest_framework/tests/htmlrenderer.py
@@ -1,12 +1,15 @@
+from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.test import TestCase
from django.template import TemplateDoesNotExist, Template
import django.template.loader
+from rest_framework import status
from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view, renderer_classes
from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response
+from rest_framework.compat import six
@api_view(('GET',))
@@ -63,19 +66,19 @@ class TemplateHTMLRendererTests(TestCase):
def test_simple_html_view(self):
response = self.client.get('/')
self.assertContains(response, "example: foobar")
- self.assertEquals(response['Content-Type'], 'text/html')
+ self.assertEqual(response['Content-Type'], 'text/html')
def test_not_found_html_view(self):
response = self.client.get('/not_found')
- self.assertEquals(response.status_code, 404)
- self.assertEquals(response.content, "404 Not Found")
- self.assertEquals(response['Content-Type'], 'text/html')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.content, six.b("404 Not Found"))
+ self.assertEqual(response['Content-Type'], 'text/html')
def test_permission_denied_html_view(self):
response = self.client.get('/permission_denied')
- self.assertEquals(response.status_code, 403)
- self.assertEquals(response.content, "403 Forbidden")
- self.assertEquals(response['Content-Type'], 'text/html')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.content, six.b("403 Forbidden"))
+ self.assertEqual(response['Content-Type'], 'text/html')
class TemplateHTMLRendererExceptionTests(TestCase):
@@ -104,12 +107,12 @@ class TemplateHTMLRendererExceptionTests(TestCase):
def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found')
- self.assertEquals(response.status_code, 404)
- self.assertEquals(response.content, "404: Not found")
- self.assertEquals(response['Content-Type'], 'text/html')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.content, six.b("404: Not found"))
+ self.assertEqual(response['Content-Type'], 'text/html')
def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied')
- self.assertEquals(response.status_code, 403)
- self.assertEquals(response.content, "403: Permission denied")
- self.assertEquals(response['Content-Type'], 'text/html')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.content, six.b("403: Permission denied"))
+ self.assertEqual(response['Content-Type'], 'text/html')
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index c6a8224b..9a61f299 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
import json
from django.test import TestCase
from django.test.client import RequestFactory
@@ -99,7 +100,7 @@ class TestBasicHyperlinkedView(TestCase):
def setUp(self):
"""
- Create 3 BasicModel intances.
+ Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
for item in items:
@@ -118,8 +119,8 @@ class TestBasicHyperlinkedView(TestCase):
"""
request = factory.get('/basic/')
response = self.list_view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
def test_get_detail_view(self):
"""
@@ -127,8 +128,8 @@ class TestBasicHyperlinkedView(TestCase):
"""
request = factory.get('/basic/1')
response = self.detail_view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data[0])
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
class TestManyToManyHyperlinkedView(TestCase):
@@ -136,7 +137,7 @@ class TestManyToManyHyperlinkedView(TestCase):
def setUp(self):
"""
- Create 3 BasicModel intances.
+ Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
anchors = []
@@ -166,8 +167,8 @@ class TestManyToManyHyperlinkedView(TestCase):
"""
request = factory.get('/manytomany/')
response = self.list_view(request)
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
def test_get_detail_view(self):
"""
@@ -175,8 +176,8 @@ class TestManyToManyHyperlinkedView(TestCase):
"""
request = factory.get('/manytomany/1/')
response = self.detail_view(request, pk=1)
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data[0])
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
class TestCreateWithForeignKeys(TestCase):
@@ -234,7 +235,7 @@ class TestOptionalRelationHyperlinkedView(TestCase):
def setUp(self):
"""
- Create 1 OptionalRelationModel intances.
+ Create 1 OptionalRelationModel instances.
"""
OptionalRelationModel().save()
self.objects = OptionalRelationModel.objects
@@ -248,8 +249,8 @@ class TestOptionalRelationHyperlinkedView(TestCase):
"""
request = factory.get('/optionalrelationmodel-detail/1')
response = self.detail_view(request, pk=1)
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
def test_put_detail_view(self):
"""
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 93f09761..f2117538 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -1,35 +1,6 @@
+from __future__ import unicode_literals
from django.db import models
-from django.contrib.contenttypes.models import ContentType
-from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation
-# from django.contrib.auth.models import Group
-
-
-# class CustomUser(models.Model):
-# """
-# A custom user model, which uses a 'through' table for the foreign key
-# """
-# username = models.CharField(max_length=255, unique=True)
-# groups = models.ManyToManyField(
-# to=Group, blank=True, null=True, through='UserGroupMap'
-# )
-
-# @models.permalink
-# def get_absolute_url(self):
-# return ('custom_user', (), {
-# 'pk': self.id
-# })
-
-
-# class UserGroupMap(models.Model):
-# user = models.ForeignKey(to=CustomUser)
-# group = models.ForeignKey(to=Group)
-
-# @models.permalink
-# def get_absolute_url(self):
-# return ('user_group_map', (), {
-# 'pk': self.id
-# })
def foobar():
return 'foobar'
@@ -86,27 +57,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
text = models.CharField(max_length=100, default='anchor')
rel = models.ManyToManyField(Anchor)
-# Models to test generic relations
-
-
-class Tag(RESTFrameworkModel):
- tag_name = models.SlugField()
-
-
-class TaggedItem(RESTFrameworkModel):
- tag = models.ForeignKey(Tag, related_name='items')
- content_type = models.ForeignKey(ContentType)
- object_id = models.PositiveIntegerField()
- content_object = GenericForeignKey('content_type', 'object_id')
-
- def __unicode__(self):
- return self.tag.tag_name
-
-
-class Bookmark(RESTFrameworkModel):
- url = models.URLField()
- tags = GenericRelation(TaggedItem)
-
# Model to test filtering.
class FilterableItem(RESTFrameworkModel):
diff --git a/rest_framework/tests/modelviews.py b/rest_framework/tests/modelviews.py
deleted file mode 100644
index f12e3b97..00000000
--- a/rest_framework/tests/modelviews.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# from rest_framework.compat import patterns, url
-# from django.forms import ModelForm
-# from django.contrib.auth.models import Group, User
-# from rest_framework.resources import ModelResource
-# from rest_framework.views import ListOrCreateModelView, InstanceModelView
-# from rest_framework.tests.models import CustomUser
-# from rest_framework.tests.testcases import TestModelsTestCase
-
-
-# class GroupResource(ModelResource):
-# model = Group
-
-
-# class UserForm(ModelForm):
-# class Meta:
-# model = User
-# exclude = ('last_login', 'date_joined')
-
-
-# class UserResource(ModelResource):
-# model = User
-# form = UserForm
-
-
-# class CustomUserResource(ModelResource):
-# model = CustomUser
-
-# urlpatterns = patterns('',
-# url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'),
-# url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)),
-# url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'),
-# url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)),
-# url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'),
-# url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)),
-# )
-
-
-# class ModelViewTests(TestModelsTestCase):
-# """Test the model views rest_framework provides"""
-# urls = 'rest_framework.tests.modelviews'
-
-# def test_creation(self):
-# """Ensure that a model object can be created"""
-# self.assertEqual(0, Group.objects.count())
-
-# response = self.client.post('/groups/', {'name': 'foo'})
-
-# self.assertEqual(response.status_code, 201)
-# self.assertEqual(1, Group.objects.count())
-# self.assertEqual('foo', Group.objects.all()[0].name)
-
-# def test_creation_with_m2m_relation(self):
-# """Ensure that a model object with a m2m relation can be created"""
-# group = Group(name='foo')
-# group.save()
-# self.assertEqual(0, User.objects.count())
-
-# response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]})
-
-# self.assertEqual(response.status_code, 201)
-# self.assertEqual(1, User.objects.count())
-
-# user = User.objects.all()[0]
-# self.assertEqual('bar', user.username)
-# self.assertEqual('baz', user.password)
-# self.assertEqual(1, user.groups.count())
-
-# group = user.groups.all()[0]
-# self.assertEqual('foo', group.name)
-
-# def test_creation_with_m2m_relation_through(self):
-# """
-# Ensure that a model object with a m2m relation can be created where that
-# relation uses a through table
-# """
-# group = Group(name='foo')
-# group.save()
-# self.assertEqual(0, User.objects.count())
-
-# response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]})
-
-# self.assertEqual(response.status_code, 201)
-# self.assertEqual(1, CustomUser.objects.count())
-
-# user = CustomUser.objects.all()[0]
-# self.assertEqual('bar', user.username)
-# self.assertEqual(1, user.groups.count())
-
-# group = user.groups.all()[0]
-# self.assertEqual('foo', group.name)
diff --git a/rest_framework/tests/multitable_inheritance.py b/rest_framework/tests/multitable_inheritance.py
new file mode 100644
index 00000000..00c15327
--- /dev/null
+++ b/rest_framework/tests/multitable_inheritance.py
@@ -0,0 +1,67 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import RESTFrameworkModel
+
+
+# Models
+class ParentModel(RESTFrameworkModel):
+ name1 = models.CharField(max_length=100)
+
+
+class ChildModel(ParentModel):
+ name2 = models.CharField(max_length=100)
+
+
+class AssociatedModel(RESTFrameworkModel):
+ ref = models.OneToOneField(ParentModel, primary_key=True)
+ name = models.CharField(max_length=100)
+
+
+# Serializers
+class DerivedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChildModel
+
+
+class AssociatedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = AssociatedModel
+
+
+# Tests
+class IneritedModelSerializationTests(TestCase):
+
+ def test_multitable_inherited_model_fields_as_expected(self):
+ """
+ Assert that the parent pointer field is not included in the fields
+ serialized fields
+ """
+ child = ChildModel(name1='parent name', name2='child name')
+ serializer = DerivedModelSerializer(child)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name1', 'name2', 'id']))
+
+ def test_onetoone_primary_key_model_fields_as_expected(self):
+ """
+ Assert that a model with a onetoone field that is the primary key is
+ not treated like a derived model
+ """
+ parent = ParentModel(name1='parent name')
+ associate = AssociatedModel(name='hello', ref=parent)
+ serializer = AssociatedModelSerializer(associate)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name', 'ref']))
+
+ def test_data_is_valid_without_parent_ptr(self):
+ """
+ Assert that the pointer to the parent table is not a required field
+ for input data
+ """
+ data = {
+ 'name1': 'parent name',
+ 'name2': 'child name',
+ }
+ serializer = DerivedModelSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/negotiation.py
index e06354ea..43721b84 100644
--- a/rest_framework/tests/negotiation.py
+++ b/rest_framework/tests/negotiation.py
@@ -1,6 +1,9 @@
+from __future__ import unicode_literals
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework.negotiation import DefaultContentNegotiation
+from rest_framework.request import Request
+
factory = RequestFactory()
@@ -22,16 +25,16 @@ class TestAcceptedMediaType(TestCase):
return self.negotiator.select_renderer(request, self.renderers)
def test_client_without_accept_use_renderer(self):
- request = factory.get('/')
+ request = Request(factory.get('/'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
- self.assertEquals(accepted_media_type, 'application/json')
+ self.assertEqual(accepted_media_type, 'application/json')
def test_client_underspecifies_accept_use_renderer(self):
- request = factory.get('/', HTTP_ACCEPT='*/*')
+ request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
- self.assertEquals(accepted_media_type, 'application/json')
+ self.assertEqual(accepted_media_type, 'application/json')
def test_client_overspecifies_accept_use_client(self):
- request = factory.get('/', HTTP_ACCEPT='application/json; indent=8')
+ request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
- self.assertEquals(accepted_media_type, 'application/json; indent=8')
+ self.assertEqual(accepted_media_type, 'application/json; indent=8')
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 3b550877..1a2d68a6 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -1,5 +1,7 @@
+from __future__ import unicode_literals
import datetime
from decimal import Decimal
+import django
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
@@ -19,21 +21,6 @@ class RootView(generics.ListCreateAPIView):
paginate_by = 10
-if django_filters:
- class DecimalFilter(django_filters.FilterSet):
- decimal = django_filters.NumberFilter(lookup_type='lt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- paginate_by = 10
- filter_class = DecimalFilter
- filter_backend = filters.DjangoFilterBackend
-
-
class DefaultPageSizeKwargView(generics.ListAPIView):
"""
View for testing default paginate_by_param usage
@@ -72,28 +59,32 @@ class IntegrationTestPagination(TestCase):
GET requests to paginated ListCreateAPIView should return paginated results.
"""
request = factory.get('/')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data['count'], 26)
- self.assertEquals(response.data['results'], self.data[:10])
- self.assertNotEquals(response.data['next'], None)
- self.assertEquals(response.data['previous'], None)
+ # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 26)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next'])
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data['count'], 26)
- self.assertEquals(response.data['results'], self.data[10:20])
- self.assertNotEquals(response.data['next'], None)
- self.assertNotEquals(response.data['previous'], None)
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 26)
+ self.assertEqual(response.data['results'], self.data[10:20])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['next'])
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data['count'], 26)
- self.assertEquals(response.data['results'], self.data[20:])
- self.assertEquals(response.data['next'], None)
- self.assertNotEquals(response.data['previous'], None)
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 26)
+ self.assertEqual(response.data['results'], self.data[20:])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
class IntegrationTestPaginationAndFiltering(TestCase):
@@ -111,41 +102,115 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()}
+ for obj in self.objects.all()
]
- self.view = FilterFieldsRootView.as_view()
@unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_get_paginated_filtered_root_view(self):
+ def test_get_django_filter_paginated_filtered_root_view(self):
"""
GET requests to paginated filtered ListCreateAPIView should return
paginated results. The next and previous links should preserve the
filtered parameters.
"""
+ class DecimalFilter(django_filters.FilterSet):
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_class = DecimalFilter
+ filter_backend = filters.DjangoFilterBackend
+
+ view = FilterFieldsRootView.as_view()
+
+ EXPECTED_NUM_QUERIES = 2
+ if django.VERSION < (1, 4):
+ # On Django 1.3 we need to use django-filter 0.5.4
+ #
+ # The filter objects there don't expose a `.count()` method,
+ # which means we only make a single query *but* it's a single
+ # query across *all* of the queryset, instead of a COUNT and then
+ # a SELECT with a LIMIT.
+ #
+ # Although this is fewer queries, it's actually a regression.
+ EXPECTED_NUM_QUERIES = 1
+
request = factory.get('/?decimal=15.20')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data['count'], 15)
- self.assertEquals(response.data['results'], self.data[:10])
- self.assertNotEquals(response.data['next'], None)
- self.assertEquals(response.data['previous'], None)
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next'])
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data['count'], 15)
- self.assertEquals(response.data['results'], self.data[10:15])
- self.assertEquals(response.data['next'], None)
- self.assertNotEquals(response.data['previous'], None)
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[10:15])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['previous'])
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data['count'], 15)
- self.assertEquals(response.data['results'], self.data[:10])
- self.assertNotEquals(response.data['next'], None)
- self.assertEquals(response.data['previous'], None)
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ def test_get_basic_paginated_filtered_root_view(self):
+ """
+ Same as `test_get_django_filter_paginated_filtered_root_view`,
+ except using a custom filter backend instead of the django-filter
+ backend,
+ """
+
+ class DecimalFilterBackend(filters.BaseFilterBackend):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(decimal__lt=Decimal(request.GET['decimal']))
+
+ class BasicFilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_backend = DecimalFilterBackend
+
+ view = BasicFilterFieldsRootView.as_view()
+
+ request = factory.get('/?decimal=15.20')
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[10:15])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
@@ -166,16 +231,16 @@ class UnitTestPagination(TestCase):
def test_native_pagination(self):
serializer = pagination.PaginationSerializer(self.first_page)
- self.assertEquals(serializer.data['count'], 26)
- self.assertEquals(serializer.data['next'], '?page=2')
- self.assertEquals(serializer.data['previous'], None)
- self.assertEquals(serializer.data['results'], self.objects[:10])
+ self.assertEqual(serializer.data['count'], 26)
+ self.assertEqual(serializer.data['next'], '?page=2')
+ self.assertEqual(serializer.data['previous'], None)
+ self.assertEqual(serializer.data['results'], self.objects[:10])
serializer = pagination.PaginationSerializer(self.last_page)
- self.assertEquals(serializer.data['count'], 26)
- self.assertEquals(serializer.data['next'], None)
- self.assertEquals(serializer.data['previous'], '?page=2')
- self.assertEquals(serializer.data['results'], self.objects[20:])
+ self.assertEqual(serializer.data['count'], 26)
+ self.assertEqual(serializer.data['next'], None)
+ self.assertEqual(serializer.data['previous'], '?page=2')
+ self.assertEqual(serializer.data['results'], self.objects[20:])
def test_context_available_in_result(self):
"""
@@ -184,7 +249,7 @@ class UnitTestPagination(TestCase):
serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
serializer.data
results = serializer.fields[serializer.results_field]
- self.assertEquals(serializer.context, results.context)
+ self.assertEqual(serializer.context, results.context)
class TestUnpaginated(TestCase):
@@ -212,7 +277,7 @@ class TestUnpaginated(TestCase):
"""
request = factory.get('/')
response = self.view(request)
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.data, self.data)
class TestCustomPaginateByParam(TestCase):
@@ -240,7 +305,7 @@ class TestCustomPaginateByParam(TestCase):
"""
request = factory.get('/')
response = self.view(request).render()
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.data, self.data)
def test_paginate_by_param(self):
"""
@@ -248,9 +313,11 @@ class TestCustomPaginateByParam(TestCase):
"""
request = factory.get('/?page_size=5')
response = self.view(request).render()
- self.assertEquals(response.data['count'], 13)
- self.assertEquals(response.data['results'], self.data[:5])
+ self.assertEqual(response.data['count'], 13)
+ self.assertEqual(response.data['results'], self.data[:5])
+
+### Tests for context in pagination serializers
class CustomField(serializers.Field):
def to_native(self, value):
@@ -262,6 +329,11 @@ class CustomField(serializers.Field):
class BasicModelSerializer(serializers.Serializer):
text = CustomField()
+ def __init__(self, *args, **kwargs):
+ super(BasicModelSerializer, self).__init__(*args, **kwargs)
+ if not 'view' in self.context:
+ raise RuntimeError("context isn't getting passed into serializer init")
+
class TestContextPassedToCustomField(TestCase):
def setUp(self):
@@ -277,5 +349,41 @@ class TestContextPassedToCustomField(TestCase):
request = factory.get('/')
response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+### Tests for custom pagination serializers
+
+class LinksSerializer(serializers.Serializer):
+ next = pagination.NextPageField(source='*')
+ prev = pagination.PreviousPageField(source='*')
+
+class CustomPaginationSerializer(pagination.BasePaginationSerializer):
+ links = LinksSerializer(source='*') # Takes the page object as the source
+ total_results = serializers.Field(source='paginator.count')
+
+ results_field = 'objects'
+
+
+class TestCustomPaginationSerializer(TestCase):
+ def setUp(self):
+ objects = ['john', 'paul', 'george', 'ringo']
+ paginator = Paginator(objects, 2)
+ self.page = paginator.page(1)
+
+ def test_custom_pagination_serializer(self):
+ request = RequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=self.page,
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': 'http://testserver/foobar?page=2',
+ 'prev': None
+ },
+ 'total_results': 4,
+ 'objects': ['john', 'paul']
+ }
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py
index 8ab8a52f..539c5b44 100644
--- a/rest_framework/tests/parsers.py
+++ b/rest_framework/tests/parsers.py
@@ -1,139 +1,9 @@
-# """
-# ..
-# >>> from rest_framework.parsers import FormParser
-# >>> from django.test.client import RequestFactory
-# >>> from rest_framework.views import View
-# >>> from StringIO import StringIO
-# >>> from urllib import urlencode
-# >>> req = RequestFactory().get('/')
-# >>> some_view = View()
-# >>> some_view.request = req # Make as if this request had been dispatched
-#
-# FormParser
-# ============
-#
-# Data flatening
-# ----------------
-#
-# Here is some example data, which would eventually be sent along with a post request :
-#
-# >>> inpt = urlencode([
-# ... ('key1', 'bla1'),
-# ... ('key2', 'blo1'), ('key2', 'blo2'),
-# ... ])
-#
-# Default behaviour for :class:`parsers.FormParser`, is to return a single value for each parameter :
-#
-# >>> (data, files) = FormParser(some_view).parse(StringIO(inpt))
-# >>> data == {'key1': 'bla1', 'key2': 'blo1'}
-# True
-#
-# However, you can customize this behaviour by subclassing :class:`parsers.FormParser`, and overriding :meth:`parsers.FormParser.is_a_list` :
-#
-# >>> class MyFormParser(FormParser):
-# ...
-# ... def is_a_list(self, key, val_list):
-# ... return len(val_list) > 1
-#
-# This new parser only flattens the lists of parameters that contain a single value.
-#
-# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
-# >>> data == {'key1': 'bla1', 'key2': ['blo1', 'blo2']}
-# True
-#
-# .. note:: The same functionality is available for :class:`parsers.MultiPartParser`.
-#
-# Submitting an empty list
-# --------------------------
-#
-# When submitting an empty select multiple, like this one ::
-#
-# <select multiple="multiple" name="key2"></select>
-#
-# The browsers usually strip the parameter completely. A hack to avoid this, and therefore being able to submit an empty select multiple, is to submit a value that tells the server that the list is empty ::
-#
-# <select multiple="multiple" name="key2"><option value="_empty"></select>
-#
-# :class:`parsers.FormParser` provides the server-side implementation for this hack. Considering the following posted data :
-#
-# >>> inpt = urlencode([
-# ... ('key1', 'blo1'), ('key1', '_empty'),
-# ... ('key2', '_empty'),
-# ... ])
-#
-# :class:`parsers.FormParser` strips the values ``_empty`` from all the lists.
-#
-# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
-# >>> data == {'key1': 'blo1'}
-# True
-#
-# Oh ... but wait a second, the parameter ``key2`` isn't even supposed to be a list, so the parser just stripped it.
-#
-# >>> class MyFormParser(FormParser):
-# ...
-# ... def is_a_list(self, key, val_list):
-# ... return key == 'key2'
-# ...
-# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
-# >>> data == {'key1': 'blo1', 'key2': []}
-# True
-#
-# Better like that. Note that you can configure something else than ``_empty`` for the empty value by setting :attr:`parsers.FormParser.EMPTY_VALUE`.
-# """
-# import httplib, mimetypes
-# from tempfile import TemporaryFile
-# from django.test import TestCase
-# from django.test.client import RequestFactory
-# from rest_framework.parsers import MultiPartParser
-# from rest_framework.views import View
-# from StringIO import StringIO
-#
-# def encode_multipart_formdata(fields, files):
-# """For testing multipart parser.
-# fields is a sequence of (name, value) elements for regular form fields.
-# files is a sequence of (name, filename, value) elements for data to be uploaded as files
-# Return (content_type, body)."""
-# BOUNDARY = '----------ThIs_Is_tHe_bouNdaRY_$'
-# CRLF = '\r\n'
-# L = []
-# for (key, value) in fields:
-# L.append('--' + BOUNDARY)
-# L.append('Content-Disposition: form-data; name="%s"' % key)
-# L.append('')
-# L.append(value)
-# for (key, filename, value) in files:
-# L.append('--' + BOUNDARY)
-# L.append('Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename))
-# L.append('Content-Type: %s' % get_content_type(filename))
-# L.append('')
-# L.append(value)
-# L.append('--' + BOUNDARY + '--')
-# L.append('')
-# body = CRLF.join(L)
-# content_type = 'multipart/form-data; boundary=%s' % BOUNDARY
-# return content_type, body
-#
-# def get_content_type(filename):
-# return mimetypes.guess_type(filename)[0] or 'application/octet-stream'
-#
-#class TestMultiPartParser(TestCase):
-# def setUp(self):
-# self.req = RequestFactory()
-# self.content_type, self.body = encode_multipart_formdata([('key1', 'val1'), ('key1', 'val2')],
-# [('file1', 'pic.jpg', 'blablabla'), ('file1', 't.txt', 'blobloblo')])
-#
-# def test_multipartparser(self):
-# """Ensure that MultiPartParser can parse multipart/form-data that contains a mix of several files and parameters."""
-# post_req = RequestFactory().post('/', self.body, content_type=self.content_type)
-# view = View()
-# view.request = post_req
-# (data, files) = MultiPartParser(view).parse(StringIO(self.body))
-# self.assertEqual(data['key1'], 'val1')
-# self.assertEqual(files['file1'].read(), 'blablabla')
-
-from StringIO import StringIO
+from __future__ import unicode_literals
+from rest_framework.compat import StringIO
from django import forms
from django.test import TestCase
+from django.utils import unittest
+from rest_framework.compat import etree
from rest_framework.parsers import FormParser
from rest_framework.parsers import XMLParser
import datetime
@@ -201,11 +71,13 @@ class TestXMLParser(TestCase):
]
}
+ @unittest.skipUnless(etree, 'defusedxml not installed')
def test_parse(self):
parser = XMLParser()
data = parser.parse(self._input)
self.assertEqual(data, self._data)
+ @unittest.skipUnless(etree, 'defusedxml not installed')
def test_complex_data_parse(self):
parser = XMLParser()
data = parser.parse(self._complex_data_input)
diff --git a/rest_framework/tests/permissions.py b/rest_framework/tests/permissions.py
new file mode 100644
index 00000000..b3993be5
--- /dev/null
+++ b/rest_framework/tests/permissions.py
@@ -0,0 +1,153 @@
+from __future__ import unicode_literals
+from django.contrib.auth.models import User, Permission
+from django.db import models
+from django.test import TestCase
+from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING
+from rest_framework.tests.utils import RequestFactory
+import base64
+import json
+
+factory = RequestFactory()
+
+
+class BasicModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class RootView(generics.ListCreateAPIView):
+ model = BasicModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [permissions.DjangoModelPermissions]
+
+
+class InstanceView(generics.RetrieveUpdateDestroyAPIView):
+ model = BasicModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [permissions.DjangoModelPermissions]
+
+root_view = RootView.as_view()
+instance_view = InstanceView.as_view()
+
+
+def basic_auth_header(username, password):
+ credentials = ('%s:%s' % (username, password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ return 'Basic %s' % base64_credentials
+
+
+class ModelPermissionsIntegrationTests(TestCase):
+ def setUp(self):
+ User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
+ user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
+ user.user_permissions = [
+ Permission.objects.get(codename='add_basicmodel'),
+ Permission.objects.get(codename='change_basicmodel'),
+ Permission.objects.get(codename='delete_basicmodel')
+ ]
+ user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
+ user.user_permissions = [
+ Permission.objects.get(codename='change_basicmodel'),
+ ]
+
+ self.permitted_credentials = basic_auth_header('permitted', 'password')
+ self.disallowed_credentials = basic_auth_header('disallowed', 'password')
+ self.updateonly_credentials = basic_auth_header('updateonly', 'password')
+
+ BasicModel(text='foo').save()
+
+ def test_has_create_permissions(self):
+ request = factory.post('/', json.dumps({'text': 'foobar'}),
+ content_type='application/json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+
+ def test_has_put_permissions(self):
+ request = factory.put('/1', json.dumps({'text': 'foobar'}),
+ content_type='application/json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_has_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_does_not_have_create_permissions(self):
+ request = factory.post('/', json.dumps({'text': 'foobar'}),
+ content_type='application/json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = root_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_does_not_have_put_permissions(self):
+ request = factory.put('/1', json.dumps({'text': 'foobar'}),
+ content_type='application/json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_does_not_have_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_has_put_as_create_permissions(self):
+ # User only has update permissions - should be able to update an entity.
+ request = factory.put('/1', json.dumps({'text': 'foobar'}),
+ content_type='application/json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ # But if PUTing to a new entity, permission should be denied.
+ request = factory.put('/2', json.dumps({'text': 'foobar'}),
+ content_type='application/json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='2')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+
+class OwnerModel(models.Model):
+ text = models.CharField(max_length=100)
+ owner = models.ForeignKey(User)
+
+
+class IsOwnerPermission(permissions.BasePermission):
+ def has_object_permission(self, request, view, obj):
+ return request.user == obj.owner
+
+
+class OwnerInstanceView(generics.RetrieveUpdateDestroyAPIView):
+ model = OwnerModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [IsOwnerPermission]
+
+
+owner_instance_view = OwnerInstanceView.as_view()
+
+
+class ObjectPermissionsIntegrationTests(TestCase):
+ """
+ Integration tests for the object level permissions API.
+ """
+
+ def setUp(self):
+ User.objects.create_user('not_owner', 'not_owner@example.com', 'password')
+ user = User.objects.create_user('owner', 'owner@example.com', 'password')
+
+ self.not_owner_credentials = basic_auth_header('not_owner', 'password')
+ self.owner_credentials = basic_auth_header('owner', 'password')
+
+ OwnerModel(text='foo', owner=user).save()
+
+ def test_owner_has_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.owner_credentials)
+ response = owner_instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_non_owner_does_not_have_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.not_owner_credentials)
+ response = owner_instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py
index 91daea8a..cbf93c65 100644
--- a/rest_framework/tests/relations.py
+++ b/rest_framework/tests/relations.py
@@ -1,7 +1,7 @@
"""
General tests for relational fields.
"""
-
+from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
from rest_framework import serializers
@@ -31,3 +31,17 @@ class FieldTests(TestCase):
field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+
+class TestManyRelateMixin(TestCase):
+ def test_missing_many_to_many_related_field(self):
+ '''
+ Regression test for #632
+
+ https://github.com/tomchristie/django-rest-framework/pull/632
+ '''
+ field = serializers.RelatedField(many=True, read_only=False)
+
+ into = {}
+ field.field_from_native({}, None, 'field_name', into)
+ self.assertEqual(into['field_name'], [])
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
index 57913670..b5702a48 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -1,7 +1,16 @@
+from __future__ import unicode_literals
from django.test import TestCase
+from django.test.client import RequestFactory
from rest_framework import serializers
from rest_framework.compat import patterns, url
-from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+from rest_framework.tests.models import (
+ ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
+ NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+)
+
+factory = RequestFactory()
+request = factory.get('/') # Just to ensure we have a request in the serializer context
+
def dummy_view(request, pk):
pass
@@ -16,8 +25,9 @@ urlpatterns = patterns('',
url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
)
+
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail')
+ sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail')
class Meta:
model = ManyToManyTarget
@@ -29,7 +39,7 @@ class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail')
+ sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail')
class Meta:
model = ForeignKeyTarget
@@ -70,98 +80,98 @@ class HyperlinkedManyToManyTests(TestCase):
def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset)
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']},
- {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
- {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset)
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
- {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
- {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_many_to_many_update(self):
- data = {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
instance = ManyToManySource.objects.get(pk=1)
- serializer = ManyToManySourceSerializer(instance, data=data)
+ serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
serializer.save()
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset)
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']},
- {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
- {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_update(self):
- data = {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']}
+ data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
instance = ManyToManyTarget.objects.get(pk=1)
- serializer = ManyToManyTargetSerializer(instance, data=data)
+ serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
serializer.save()
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset)
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']},
- {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
- {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_many_to_many_create(self):
- data = {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']}
- serializer = ManyToManySourceSerializer(data=data)
+ data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
+ serializer = ManyToManySourceSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset)
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']},
- {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
- {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']},
- {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']}
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
+ {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_create(self):
- data = {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']}
- serializer = ManyToManyTargetSerializer(data=data)
+ data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
+ serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'target-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset)
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
- {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
- {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']},
- {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']}
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
class HyperlinkedForeignKeyTests(TestCase):
@@ -178,111 +188,118 @@ class HyperlinkedForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
- {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
- {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update(self):
- data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'}
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'},
- {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected url string, received int.']})
def test_reverse_foreign_key_update(self):
- data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
+ data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
instance = ForeignKeyTarget.objects.get(pk=2)
- serializer = ForeignKeyTargetSerializer(instance, data=data)
+ serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset)
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
- {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
- ]
- self.assertEquals(new_serializer.data, expected)
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
serializer.save()
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']},
- {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_create(self):
- data = {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'}
- serializer = ForeignKeySourceSerializer(data=data)
+ data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
+ serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
- {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'},
- {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'},
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_create(self):
- data = {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
- serializer = ForeignKeyTargetSerializer(data=data)
+ data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
+ serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'target-3')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
# Ensure target 4 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']},
- {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
- {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self):
- data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': None}
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
+ self.assertEqual(serializer.errors, {'target': ['This field is required.']})
class HyperlinkedNullableForeignKeyTests(TestCase):
@@ -299,118 +316,118 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self):
- data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
+ data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
- {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': ''}
- expected_data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
+ data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
+ expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, expected_data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
- {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self):
- data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}
+ data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None},
- {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': ''}
- expected_data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}
+ data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
+ expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, expected_data)
+ self.assertEqual(serializer.data, expected_data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None},
- {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
# reverse foreign keys MUST be read_only
# In the general case they do not provide .remove() or .clear()
# and cannot be arbitrarily set.
# def test_reverse_foreign_key_update(self):
- # data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
# instance = ForeignKeyTarget.objects.get(pk=1)
# serializer = ForeignKeyTargetSerializer(instance, data=data)
# self.assertTrue(serializer.is_valid())
- # self.assertEquals(serializer.data, data)
+ # self.assertEqual(serializer.data, data)
# serializer.save()
# # Ensure target 1 is updated, and everything else is as expected
# queryset = ForeignKeyTarget.objects.all()
- # serializer = ForeignKeyTargetSerializer(queryset)
+ # serializer = ForeignKeyTargetSerializer(queryset, many=True)
# expected = [
- # {'id': 1, 'name': u'target-1', 'sources': [1]},
- # {'id': 2, 'name': u'target-2', 'sources': []},
+ # {'id': 1, 'name': 'target-1', 'sources': [1]},
+ # {'id': 2, 'name': 'target-2', 'sources': []},
# ]
- # self.assertEquals(serializer.data, expected)
+ # self.assertEqual(serializer.data, expected)
class HyperlinkedNullableOneToOneTests(TestCase):
@@ -426,9 +443,9 @@ class HyperlinkedNullableOneToOneTests(TestCase):
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset)
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'},
- {'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None},
+ {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
+ {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
index 0e129fae..a125ba65 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/relations_nested.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import serializers
from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
@@ -15,7 +16,7 @@ class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = FlatForeignKeySourceSerializer()
+ sources = FlatForeignKeySourceSerializer(many=True)
class Meta:
model = ForeignKeyTarget
@@ -51,27 +52,27 @@ class ReverseForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}},
- {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}},
- {'id': 3, 'name': u'source-3', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [
- {'id': 1, 'name': u'source-1', 'target': 1},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': 1},
+ {'id': 1, 'name': 'target-1', 'sources': [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1},
]},
- {'id': 2, 'name': u'target-2', 'sources': [
+ {'id': 2, 'name': 'target-2', 'sources': [
]}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
class NestedNullableForeignKeyTests(TestCase):
@@ -86,13 +87,13 @@ class NestedNullableForeignKeyTests(TestCase):
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}},
- {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}},
- {'id': 3, 'name': u'source-3', 'target': None},
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 3, 'name': 'source-3', 'target': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
class NestedNullableOneToOneTests(TestCase):
@@ -106,9 +107,9 @@ class NestedNullableOneToOneTests(TestCase):
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset)
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}},
- {'id': 2, 'name': u'target-2', 'nullable_source': None},
+ {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}},
+ {'id': 2, 'name': 'target-2', 'nullable_source': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
index 54835860..f08e1808 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/relations_pk.py
@@ -1,10 +1,12 @@
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import serializers
from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+from rest_framework.compat import six
class ManyToManyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.ManyPrimaryKeyRelatedField()
+ sources = serializers.PrimaryKeyRelatedField(many=True)
class Meta:
model = ManyToManyTarget
@@ -16,7 +18,7 @@ class ManyToManySourceSerializer(serializers.ModelSerializer):
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.ManyPrimaryKeyRelatedField()
+ sources = serializers.PrimaryKeyRelatedField(many=True)
class Meta:
model = ForeignKeyTarget
@@ -54,97 +56,97 @@ class PKManyToManyTests(TestCase):
def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset)
+ serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'targets': [1]},
- {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}
+ {'id': 1, 'name': 'source-1', 'targets': [1]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset)
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': u'target-3', 'sources': [3]}
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_many_to_many_update(self):
- data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}
+ data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
serializer.save()
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset)
+ serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]},
- {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}
+ {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_update(self):
- data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ data = {'id': 1, 'name': 'target-1', 'sources': [1]}
instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
serializer.save()
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset)
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [1]},
- {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': u'target-3', 'sources': [3]}
+ {'id': 1, 'name': 'target-1', 'sources': [1]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_many_to_many_create(self):
- data = {'id': 4, 'name': u'source-4', 'targets': [1, 3]}
+ data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
serializer = ManyToManySourceSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset)
+ serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'targets': [1]},
- {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]},
- {'id': 4, 'name': u'source-4', 'targets': [1, 3]},
+ {'id': 1, 'name': 'source-1', 'targets': [1]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
+ {'id': 4, 'name': 'source-4', 'targets': [1, 3]},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_create(self):
- data = {'id': 4, 'name': u'target-4', 'sources': [1, 3]}
+ data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
serializer = ManyToManyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'target-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset)
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': u'target-3', 'sources': [3]},
- {'id': 4, 'name': u'target-4', 'sources': [1, 3]}
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]},
+ {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
class PKForeignKeyTests(TestCase):
@@ -159,111 +161,118 @@ class PKForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': 1},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': 1}
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': u'target-2', 'sources': []},
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update(self):
- data = {'id': 1, 'name': u'source-1', 'target': 2}
+ data = {'id': 1, 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': 2},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': 1}
+ {'id': 1, 'name': 'source-1', 'target': 2},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]})
def test_reverse_foreign_key_update(self):
- data = {'id': 2, 'name': u'target-2', 'sources': [1, 3]}
+ data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset)
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': u'target-2', 'sources': []},
- ]
- self.assertEquals(new_serializer.data, expected)
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
serializer.save()
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [2]},
- {'id': 2, 'name': u'target-2', 'sources': [1, 3]},
+ {'id': 1, 'name': 'target-1', 'sources': [2]},
+ {'id': 2, 'name': 'target-2', 'sources': [1, 3]},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_create(self):
- data = {'id': 4, 'name': u'source-4', 'target': 2}
+ data = {'id': 4, 'name': 'source-4', 'target': 2}
serializer = ForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': 1},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': 1},
- {'id': 4, 'name': u'source-4', 'target': 2},
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1},
+ {'id': 4, 'name': 'source-4', 'target': 2},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_create(self):
- data = {'id': 3, 'name': u'target-3', 'sources': [1, 3]}
+ data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
serializer = ForeignKeyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'target-3')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
# Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [2]},
- {'id': 2, 'name': u'target-2', 'sources': []},
- {'id': 3, 'name': u'target-3', 'sources': [1, 3]},
+ {'id': 1, 'name': 'target-1', 'sources': [2]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ {'id': 3, 'name': 'target-3', 'sources': [1, 3]},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self):
- data = {'id': 1, 'name': u'source-1', 'target': None}
+ data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
+ self.assertEqual(serializer.errors, {'target': ['This field is required.']})
class PKNullableForeignKeyTests(TestCase):
@@ -278,118 +287,118 @@ class PKNullableForeignKeyTests(TestCase):
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': 1},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': None},
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self):
- data = {'id': 4, 'name': u'source-4', 'target': None}
+ data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': 1},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': None},
- {'id': 4, 'name': u'source-4', 'target': None}
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'id': 4, 'name': u'source-4', 'target': ''}
- expected_data = {'id': 4, 'name': u'source-4', 'target': None}
+ data = {'id': 4, 'name': 'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, expected_data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': 1},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': None},
- {'id': 4, 'name': u'source-4', 'target': None}
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self):
- data = {'id': 1, 'name': u'source-1', 'target': None}
+ data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': None},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': None}
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'id': 1, 'name': u'source-1', 'target': ''}
- expected_data = {'id': 1, 'name': u'source-1', 'target': None}
+ data = {'id': 1, 'name': 'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, expected_data)
+ self.assertEqual(serializer.data, expected_data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': None},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': None}
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
# reverse foreign keys MUST be read_only
# In the general case they do not provide .remove() or .clear()
# and cannot be arbitrarily set.
# def test_reverse_foreign_key_update(self):
- # data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
# instance = ForeignKeyTarget.objects.get(pk=1)
# serializer = ForeignKeyTargetSerializer(instance, data=data)
# self.assertTrue(serializer.is_valid())
- # self.assertEquals(serializer.data, data)
+ # self.assertEqual(serializer.data, data)
# serializer.save()
# # Ensure target 1 is updated, and everything else is as expected
# queryset = ForeignKeyTarget.objects.all()
- # serializer = ForeignKeyTargetSerializer(queryset)
+ # serializer = ForeignKeyTargetSerializer(queryset, many=True)
# expected = [
- # {'id': 1, 'name': u'target-1', 'sources': [1]},
- # {'id': 2, 'name': u'target-2', 'sources': []},
+ # {'id': 1, 'name': 'target-1', 'sources': [1]},
+ # {'id': 2, 'name': 'target-2', 'sources': []},
# ]
- # self.assertEquals(serializer.data, expected)
+ # self.assertEqual(serializer.data, expected)
class PKNullableOneToOneTests(TestCase):
@@ -398,14 +407,14 @@ class PKNullableOneToOneTests(TestCase):
target.save()
new_target = OneToOneTarget(name='target-2')
new_target.save()
- source = NullableOneToOneSource(name='source-1', target=target)
+ source = NullableOneToOneSource(name='source-1', target=new_target)
source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset)
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'nullable_source': 1},
- {'id': 2, 'name': u'target-2', 'nullable_source': None},
+ {'id': 1, 'name': 'target-1', 'nullable_source': None},
+ {'id': 2, 'name': 'target-2', 'nullable_source': 1},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/relations_slug.py
new file mode 100644
index 00000000..435c821c
--- /dev/null
+++ b/rest_framework/tests/relations_slug.py
@@ -0,0 +1,257 @@
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
+
+
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.SlugRelatedField(many=True, slug_field='name')
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(slug_field='name')
+
+ class Meta:
+ model = ForeignKeySource
+
+
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(slug_field='name', required=False)
+
+ class Meta:
+ model = NullableForeignKeySource
+
+
+# TODO: M2M Tests, FKTests (Non-nullable), One2One
+class SlugForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-2'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 123}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
+ serializer = ForeignKeySourceSerializer(data=data)
+ serializer.is_valid()
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'},
+ {'id': 4, 'name': 'source-4', 'target': 'target-2'},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['This field is required.']})
+
+
+class SlugNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 4, 'name': 'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 1, 'name': 'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py
index c1b4e624..40bac9cb 100644
--- a/rest_framework/tests/renderers.py
+++ b/rest_framework/tests/renderers.py
@@ -1,29 +1,28 @@
-import pickle
-import re
-
+from decimal import Decimal
from django.core.cache import cache
from django.test import TestCase
from django.test.client import RequestFactory
-
+from django.utils import unittest
from rest_framework import status, permissions
-from rest_framework.compat import yaml, patterns, url, include
+from rest_framework.compat import yaml, etree, patterns, url, include
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
XMLRenderer, JSONPRenderer, BrowsableAPIRenderer
from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings
-
-from StringIO import StringIO
+from rest_framework.compat import StringIO
+from rest_framework.compat import six
import datetime
-from decimal import Decimal
+import pickle
+import re
DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'
-RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x
-RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x
+RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
+RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
expected_results = [
@@ -35,7 +34,7 @@ class BasicRendererTests(TestCase):
def test_expected_results(self):
for value, renderer_cls, expected in expected_results:
output = renderer_cls().render(value)
- self.assertEquals(output, expected)
+ self.assertEqual(output, expected)
class RendererA(BaseRenderer):
@@ -94,7 +93,7 @@ urlpatterns = patterns('',
class POSTDeniedPermission(permissions.BasePermission):
- def has_permission(self, request, view, obj=None):
+ def has_permission(self, request, view):
return request.method != 'POST'
@@ -111,6 +110,9 @@ class POSTDeniedView(APIView):
def put(self, request):
return Response()
+ def patch(self, request):
+ return Response()
+
class DocumentingRendererTests(TestCase):
def test_only_permitted_forms_are_displayed(self):
@@ -119,6 +121,7 @@ class DocumentingRendererTests(TestCase):
response = view(request).render()
self.assertNotContains(response, '>POST<')
self.assertContains(response, '>PUT<')
+ self.assertContains(response, '>PATCH<')
class RendererEndToEndTests(TestCase):
@@ -131,39 +134,39 @@ class RendererEndToEndTests(TestCase):
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
- self.assertEquals(resp.status_code, DUMMYSTATUS)
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, '')
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_accept_query(self):
"""The '_accept' query string should behave in the same way as the Accept header."""
@@ -172,14 +175,14 @@ class RendererEndToEndTests(TestCase):
RendererB.media_type
)
resp = self.client.get('/' + param)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
- self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
+ self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
@@ -189,17 +192,17 @@ class RendererEndToEndTests(TestCase):
RendererB.format
)
resp = self.client.get('/' + param)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified,
@@ -210,9 +213,9 @@ class RendererEndToEndTests(TestCase):
)
resp = self.client.get('/' + param,
HTTP_ACCEPT=RendererB.media_type)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
_flat_repr = '{"foo": ["bar", "baz"]}'
@@ -240,7 +243,7 @@ class JSONRendererTests(TestCase):
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json')
# Fix failing test case which depends on version of JSON library.
- self.assertEquals(content, _flat_repr)
+ self.assertEqual(content, _flat_repr)
def test_with_content_type_args(self):
"""
@@ -249,7 +252,7 @@ class JSONRendererTests(TestCase):
obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json; indent=2')
- self.assertEquals(strip_trailing_whitespace(content), _indented_repr)
+ self.assertEqual(strip_trailing_whitespace(content), _indented_repr)
class JSONPRendererTests(TestCase):
@@ -265,9 +268,10 @@ class JSONPRendererTests(TestCase):
"""
resp = self.client.get('/jsonp/jsonrenderer',
HTTP_ACCEPT='application/javascript')
- self.assertEquals(resp.status_code, 200)
- self.assertEquals(resp['Content-Type'], 'application/javascript')
- self.assertEquals(resp.content, 'callback(%s);' % _flat_repr)
+ self.assertEqual(resp.status_code, status.HTTP_200_OK)
+ self.assertEqual(resp['Content-Type'], 'application/javascript')
+ self.assertEqual(resp.content,
+ ('callback(%s);' % _flat_repr).encode('ascii'))
def test_without_callback_without_json_renderer(self):
"""
@@ -275,9 +279,10 @@ class JSONPRendererTests(TestCase):
"""
resp = self.client.get('/jsonp/nojsonrenderer',
HTTP_ACCEPT='application/javascript')
- self.assertEquals(resp.status_code, 200)
- self.assertEquals(resp['Content-Type'], 'application/javascript')
- self.assertEquals(resp.content, 'callback(%s);' % _flat_repr)
+ self.assertEqual(resp.status_code, status.HTTP_200_OK)
+ self.assertEqual(resp['Content-Type'], 'application/javascript')
+ self.assertEqual(resp.content,
+ ('callback(%s);' % _flat_repr).encode('ascii'))
def test_with_callback(self):
"""
@@ -286,9 +291,10 @@ class JSONPRendererTests(TestCase):
callback_func = 'myjsonpcallback'
resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func,
HTTP_ACCEPT='application/javascript')
- self.assertEquals(resp.status_code, 200)
- self.assertEquals(resp['Content-Type'], 'application/javascript')
- self.assertEquals(resp.content, '%s(%s);' % (callback_func, _flat_repr))
+ self.assertEqual(resp.status_code, status.HTTP_200_OK)
+ self.assertEqual(resp['Content-Type'], 'application/javascript')
+ self.assertEqual(resp.content,
+ ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii'))
if yaml:
@@ -306,7 +312,7 @@ if yaml:
obj = {'foo': ['bar', 'baz']}
renderer = YAMLRenderer()
content = renderer.render(obj, 'application/yaml')
- self.assertEquals(content, _yaml_repr)
+ self.assertEqual(content, _yaml_repr)
def test_render_and_parse(self):
"""
@@ -320,7 +326,7 @@ if yaml:
content = renderer.render(obj, 'application/yaml')
data = parser.parse(StringIO(content))
- self.assertEquals(obj, data)
+ self.assertEqual(obj, data)
class XMLRendererTestCase(TestCase):
@@ -402,6 +408,7 @@ class XMLRendererTestCase(TestCase):
self.assertXMLContains(content, '<sub_name>first</sub_name>')
self.assertXMLContains(content, '<sub_name>second</sub_name>')
+ @unittest.skipUnless(etree, 'defusedxml not installed')
def test_render_and_parse_complex_data(self):
"""
Test XML rendering.
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py
index 4b032405..97e5af20 100644
--- a/rest_framework/tests/request.py
+++ b/rest_framework/tests/request.py
@@ -1,7 +1,7 @@
"""
Tests for content parsing, and form-overloaded content parsing.
"""
-import json
+from __future__ import unicode_literals
from django.contrib.auth.models import User
from django.contrib.auth import authenticate, login, logout
from django.contrib.sessions.middleware import SessionMiddleware
@@ -20,6 +20,8 @@ from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.views import APIView
+from rest_framework.compat import six
+import json
factory = RequestFactory()
@@ -56,21 +58,29 @@ class TestMethodOverloading(TestCase):
request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'}))
self.assertEqual(request.method, 'DELETE')
+ def test_x_http_method_override_header(self):
+ """
+ POST requests can also be overloaded to another method by setting
+ the X-HTTP-Method-Override header.
+ """
+ request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
+ self.assertEqual(request.method, 'DELETE')
+
class TestContentParsing(TestCase):
def test_standard_behaviour_determines_no_content_GET(self):
"""
- Ensure request.DATA returns None for GET request with no content.
+ Ensure request.DATA returns empty QueryDict for GET request.
"""
request = Request(factory.get('/'))
- self.assertEqual(request.DATA, None)
+ self.assertEqual(request.DATA, {})
def test_standard_behaviour_determines_no_content_HEAD(self):
"""
- Ensure request.DATA returns None for HEAD request.
+ Ensure request.DATA returns empty QueryDict for HEAD request.
"""
request = Request(factory.head('/'))
- self.assertEqual(request.DATA, None)
+ self.assertEqual(request.DATA, {})
def test_request_DATA_with_form_content(self):
"""
@@ -79,14 +89,14 @@ class TestContentParsing(TestCase):
data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data))
request.parsers = (FormParser(), MultiPartParser())
- self.assertEqual(request.DATA.items(), data.items())
+ self.assertEqual(list(request.DATA.items()), list(data.items()))
def test_request_DATA_with_text_content(self):
"""
Ensure request.DATA returns content for POST request with
non-form content.
"""
- content = 'qwerty'
+ content = six.b('qwerty')
content_type = 'text/plain'
request = Request(factory.post('/', content, content_type=content_type))
request.parsers = (PlainTextParser(),)
@@ -99,7 +109,7 @@ class TestContentParsing(TestCase):
data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data))
request.parsers = (FormParser(), MultiPartParser())
- self.assertEqual(request.POST.items(), data.items())
+ self.assertEqual(list(request.POST.items()), list(data.items()))
def test_standard_behaviour_determines_form_content_PUT(self):
"""
@@ -117,14 +127,14 @@ class TestContentParsing(TestCase):
request = Request(factory.put('/', data))
request.parsers = (FormParser(), MultiPartParser())
- self.assertEqual(request.DATA.items(), data.items())
+ self.assertEqual(list(request.DATA.items()), list(data.items()))
def test_standard_behaviour_determines_non_form_content_PUT(self):
"""
Ensure request.DATA returns content for PUT request with
non-form content.
"""
- content = 'qwerty'
+ content = six.b('qwerty')
content_type = 'text/plain'
request = Request(factory.put('/', content, content_type=content_type))
request.parsers = (PlainTextParser(), )
diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py
index 875f4d42..aecf83f4 100644
--- a/rest_framework/tests/response.py
+++ b/rest_framework/tests/response.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.compat import patterns, url, include
from rest_framework.response import Response
@@ -9,6 +10,7 @@ from rest_framework.renderers import (
BrowsableAPIRenderer
)
from rest_framework.settings import api_settings
+from rest_framework.compat import six
class MockPickleRenderer(BaseRenderer):
@@ -22,8 +24,8 @@ class MockJsonRenderer(BaseRenderer):
DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'
-RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x
-RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x
+RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
+RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
class RendererA(BaseRenderer):
@@ -83,39 +85,39 @@ class RendererIntegrationTests(TestCase):
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
- self.assertEquals(resp.status_code, DUMMYSTATUS)
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, '')
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_accept_query(self):
"""The '_accept' query string should behave in the same way as the Accept header."""
@@ -124,34 +126,34 @@ class RendererIntegrationTests(TestCase):
RendererB.media_type
)
resp = self.client.get('/' + param)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format,
HTTP_ACCEPT=RendererB.media_type)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
class Issue122Tests(TestCase):
diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/reverse.py
index 8c86e1fb..cb8d8132 100644
--- a/rest_framework/tests/reverse.py
+++ b/rest_framework/tests/reverse.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework.compat import patterns, url
@@ -16,7 +17,7 @@ urlpatterns = patterns('',
class ReverseTests(TestCase):
"""
- Tests for fully qualifed URLs when using `reverse`.
+ Tests for fully qualified URLs when using `reverse`.
"""
urls = 'rest_framework.tests.reverse'
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index bd96ba23..beb372c2 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -1,10 +1,12 @@
-import datetime
-import pickle
+from __future__ import unicode_literals
+from django.utils.datastructures import MultiValueDict
from django.test import TestCase
from rest_framework import serializers
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
+import datetime
+import pickle
class SubComment(object):
@@ -54,6 +56,19 @@ class ActionItemSerializer(serializers.ModelSerializer):
model = ActionItem
+class ActionItemSerializerCustomRestore(serializers.ModelSerializer):
+
+ class Meta:
+ model = ActionItem
+
+ def restore_object(self, data, instance=None):
+ if instance is None:
+ return ActionItem(**data)
+ for key, val in data.items():
+ setattr(instance, key, val)
+ return instance
+
+
class PersonSerializer(serializers.ModelSerializer):
info = serializers.Field(source='info')
@@ -76,6 +91,11 @@ class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
fields = ['some_integer']
+class BrokenModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ fields = ['some_field']
+
+
class BasicTests(TestCase):
def setUp(self):
self.comment = Comment(
@@ -92,7 +112,7 @@ class BasicTests(TestCase):
self.expected = {
'email': 'tom@example.com',
'content': 'Happy new year!',
- 'created': datetime.datetime(2012, 1, 1),
+ 'created': '2012-01-01T00:00:00',
'sub_comment': 'And Merry Christmas!'
}
self.person_data = {'name': 'dwight', 'age': 35}
@@ -107,39 +127,39 @@ class BasicTests(TestCase):
'created': None,
'sub_comment': ''
}
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_retrieve(self):
serializer = CommentSerializer(self.comment)
- self.assertEquals(serializer.data, self.expected)
+ self.assertEqual(serializer.data, self.expected)
def test_create(self):
serializer = CommentSerializer(data=self.data)
expected = self.comment
- self.assertEquals(serializer.is_valid(), True)
- self.assertEquals(serializer.object, expected)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
self.assertFalse(serializer.object is expected)
- self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
+ self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data)
expected = self.comment
- self.assertEquals(serializer.is_valid(), True)
- self.assertEquals(serializer.object, expected)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
self.assertTrue(serializer.object is expected)
- self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
+ self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
def test_partial_update(self):
msg = 'Merry New Year!'
partial_data = {'content': msg}
serializer = CommentSerializer(self.comment, data=partial_data)
- self.assertEquals(serializer.is_valid(), False)
+ self.assertEqual(serializer.is_valid(), False)
serializer = CommentSerializer(self.comment, data=partial_data, partial=True)
expected = self.comment
self.assertEqual(serializer.is_valid(), True)
- self.assertEquals(serializer.object, expected)
+ self.assertEqual(serializer.object, expected)
self.assertTrue(serializer.object is expected)
- self.assertEquals(serializer.data['content'], msg)
+ self.assertEqual(serializer.data['content'], msg)
def test_model_fields_as_expected(self):
"""
@@ -147,7 +167,7 @@ class BasicTests(TestCase):
in the Meta data
"""
serializer = PersonSerializer(self.person)
- self.assertEquals(set(serializer.data.keys()),
+ self.assertEqual(set(serializer.data.keys()),
set(['name', 'age', 'info']))
def test_field_with_dictionary(self):
@@ -156,19 +176,45 @@ class BasicTests(TestCase):
"""
serializer = PersonSerializer(self.person)
expected = self.person_data
- self.assertEquals(serializer.data['info'], expected)
+ self.assertEqual(serializer.data['info'], expected)
def test_read_only_fields(self):
"""
Attempting to update fields set as read_only should have no effect.
"""
-
serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(serializer.errors, {})
+ self.assertEqual(serializer.errors, {})
# Assert age is unchanged (35)
- self.assertEquals(instance.age, self.person_data['age'])
+ self.assertEqual(instance.age, self.person_data['age'])
+
+
+class DictStyleSerializer(serializers.Serializer):
+ """
+ Note that we don't have any `restore_object` method, so the default
+ case of simply returning a dict will apply.
+ """
+ email = serializers.EmailField()
+
+
+class DictStyleSerializerTests(TestCase):
+ def test_dict_style_deserialize(self):
+ """
+ Ensure serializers can deserialize into a dict.
+ """
+ data = {'email': 'foo@example.com'}
+ serializer = DictStyleSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+
+ def test_dict_style_serialize(self):
+ """
+ Ensure serializers can serialize dict objects.
+ """
+ data = {'email': 'foo@example.com'}
+ serializer = DictStyleSerializer(data)
+ self.assertEqual(serializer.data, data)
class ValidationTests(TestCase):
@@ -183,18 +229,17 @@ class ValidationTests(TestCase):
'content': 'x' * 1001,
'created': datetime.datetime(2012, 1, 1)
}
- self.actionitem = ActionItem(title='Some to do item',
- )
+ self.actionitem = ActionItem(title='Some to do item',)
def test_create(self):
serializer = CommentSerializer(data=self.data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
def test_update_missing_field(self):
data = {
@@ -202,8 +247,8 @@ class ValidationTests(TestCase):
'created': datetime.datetime(2012, 1, 1)
}
serializer = CommentSerializer(self.comment, data=data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'email': [u'This field is required.']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'email': ['This field is required.']})
def test_missing_bool_with_default(self):
"""Make sure that a boolean value with a 'False' value is not
@@ -213,52 +258,36 @@ class ValidationTests(TestCase):
#No 'done' value.
}
serializer = ActionItemSerializer(self.actionitem, data=data)
- self.assertEquals(serializer.is_valid(), True)
- self.assertEquals(serializer.errors, {})
-
- def test_field_validation(self):
-
- class CommentSerializerWithFieldValidator(CommentSerializer):
-
- def validate_content(self, attrs, source):
- value = attrs[source]
- if "test" not in value:
- raise serializers.ValidationError("Test not in value")
- return attrs
-
- data = {
- 'email': 'tom@example.com',
- 'content': 'A test comment',
- 'created': datetime.datetime(2012, 1, 1)
- }
-
- serializer = CommentSerializerWithFieldValidator(data=data)
- self.assertTrue(serializer.is_valid())
-
- data['content'] = 'This should not validate'
-
- serializer = CommentSerializerWithFieldValidator(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.errors, {})
def test_bad_type_data_is_false(self):
"""
Data of the wrong type is not valid.
"""
data = ['i am', 'a', 'list']
- serializer = CommentSerializer(self.comment, data=data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
+ serializer = CommentSerializer(self.comment, data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertTrue(isinstance(serializer.errors, list))
+
+ self.assertEqual(
+ serializer.errors,
+ [
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']}
+ ]
+ )
data = 'and i am a string'
serializer = CommentSerializer(self.comment, data=data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
data = 42
serializer = CommentSerializer(self.comment, data=data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
def test_cross_field_validation(self):
@@ -282,23 +311,37 @@ class ValidationTests(TestCase):
serializer = CommentSerializerWithCrossFieldValidator(data=data)
self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']})
+ self.assertEqual(serializer.errors, {'non_field_errors': ['Email address not in content']})
def test_null_is_true_fields(self):
"""
Omitting a value for null-field should validate.
"""
serializer = PersonSerializer(data={'name': 'marko'})
- self.assertEquals(serializer.is_valid(), True)
- self.assertEquals(serializer.errors, {})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.errors, {})
def test_modelserializer_max_length_exceeded(self):
data = {
'title': 'x' * 201,
}
serializer = ActionItemSerializer(data=data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
+
+ def test_modelserializer_max_length_exceeded_with_custom_restore(self):
+ """
+ When overriding ModelSerializer.restore_object, validation tests should still apply.
+ Regression test for #623.
+
+ https://github.com/tomchristie/django-rest-framework/pull/623
+ """
+ data = {
+ 'title': 'x' * 201,
+ }
+ serializer = ActionItemSerializerCustomRestore(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
def test_default_modelfield_max_length_exceeded(self):
data = {
@@ -306,15 +349,99 @@ class ValidationTests(TestCase):
'info': 'x' * 13,
}
serializer = ActionItemSerializer(data=data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']})
+
+ def test_datetime_validation_failure(self):
+ """
+ Test DateTimeField validation errors on non-str values.
+ Regression test for #669.
+
+ https://github.com/tomchristie/django-rest-framework/issues/669
+ """
+ data = self.data
+ data['created'] = 0
+
+ serializer = CommentSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+
+ self.assertIn('created', serializer.errors)
+
+ def test_missing_model_field_exception_msg(self):
+ """
+ Assert that a meaningful exception message is outputted when the model
+ field is missing (e.g. when mistyping ``model``).
+ """
+ try:
+ serializer = BrokenModelSerializer()
+ except AssertionError as e:
+ self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option")
+ except:
+ self.fail('Wrong exception type thrown.')
+
+
+class CustomValidationTests(TestCase):
+ class CommentSerializerWithFieldValidator(CommentSerializer):
+
+ def validate_email(self, attrs, source):
+ value = attrs[source]
+
+ return attrs
+
+ def validate_content(self, attrs, source):
+ value = attrs[source]
+ if "test" not in value:
+ raise serializers.ValidationError("Test not in value")
+ return attrs
+
+ def test_field_validation(self):
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'This should not validate'
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'content': ['Test not in value']})
+
+ def test_missing_data(self):
+ """
+ Make sure that validate_content isn't called if the field is missing
+ """
+ incomplete_data = {
+ 'email': 'tom@example.com',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'content': ['This field is required.']})
+
+ def test_wrong_data(self):
+ """
+ Make sure that validate_content isn't called if the field input is wrong
+ """
+ wrong_data = {
+ 'email': 'not an email',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'email': ['Enter a valid e-mail address.']})
class PositiveIntegerAsChoiceTests(TestCase):
def test_positive_integer_in_json_is_correctly_parsed(self):
- data = {'some_integer':1}
+ data = {'some_integer': 1}
serializer = PositiveIntegerAsChoiceSerializer(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
+
class ModelValidationTests(TestCase):
def test_validate_unique(self):
@@ -326,7 +453,7 @@ class ModelValidationTests(TestCase):
serializer.save()
second_serializer = AlbumsSerializer(data={'title': 'a'})
self.assertFalse(second_serializer.is_valid())
- self.assertEqual(second_serializer.errors, {'title': [u'Album with this Title already exists.']})
+ self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
def test_foreign_key_with_partial(self):
"""
@@ -364,15 +491,15 @@ class RegexValidationTest(TestCase):
def test_create_failed(self):
serializer = BookSerializer(data={'isbn': '1234567890'})
self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+ self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
serializer = BookSerializer(data={'isbn': '12345678901234'})
self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+ self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+ self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
def test_create_success(self):
serializer = BookSerializer(data={'isbn': '1234567890123'})
@@ -417,7 +544,7 @@ class ManyToManyTests(TestCase):
"""
serializer = self.serializer_class(instance=self.instance)
expected = self.data
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_create(self):
"""
@@ -425,11 +552,11 @@ class ManyToManyTests(TestCase):
"""
data = {'rel': [self.anchor.id]}
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ManyToManyModel.objects.all()), 2)
- self.assertEquals(instance.pk, 2)
- self.assertEquals(list(instance.rel.all()), [self.anchor])
+ self.assertEqual(len(ManyToManyModel.objects.all()), 2)
+ self.assertEqual(instance.pk, 2)
+ self.assertEqual(list(instance.rel.all()), [self.anchor])
def test_update(self):
"""
@@ -439,11 +566,11 @@ class ManyToManyTests(TestCase):
new_anchor.save()
data = {'rel': [self.anchor.id, new_anchor.id]}
serializer = self.serializer_class(self.instance, data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ManyToManyModel.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
- self.assertEquals(list(instance.rel.all()), [self.anchor, new_anchor])
+ self.assertEqual(len(ManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(list(instance.rel.all()), [self.anchor, new_anchor])
def test_create_empty_relationship(self):
"""
@@ -452,11 +579,11 @@ class ManyToManyTests(TestCase):
"""
data = {'rel': []}
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ManyToManyModel.objects.all()), 2)
- self.assertEquals(instance.pk, 2)
- self.assertEquals(list(instance.rel.all()), [])
+ self.assertEqual(len(ManyToManyModel.objects.all()), 2)
+ self.assertEqual(instance.pk, 2)
+ self.assertEqual(list(instance.rel.all()), [])
def test_update_empty_relationship(self):
"""
@@ -467,11 +594,11 @@ class ManyToManyTests(TestCase):
new_anchor.save()
data = {'rel': []}
serializer = self.serializer_class(self.instance, data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ManyToManyModel.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
- self.assertEquals(list(instance.rel.all()), [])
+ self.assertEqual(len(ManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(list(instance.rel.all()), [])
def test_create_empty_relationship_flat_data(self):
"""
@@ -479,19 +606,20 @@ class ManyToManyTests(TestCase):
containing no items, using a representation that does not support
lists (eg form data).
"""
- data = {'rel': ''}
+ data = MultiValueDict()
+ data.setlist('rel', [''])
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ManyToManyModel.objects.all()), 2)
- self.assertEquals(instance.pk, 2)
- self.assertEquals(list(instance.rel.all()), [])
+ self.assertEqual(len(ManyToManyModel.objects.all()), 2)
+ self.assertEqual(instance.pk, 2)
+ self.assertEqual(list(instance.rel.all()), [])
class ReadOnlyManyToManyTests(TestCase):
def setUp(self):
class ReadOnlyManyToManySerializer(serializers.ModelSerializer):
- rel = serializers.ManyRelatedField(read_only=True)
+ rel = serializers.RelatedField(many=True, read_only=True)
class Meta:
model = ReadOnlyManyToManyModel
@@ -519,12 +647,12 @@ class ReadOnlyManyToManyTests(TestCase):
new_anchor.save()
data = {'rel': [self.anchor.id, new_anchor.id]}
serializer = self.serializer_class(self.instance, data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
+ self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
# rel is still as original (1 entry)
- self.assertEquals(list(instance.rel.all()), [self.anchor])
+ self.assertEqual(list(instance.rel.all()), [self.anchor])
def test_update_without_relationship(self):
"""
@@ -535,12 +663,12 @@ class ReadOnlyManyToManyTests(TestCase):
new_anchor.save()
data = {}
serializer = self.serializer_class(self.instance, data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
+ self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
# rel is still as original (1 entry)
- self.assertEquals(list(instance.rel.all()), [self.anchor])
+ self.assertEqual(list(instance.rel.all()), [self.anchor])
class DefaultValueTests(TestCase):
@@ -555,35 +683,35 @@ class DefaultValueTests(TestCase):
def test_create_using_default(self):
data = {}
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(self.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
- self.assertEquals(instance.text, 'foobar')
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'foobar')
def test_create_overriding_default(self):
data = {'text': 'overridden'}
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(self.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
- self.assertEquals(instance.text, 'overridden')
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'overridden')
def test_partial_update_default(self):
""" Regression test for issue #532 """
data = {'text': 'overridden'}
serializer = self.serializer_class(data=data, partial=True)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
data = {'extra': 'extra_value'}
serializer = self.serializer_class(instance=instance, data=data, partial=True)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(instance.extra, 'extra_value')
- self.assertEquals(instance.text, 'overridden')
+ self.assertEqual(instance.extra, 'extra_value')
+ self.assertEqual(instance.text, 'overridden')
class CallableDefaultValueTests(TestCase):
@@ -598,20 +726,20 @@ class CallableDefaultValueTests(TestCase):
def test_create_using_default(self):
data = {}
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(self.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
- self.assertEquals(instance.text, 'foobar')
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'foobar')
def test_create_overriding_default(self):
data = {'text': 'overridden'}
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(self.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
- self.assertEquals(instance.text, 'overridden')
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'overridden')
class ManyRelatedTests(TestCase):
@@ -660,6 +788,9 @@ class ManyRelatedTests(TestCase):
class RelatedTraversalTest(TestCase):
def test_nested_traversal(self):
+ """
+ Source argument should support dotted.source notation.
+ """
user = Person.objects.create(name="django")
post = BlogPost.objects.create(title="Test blog post", writer=user)
post.blogpostcomment_set.create(text="I love this blog post")
@@ -686,11 +817,11 @@ class RelatedTraversalTest(TestCase):
serializer = BlogPostSerializer(instance=post)
expected = {
- 'title': u'Test blog post',
+ 'title': 'Test blog post',
'comments': [{
- 'text': u'I love this blog post',
+ 'text': 'I love this blog post',
'post_owner': {
- "name": u"django",
+ "name": "django",
"age": None
}
}]
@@ -698,6 +829,41 @@ class RelatedTraversalTest(TestCase):
self.assertEqual(serializer.data, expected)
+ def test_nested_traversal_with_none(self):
+ """
+ If a component of the dotted.source is None, return None for the field.
+ """
+ from rest_framework.tests.models import NullableForeignKeySource
+ instance = NullableForeignKeySource.objects.create(name='Source with null FK')
+
+ class NullableSourceSerializer(serializers.Serializer):
+ target_name = serializers.Field(source='target.name')
+
+ serializer = NullableSourceSerializer(instance=instance)
+
+ expected = {
+ 'target_name': None,
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+ def test_queryset_nested_traversal(self):
+ """
+ Relational fields should be able to use methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+
+ class QuerysetMethodSerializer(serializers.Serializer):
+ blogposts = serializers.RelatedField(many=True, source='get_all_blogposts')
+
+ class ClassWithQuerysetMethod(object):
+ def get_all_blogposts(self):
+ return BlogPost.objects
+
+ obj = ClassWithQuerysetMethod()
+ serializer = QuerysetMethodSerializer(obj)
+ self.assertEqual(serializer.data, {'blogposts': ['BlogPost object']})
+
class SerializerMethodFieldTests(TestCase):
def setUp(self):
@@ -725,8 +891,8 @@ class SerializerMethodFieldTests(TestCase):
serializer = self.serializer_class(source_data)
expected = {
- 'beep': u'hello!',
- 'boop': [u'a', u'b', u'c'],
+ 'beep': 'hello!',
+ 'boop': ['a', 'b', 'c'],
'boop_count': 3,
}
@@ -742,7 +908,7 @@ class BlankFieldTests(TestCase):
model = BlankFieldModel
class BlankFieldSerializer(serializers.Serializer):
- title = serializers.CharField(blank=True)
+ title = serializers.CharField(required=False)
class NotBlankFieldModelSerializer(serializers.ModelSerializer):
class Meta:
@@ -759,15 +925,15 @@ class BlankFieldTests(TestCase):
def test_create_blank_field(self):
serializer = self.serializer_class(data=self.data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
def test_create_model_blank_field(self):
serializer = self.model_serializer_class(data=self.data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
def test_create_model_null_field(self):
serializer = self.model_serializer_class(data={'title': None})
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
def test_create_not_blank_field(self):
"""
@@ -775,7 +941,7 @@ class BlankFieldTests(TestCase):
is considered invalid in a non-model serializer
"""
serializer = self.not_blank_serializer_class(data=self.data)
- self.assertEquals(serializer.is_valid(), False)
+ self.assertEqual(serializer.is_valid(), False)
def test_create_model_not_blank_field(self):
"""
@@ -783,11 +949,11 @@ class BlankFieldTests(TestCase):
is considered invalid in a model serializer
"""
serializer = self.not_blank_model_serializer_class(data=self.data)
- self.assertEquals(serializer.is_valid(), False)
+ self.assertEqual(serializer.is_valid(), False)
- def test_create_model_null_field(self):
+ def test_create_model_empty_field(self):
serializer = self.model_serializer_class(data={})
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
#test for issue #460
@@ -811,7 +977,21 @@ class SerializerPickleTests(TestCase):
class Meta:
model = Person
fields = ('name', 'age')
- pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data)
+ pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data, 0)
+
+ def test_getstate_method_should_not_return_none(self):
+ """
+ Regression test for #645.
+ """
+ data = serializers.DictWithMetadata({1: 1})
+ self.assertEqual(data.__getstate__(), serializers.SortedDict({1: 1}))
+
+ def test_serializer_data_is_pickleable(self):
+ """
+ Another regression test for #645.
+ """
+ data = serializers.SortedDictWithMetadata({1: 1})
+ repr(pickle.loads(pickle.dumps(data, 0)))
class DepthTest(TestCase):
@@ -825,8 +1005,8 @@ class DepthTest(TestCase):
depth = 1
serializer = BlogPostSerializer(instance=post)
- expected = {'id': 1, 'title': u'Test blog post',
- 'writer': {'id': 1, 'name': u'django', 'age': 1}}
+ expected = {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}
self.assertEqual(serializer.data, expected)
@@ -845,8 +1025,8 @@ class DepthTest(TestCase):
model = BlogPost
serializer = BlogPostSerializer(instance=post)
- expected = {'id': 1, 'title': u'Test blog post',
- 'writer': {'id': 1, 'name': u'django', 'age': 1}}
+ expected = {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}
self.assertEqual(serializer.data, expected)
@@ -901,3 +1081,32 @@ class NestedSerializerContextTests(TestCase):
# This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers
AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data
+
+
+class DeserializeListTestCase(TestCase):
+
+ def setUp(self):
+ self.data = {
+ 'email': 'nobody@nowhere.com',
+ 'content': 'This is some test content',
+ 'created': datetime.datetime(2013, 3, 7),
+ }
+
+ def test_no_errors(self):
+ data = [self.data.copy() for x in range(0, 3)]
+ serializer = CommentSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertTrue(isinstance(serializer.object, list))
+ self.assertTrue(
+ all((isinstance(item, Comment) for item in serializer.object))
+ )
+
+ def test_errors_return_as_list(self):
+ invalid_item = self.data.copy()
+ invalid_item['email'] = ''
+ data = [self.data.copy(), invalid_item, self.data.copy()]
+
+ serializer = CommentSerializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ expected = [{}, {'email': ['This field is required.']}, {}]
+ self.assertEqual(serializer.errors, expected)
diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/settings.py
index 0293fdc3..857375c2 100644
--- a/rest_framework/tests/settings.py
+++ b/rest_framework/tests/settings.py
@@ -1,4 +1,5 @@
"""Tests for the settings module"""
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS
diff --git a/rest_framework/tests/status.py b/rest_framework/tests/status.py
index 30df5cef..e1644a6b 100644
--- a/rest_framework/tests/status.py
+++ b/rest_framework/tests/status.py
@@ -1,4 +1,5 @@
"""Tests for the status module"""
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import status
@@ -8,5 +9,5 @@ class TestStatus(TestCase):
def test_status(self):
"""Ensure the status module is present and correct."""
- self.assertEquals(200, status.HTTP_200_OK)
- self.assertEquals(404, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(200, status.HTTP_200_OK)
+ self.assertEqual(404, status.HTTP_404_NOT_FOUND)
diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py
index 97f492ff..f8c2579e 100644
--- a/rest_framework/tests/testcases.py
+++ b/rest_framework/tests/testcases.py
@@ -1,4 +1,5 @@
# http://djangosnippets.org/snippets/1011/
+from __future__ import unicode_literals
from django.conf import settings
from django.core.management import call_command
from django.db.models import loading
diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py
index adeaf6da..08f88e11 100644
--- a/rest_framework/tests/tests.py
+++ b/rest_framework/tests/tests.py
@@ -2,6 +2,7 @@
Force import of all modules in this package in order to get the standard test
runner to pick up the tests. Yowzers.
"""
+from __future__ import unicode_literals
import os
modules = [filename.rsplit('.', 1)[0]
diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py
index 4b98b941..11cbd8eb 100644
--- a/rest_framework/tests/throttling.py
+++ b/rest_framework/tests/throttling.py
@@ -1,11 +1,10 @@
"""
Tests for the throttling implementations in the permissions module.
"""
-
+from __future__ import unicode_literals
from django.test import TestCase
from django.contrib.auth.models import User
from django.core.cache import cache
-
from django.test.client import RequestFactory
from rest_framework.views import APIView
from rest_framework.throttling import UserRateThrottle
@@ -104,7 +103,7 @@ class ThrottlingTests(TestCase):
self.set_throttle_timer(view, timer)
response = view.as_view()(request)
if expect is not None:
- self.assertEquals(response['X-Throttle-Wait-Seconds'], expect)
+ self.assertEqual(response['X-Throttle-Wait-Seconds'], expect)
else:
self.assertFalse('X-Throttle-Wait-Seconds' in response)
diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/urlpatterns.py
new file mode 100644
index 00000000..29ed4a96
--- /dev/null
+++ b/rest_framework/tests/urlpatterns.py
@@ -0,0 +1,76 @@
+from __future__ import unicode_literals
+from collections import namedtuple
+from django.core import urlresolvers
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework.compat import patterns, url, include
+from rest_framework.urlpatterns import format_suffix_patterns
+
+
+# A container class for test paths for the test case
+URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs'])
+
+
+def dummy_view(request, *args, **kwargs):
+ pass
+
+
+class FormatSuffixTests(TestCase):
+ """
+ Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
+ """
+ def _resolve_urlpatterns(self, urlpatterns, test_paths):
+ factory = RequestFactory()
+ try:
+ urlpatterns = format_suffix_patterns(urlpatterns)
+ except Exception:
+ self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
+ resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
+ for test_path in test_paths:
+ request = factory.get(test_path.path)
+ try:
+ callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
+ except Exception:
+ self.fail("Failed to resolve URL: %s" % request.path_info)
+ self.assertEqual(callback_args, test_path.args)
+ self.assertEqual(callback_kwargs, test_path.kwargs)
+
+ def test_format_suffix(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {}),
+ URLTestPath('/test.api', (), {'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_default_args(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view, {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {'foo': 'bar', }),
+ URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_included_urls(self):
+ nested_patterns = patterns(
+ '',
+ url(r'^path$', dummy_view)
+ )
+ urlpatterns = patterns(
+ '',
+ url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test/path', (), {'foo': 'bar', }),
+ URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py
index 3906adb9..8c87917d 100644
--- a/rest_framework/tests/utils.py
+++ b/rest_framework/tests/utils.py
@@ -1,9 +1,10 @@
-from django.test.client import RequestFactory, FakePayload
+from __future__ import unicode_literals
+from django.test.client import FakePayload, Client as _Client, RequestFactory as _RequestFactory
from django.test.client import MULTIPART_CONTENT
-from urlparse import urlparse
+from rest_framework.compat import urlparse
-class RequestFactory(RequestFactory):
+class RequestFactory(_RequestFactory):
def __init__(self, **defaults):
super(RequestFactory, self).__init__(**defaults)
@@ -14,7 +15,7 @@ class RequestFactory(RequestFactory):
patch_data = self._encode_data(data, content_type)
- parsed = urlparse(path)
+ parsed = urlparse.urlparse(path)
r = {
'CONTENT_LENGTH': len(patch_data),
'CONTENT_TYPE': content_type,
@@ -25,3 +26,15 @@ class RequestFactory(RequestFactory):
}
r.update(extra)
return self.request(**r)
+
+
+class Client(_Client, RequestFactory):
+ def patch(self, path, data={}, content_type=MULTIPART_CONTENT,
+ follow=False, **extra):
+ """
+ Send a resource to the server using PATCH.
+ """
+ response = super(Client, self).patch(path, data=data, content_type=content_type, **extra)
+ if follow:
+ response = self._handle_redirects(response, **extra)
+ return response
diff --git a/rest_framework/tests/validation.py b/rest_framework/tests/validation.py
new file mode 100644
index 00000000..cbdd6515
--- /dev/null
+++ b/rest_framework/tests/validation.py
@@ -0,0 +1,65 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import generics, serializers, status
+from rest_framework.tests.utils import RequestFactory
+import json
+
+factory = RequestFactory()
+
+
+# Regression for #666
+
+class ValidationModel(models.Model):
+ blank_validated_field = models.CharField(max_length=255)
+
+
+class ValidationModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ValidationModel
+ fields = ('blank_validated_field',)
+ read_only_fields = ('blank_validated_field',)
+
+
+class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
+ model = ValidationModel
+ serializer_class = ValidationModelSerializer
+
+
+class TestPreSaveValidationExclusions(TestCase):
+ def test_pre_save_validation_exclusions(self):
+ """
+ Somewhat weird test case to ensure that we don't perform model
+ validation on read only fields.
+ """
+ obj = ValidationModel.objects.create(blank_validated_field='')
+ request = factory.put('/', json.dumps({}),
+ content_type='application/json')
+ view = UpdateValidationModel().as_view()
+ response = view(request, pk=obj.pk).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+# Regression for #653
+
+class ShouldValidateModel(models.Model):
+ should_validate_field = models.CharField(max_length=255)
+
+
+class ShouldValidateModelSerializer(serializers.ModelSerializer):
+ renamed = serializers.CharField(source='should_validate_field', required=False)
+
+ class Meta:
+ model = ShouldValidateModel
+ fields = ('renamed',)
+
+
+class TestPreSaveValidationExclusions(TestCase):
+ def test_renamed_fields_are_model_validated(self):
+ """
+ Ensure fields with 'source' applied do get still get model validation.
+ """
+ # We've set `required=False` on the serializer, but the model
+ # does not have `blank=True`, so this serializer should not validate.
+ serializer = ShouldValidateModelSerializer(data={'renamed': ''})
+ self.assertEqual(serializer.is_valid(), False)
diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py
deleted file mode 100644
index c032985e..00000000
--- a/rest_framework/tests/validators.py
+++ /dev/null
@@ -1,329 +0,0 @@
-# from django import forms
-# from django.db import models
-# from django.test import TestCase
-# from rest_framework.response import ImmediateResponse
-# from rest_framework.views import View
-
-
-# class TestDisabledValidations(TestCase):
-# """Tests on FormValidator with validation disabled by setting form to None"""
-
-# def test_disabled_form_validator_returns_content_unchanged(self):
-# """If the view's form attribute is None then FormValidator(view).validate_request(content, None)
-# should just return the content unmodified."""
-# class DisabledFormResource(FormResource):
-# form = None
-
-# class MockView(View):
-# resource = DisabledFormResource
-
-# view = MockView()
-# content = {'qwerty': 'uiop'}
-# self.assertEqual(FormResource(view).validate_request(content, None), content)
-
-# def test_disabled_form_validator_get_bound_form_returns_none(self):
-# """If the view's form attribute is None on then
-# FormValidator(view).get_bound_form(content) should just return None."""
-# class DisabledFormResource(FormResource):
-# form = None
-
-# class MockView(View):
-# resource = DisabledFormResource
-
-# view = MockView()
-# content = {'qwerty': 'uiop'}
-# self.assertEqual(FormResource(view).get_bound_form(content), None)
-
-# def test_disabled_model_form_validator_returns_content_unchanged(self):
-# """If the view's form is None and does not have a Resource with a model set then
-# ModelFormValidator(view).validate_request(content, None) should just return the content unmodified."""
-
-# class DisabledModelFormView(View):
-# resource = ModelResource
-
-# view = DisabledModelFormView()
-# content = {'qwerty': 'uiop'}
-# self.assertEqual(ModelResource(view).get_bound_form(content), None)
-
-# def test_disabled_model_form_validator_get_bound_form_returns_none(self):
-# """If the form attribute is None on FormValidatorMixin then get_bound_form(content) should just return None."""
-# class DisabledModelFormView(View):
-# resource = ModelResource
-
-# view = DisabledModelFormView()
-# content = {'qwerty': 'uiop'}
-# self.assertEqual(ModelResource(view).get_bound_form(content), None)
-
-
-# class TestNonFieldErrors(TestCase):
-# """Tests against form validation errors caused by non-field errors. (eg as might be caused by some custom form validation)"""
-
-# def test_validate_failed_due_to_non_field_error_returns_appropriate_message(self):
-# """If validation fails with a non-field error, ensure the response a non-field error"""
-# class MockForm(forms.Form):
-# field1 = forms.CharField(required=False)
-# field2 = forms.CharField(required=False)
-# ERROR_TEXT = 'You may not supply both field1 and field2'
-
-# def clean(self):
-# if 'field1' in self.cleaned_data and 'field2' in self.cleaned_data:
-# raise forms.ValidationError(self.ERROR_TEXT)
-# return self.cleaned_data
-
-# class MockResource(FormResource):
-# form = MockForm
-
-# class MockView(View):
-# pass
-
-# view = MockView()
-# content = {'field1': 'example1', 'field2': 'example2'}
-# try:
-# MockResource(view).validate_request(content, None)
-# except ImmediateResponse, exc:
-# response = exc.response
-# self.assertEqual(response.raw_content, {'errors': [MockForm.ERROR_TEXT]})
-# else:
-# self.fail('ImmediateResponse was not raised')
-
-
-# class TestFormValidation(TestCase):
-# """Tests which check basic form validation.
-# Also includes the same set of tests with a ModelFormValidator for which the form has been explicitly set.
-# (ModelFormValidator should behave as FormValidator if a form is set rather than relying on the default ModelForm)"""
-# def setUp(self):
-# class MockForm(forms.Form):
-# qwerty = forms.CharField(required=True)
-
-# class MockFormResource(FormResource):
-# form = MockForm
-
-# class MockModelResource(ModelResource):
-# form = MockForm
-
-# class MockFormView(View):
-# resource = MockFormResource
-
-# class MockModelFormView(View):
-# resource = MockModelResource
-
-# self.MockFormResource = MockFormResource
-# self.MockModelResource = MockModelResource
-# self.MockFormView = MockFormView
-# self.MockModelFormView = MockModelFormView
-
-# def validation_returns_content_unchanged_if_already_valid_and_clean(self, validator):
-# """If the content is already valid and clean then validate(content) should just return the content unmodified."""
-# content = {'qwerty': 'uiop'}
-# self.assertEqual(validator.validate_request(content, None), content)
-
-# def validation_failure_raises_response_exception(self, validator):
-# """If form validation fails a ResourceException 400 (Bad Request) should be raised."""
-# content = {}
-# self.assertRaises(ImmediateResponse, validator.validate_request, content, None)
-
-# def validation_does_not_allow_extra_fields_by_default(self, validator):
-# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
-# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
-# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'qwerty': 'uiop', 'extra': 'extra'}
-# self.assertRaises(ImmediateResponse, validator.validate_request, content, None)
-
-# def validation_allows_extra_fields_if_explicitly_set(self, validator):
-# """If we include an allowed_extra_fields paramater on _validate, then allow fields with those names."""
-# content = {'qwerty': 'uiop', 'extra': 'extra'}
-# validator._validate(content, None, allowed_extra_fields=('extra',))
-
-# def validation_allows_unknown_fields_if_explicitly_allowed(self, validator):
-# """If we set ``unknown_form_fields`` on the form resource, then don't
-# raise errors on unexpected request data"""
-# content = {'qwerty': 'uiop', 'extra': 'extra'}
-# validator.allow_unknown_form_fields = True
-# self.assertEqual({'qwerty': u'uiop'},
-# validator.validate_request(content, None),
-# "Resource didn't accept unknown fields.")
-# validator.allow_unknown_form_fields = False
-
-# def validation_does_not_require_extra_fields_if_explicitly_set(self, validator):
-# """If we include an allowed_extra_fields paramater on _validate, then do not fail if we do not have fields with those names."""
-# content = {'qwerty': 'uiop'}
-# self.assertEqual(validator._validate(content, None, allowed_extra_fields=('extra',)), content)
-
-# def validation_failed_due_to_no_content_returns_appropriate_message(self, validator):
-# """If validation fails due to no content, ensure the response contains a single non-field error"""
-# content = {}
-# try:
-# validator.validate_request(content, None)
-# except ImmediateResponse, exc:
-# response = exc.response
-# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}})
-# else:
-# self.fail('ResourceException was not raised')
-
-# def validation_failed_due_to_field_error_returns_appropriate_message(self, validator):
-# """If validation fails due to a field error, ensure the response contains a single field error"""
-# content = {'qwerty': ''}
-# try:
-# validator.validate_request(content, None)
-# except ImmediateResponse, exc:
-# response = exc.response
-# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}})
-# else:
-# self.fail('ResourceException was not raised')
-
-# def validation_failed_due_to_invalid_field_returns_appropriate_message(self, validator):
-# """If validation fails due to an invalid field, ensure the response contains a single field error"""
-# content = {'qwerty': 'uiop', 'extra': 'extra'}
-# try:
-# validator.validate_request(content, None)
-# except ImmediateResponse, exc:
-# response = exc.response
-# self.assertEqual(response.raw_content, {'field_errors': {'extra': ['This field does not exist.']}})
-# else:
-# self.fail('ResourceException was not raised')
-
-# def validation_failed_due_to_multiple_errors_returns_appropriate_message(self, validator):
-# """If validation for multiple reasons, ensure the response contains each error"""
-# content = {'qwerty': '', 'extra': 'extra'}
-# try:
-# validator.validate_request(content, None)
-# except ImmediateResponse, exc:
-# response = exc.response
-# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.'],
-# 'extra': ['This field does not exist.']}})
-# else:
-# self.fail('ResourceException was not raised')
-
-# # Tests on FormResource
-
-# def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
-
-# def test_form_validation_failure_raises_response_exception(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_failure_raises_response_exception(validator)
-
-# def test_validation_does_not_allow_extra_fields_by_default(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_does_not_allow_extra_fields_by_default(validator)
-
-# def test_validation_allows_extra_fields_if_explicitly_set(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_allows_extra_fields_if_explicitly_set(validator)
-
-# def test_validation_allows_unknown_fields_if_explicitly_allowed(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_allows_unknown_fields_if_explicitly_allowed(validator)
-
-# def test_validation_does_not_require_extra_fields_if_explicitly_set(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
-
-# def test_validation_failed_due_to_no_content_returns_appropriate_message(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
-
-# def test_validation_failed_due_to_field_error_returns_appropriate_message(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
-
-# def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
-
-# def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
-
-# # Same tests on ModelResource
-
-# def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
-
-# def test_modelform_validation_failure_raises_response_exception(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_failure_raises_response_exception(validator)
-
-# def test_modelform_validation_does_not_allow_extra_fields_by_default(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_does_not_allow_extra_fields_by_default(validator)
-
-# def test_modelform_validation_allows_extra_fields_if_explicitly_set(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_allows_extra_fields_if_explicitly_set(validator)
-
-# def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
-
-# def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
-
-# def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
-
-# def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
-
-# def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
-
-
-# class TestModelFormValidator(TestCase):
-# """Tests specific to ModelFormValidatorMixin"""
-
-# def setUp(self):
-# """Create a validator for a model with two fields and a property."""
-# class MockModel(models.Model):
-# qwerty = models.CharField(max_length=256)
-# uiop = models.CharField(max_length=256, blank=True)
-
-# @property
-# def read_only(self):
-# return 'read only'
-
-# class MockResource(ModelResource):
-# model = MockModel
-
-# class MockView(View):
-# resource = MockResource
-
-# self.validator = MockResource(MockView)
-
-# def test_property_fields_are_allowed_on_model_forms(self):
-# """Validation on ModelForms may include property fields that exist on the Model to be included in the input."""
-# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'}
-# self.assertEqual(self.validator.validate_request(content, None), content)
-
-# def test_property_fields_are_not_required_on_model_forms(self):
-# """Validation on ModelForms does not require property fields that exist on the Model to be included in the input."""
-# content = {'qwerty': 'example', 'uiop': 'example'}
-# self.assertEqual(self.validator.validate_request(content, None), content)
-
-# def test_extra_fields_not_allowed_on_model_forms(self):
-# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
-# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
-# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'}
-# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
-
-# def test_validate_requires_fields_on_model_forms(self):
-# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
-# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
-# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'read_only': 'read only'}
-# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
-
-# def test_validate_does_not_require_blankable_fields_on_model_forms(self):
-# """Test standard ModelForm validation behaviour - fields with blank=True are not required."""
-# content = {'qwerty': 'example', 'read_only': 'read only'}
-# self.validator.validate_request(content, None)
-
-# def test_model_form_validator_uses_model_forms(self):
-# self.assertTrue(isinstance(self.validator.get_bound_form(), forms.ModelForm))
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py
index 7cd82656..994cf6dc 100644
--- a/rest_framework/tests/views.py
+++ b/rest_framework/tests/views.py
@@ -1,4 +1,4 @@
-import copy
+from __future__ import unicode_literals
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import status
@@ -6,6 +6,7 @@ from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.views import APIView
+import copy
factory = RequestFactory()
@@ -49,10 +50,10 @@ class ClassBasedViewIntegrationTests(TestCase):
request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request)
expected = {
- 'detail': u'JSON parse error - No JSON object could be decoded'
+ 'detail': 'JSON parse error - No JSON object could be decoded'
}
- self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEquals(sanitise_json_error(response.data), expected)
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
def test_400_parse_error_tunneled_content(self):
content = 'f00bar'
@@ -64,10 +65,10 @@ class ClassBasedViewIntegrationTests(TestCase):
request = factory.post('/', form_data)
response = self.view(request)
expected = {
- 'detail': u'JSON parse error - No JSON object could be decoded'
+ 'detail': 'JSON parse error - No JSON object could be decoded'
}
- self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEquals(sanitise_json_error(response.data), expected)
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
class FunctionBasedViewIntegrationTests(TestCase):
@@ -78,10 +79,10 @@ class FunctionBasedViewIntegrationTests(TestCase):
request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request)
expected = {
- 'detail': u'JSON parse error - No JSON object could be decoded'
+ 'detail': 'JSON parse error - No JSON object could be decoded'
}
- self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEquals(sanitise_json_error(response.data), expected)
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
def test_400_parse_error_tunneled_content(self):
content = 'f00bar'
@@ -93,7 +94,7 @@ class FunctionBasedViewIntegrationTests(TestCase):
request = factory.post('/', form_data)
response = self.view(request)
expected = {
- 'detail': u'JSON parse error - No JSON object could be decoded'
+ 'detail': 'JSON parse error - No JSON object could be decoded'
}
- self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEquals(sanitise_json_error(response.data), expected)
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 8fe64248..810cad63 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -1,7 +1,8 @@
-import time
+from __future__ import unicode_literals
from django.core.cache import cache
from rest_framework import exceptions
from rest_framework.settings import api_settings
+import time
class BaseThrottle(object):
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index 143928c9..d9143bb4 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -1,7 +1,38 @@
-from rest_framework.compat import url
+from __future__ import unicode_literals
+from django.core.urlresolvers import RegexURLResolver
+from rest_framework.compat import url, include
from rest_framework.settings import api_settings
+def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required):
+ ret = []
+ for urlpattern in urlpatterns:
+ if isinstance(urlpattern, RegexURLResolver):
+ # Set of included URL patterns
+ regex = urlpattern.regex.pattern
+ namespace = urlpattern.namespace
+ app_name = urlpattern.app_name
+ kwargs = urlpattern.default_kwargs
+ # Add in the included patterns, after applying the suffixes
+ patterns = apply_suffix_patterns(urlpattern.url_patterns,
+ suffix_pattern,
+ suffix_required)
+ ret.append(url(regex, include(patterns, namespace, app_name), kwargs))
+
+ else:
+ # Regular URL pattern
+ regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern
+ view = urlpattern._callback or urlpattern._callback_str
+ kwargs = urlpattern.default_args
+ name = urlpattern.name
+ # Add in both the existing and the new urlpattern
+ if not suffix_required:
+ ret.append(urlpattern)
+ ret.append(url(regex, view, kwargs, name))
+
+ return ret
+
+
def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
"""
Supplement existing urlpatterns with corresponding patterns that also
@@ -28,15 +59,4 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
else:
suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg
- ret = []
- for urlpattern in urlpatterns:
- # Form our complementing '.format' urlpattern
- regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern
- view = urlpattern._callback or urlpattern._callback_str
- kwargs = urlpattern.default_args
- name = urlpattern.name
- # Add in both the existing and the new urlpattern
- if not suffix_required:
- ret.append(urlpattern)
- ret.append(url(regex, view, kwargs, name))
- return ret
+ return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required)
diff --git a/rest_framework/urls.py b/rest_framework/urls.py
index fbe4bc07..9c4719f1 100644
--- a/rest_framework/urls.py
+++ b/rest_framework/urls.py
@@ -12,6 +12,7 @@ your authentication settings include `SessionAuthentication`.
url(r'^auth', include('rest_framework.urls', namespace='rest_framework'))
)
"""
+from __future__ import unicode_literals
from rest_framework.compat import patterns, url
diff --git a/rest_framework/utils/__init__.py b/rest_framework/utils/__init__.py
index 84fcb5db..e69de29b 100644
--- a/rest_framework/utils/__init__.py
+++ b/rest_framework/utils/__init__.py
@@ -1,100 +0,0 @@
-from django.utils.encoding import smart_unicode
-from django.utils.xmlutils import SimplerXMLGenerator
-from rest_framework.compat import StringIO
-import re
-import xml.etree.ElementTree as ET
-
-
-# From xml2dict
-class XML2Dict(object):
-
- def __init__(self):
- pass
-
- def _parse_node(self, node):
- node_tree = {}
- # Save attrs and text, hope there will not be a child with same name
- if node.text:
- node_tree = node.text
- for (k, v) in node.attrib.items():
- k, v = self._namespace_split(k, v)
- node_tree[k] = v
- #Save childrens
- for child in node.getchildren():
- tag, tree = self._namespace_split(child.tag, self._parse_node(child))
- if tag not in node_tree: # the first time, so store it in dict
- node_tree[tag] = tree
- continue
- old = node_tree[tag]
- if not isinstance(old, list):
- node_tree.pop(tag)
- node_tree[tag] = [old] # multi times, so change old dict to a list
- node_tree[tag].append(tree) # add the new one
-
- return node_tree
-
- def _namespace_split(self, tag, value):
- """
- Split the tag '{http://cs.sfsu.edu/csc867/myscheduler}patients'
- ns = http://cs.sfsu.edu/csc867/myscheduler
- name = patients
- """
- result = re.compile("\{(.*)\}(.*)").search(tag)
- if result:
- value.namespace, tag = result.groups()
- return (tag, value)
-
- def parse(self, file):
- """parse a xml file to a dict"""
- f = open(file, 'r')
- return self.fromstring(f.read())
-
- def fromstring(self, s):
- """parse a string"""
- t = ET.fromstring(s)
- unused_root_tag, root_tree = self._namespace_split(t.tag, self._parse_node(t))
- return root_tree
-
-
-def xml2dict(input):
- return XML2Dict().fromstring(input)
-
-
-# Piston:
-class XMLRenderer():
- def _to_xml(self, xml, data):
- if isinstance(data, (list, tuple)):
- for item in data:
- xml.startElement("list-item", {})
- self._to_xml(xml, item)
- xml.endElement("list-item")
-
- elif isinstance(data, dict):
- for key, value in data.iteritems():
- xml.startElement(key, {})
- self._to_xml(xml, value)
- xml.endElement(key)
-
- elif data is None:
- # Don't output any value
- pass
-
- else:
- xml.characters(smart_unicode(data))
-
- def dict2xml(self, data):
- stream = StringIO.StringIO()
-
- xml = SimplerXMLGenerator(stream, "utf-8")
- xml.startDocument()
- xml.startElement("root", {})
-
- self._to_xml(xml, data)
-
- xml.endElement("root")
- xml.endDocument()
- return stream.getvalue()
-
-
-def dict2xml(input):
- return XMLRenderer().dict2xml(input)
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py
index 80e39d46..af21ac79 100644
--- a/rest_framework/utils/breadcrumbs.py
+++ b/rest_framework/utils/breadcrumbs.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from django.core.urlresolvers import resolve, get_script_prefix
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 7afe100a..b6de18a8 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -1,13 +1,14 @@
"""
Helper classes for parsers.
"""
+from __future__ import unicode_literals
+from django.utils.datastructures import SortedDict
+from rest_framework.compat import timezone
+from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
import datetime
import decimal
import types
import json
-from django.utils.datastructures import SortedDict
-from rest_framework.compat import timezone
-from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
class JSONEncoder(json.JSONEncoder):
diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py
index ee7f3a54..c09c2933 100644
--- a/rest_framework/utils/mediatypes.py
+++ b/rest_framework/utils/mediatypes.py
@@ -3,8 +3,9 @@ Handling of media types, as found in HTTP Content-Type and Accept headers.
See http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7
"""
-
+from __future__ import unicode_literals
from django.http.multipartparser import parse_header
+from rest_framework import HTTP_HEADER_ENCODING
def media_type_matches(lhs, rhs):
@@ -47,7 +48,7 @@ class _MediaType(object):
if media_type_str is None:
media_type_str = ''
self.orig = media_type_str
- self.full_type, self.params = parse_header(media_type_str)
+ self.full_type, self.params = parse_header(media_type_str.encode(HTTP_HEADER_ENCODING))
self.main_type, sep, self.sub_type = self.full_type.partition('/')
def match(self, other):
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 10bdd5a5..81cbdcbb 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -1,8 +1,7 @@
"""
Provides an APIView class that is used as the base of all class-based views.
"""
-
-import re
+from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.utils.html import escape
@@ -13,6 +12,7 @@ from rest_framework.compat import View, apply_markdown
from rest_framework.response import Response
from rest_framework.request import Request
from rest_framework.settings import api_settings
+import re
def _remove_trailing_string(content, trailing):
@@ -148,6 +148,8 @@ class APIView(View):
"""
If request is not permitted, determine what kind of exception to raise.
"""
+ if not self.request.successful_authenticator:
+ raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied()
def throttled(self, request, wait):
@@ -156,6 +158,15 @@ class APIView(View):
"""
raise exceptions.Throttled(wait)
+ def get_authenticate_header(self, request):
+ """
+ If a request is unauthenticated, determine the WWW-Authenticate
+ header to use for 401 responses, if any.
+ """
+ authenticators = self.get_authenticators()
+ if authenticators:
+ return authenticators[0].authenticate_header(request)
+
def get_parser_context(self, http_request):
"""
Returns a dict that is passed through to Parser.parse(),
@@ -200,13 +211,13 @@ class APIView(View):
def get_parsers(self):
"""
- Instantiates and returns the list of renderers that this view can use.
+ Instantiates and returns the list of parsers that this view can use.
"""
return [parser() for parser in self.parser_classes]
def get_authenticators(self):
"""
- Instantiates and returns the list of renderers that this view can use.
+ Instantiates and returns the list of authenticators that this view can use.
"""
return [auth() for auth in self.authentication_classes]
@@ -241,23 +252,43 @@ class APIView(View):
try:
return conneg.select_renderer(request, renderers, self.format_kwarg)
- except:
+ except Exception:
if force:
return (renderers[0], renderers[0].media_type)
raise
- def has_permission(self, request, obj=None):
+ def perform_authentication(self, request):
"""
- Return `True` if the request should be permitted.
+ Perform authentication on the incoming request.
+
+ Note that if you override this and simply 'pass', then authentication
+ will instead be performed lazily, the first time either
+ `request.user` or `request.auth` is accessed.
+ """
+ request.user
+
+ def check_permissions(self, request):
+ """
+ Check if the request should be permitted.
+ Raises an appropriate exception if the request is not permitted.
"""
for permission in self.get_permissions():
- if not permission.has_permission(request, self, obj):
- return False
- return True
+ if not permission.has_permission(request, self):
+ self.permission_denied(request)
+
+ def check_object_permissions(self, request, obj):
+ """
+ Check if the request should be permitted for a given object.
+ Raises an appropriate exception if the request is not permitted.
+ """
+ for permission in self.get_permissions():
+ if not permission.has_object_permission(request, self, obj):
+ self.permission_denied(request)
def check_throttles(self, request):
"""
Check if request should be throttled.
+ Raises an appropriate exception if the request is throttled.
"""
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
@@ -284,8 +315,8 @@ class APIView(View):
self.format_kwarg = self.get_format_suffix(**kwargs)
# Ensure that the incoming request is permitted
- if not self.has_permission(request):
- self.permission_denied(request)
+ self.perform_authentication(request)
+ self.check_permissions(request)
self.check_throttles(request)
# Perform content negotiation and store the accepted info on the request
@@ -319,6 +350,16 @@ class APIView(View):
# Throttle wait header
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
+ if isinstance(exc, (exceptions.NotAuthenticated,
+ exceptions.AuthenticationFailed)):
+ # WWW-Authenticate header for 401 responses, else coerce to 403
+ auth_header = self.get_authenticate_header(self.request)
+
+ if auth_header:
+ self.headers['WWW-Authenticate'] = auth_header
+ else:
+ exc.status_code = status.HTTP_403_FORBIDDEN
+
if isinstance(exc, exceptions.APIException):
return Response({'detail': exc.detail},
status=exc.status_code,