aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py10
-rw-r--r--rest_framework/authentication.py18
-rw-r--r--rest_framework/authtoken/migrations/0001_initial.py94
-rw-r--r--rest_framework/authtoken/models.py3
-rw-r--r--rest_framework/authtoken/serializers.py11
-rw-r--r--rest_framework/authtoken/south_migrations/0001_initial.py60
-rw-r--r--rest_framework/authtoken/south_migrations/__init__.py (renamed from rest_framework/runtests/__init__.py)0
-rw-r--r--rest_framework/compat.py388
-rw-r--r--rest_framework/decorators.py45
-rw-r--r--rest_framework/exceptions.py1
-rw-r--r--rest_framework/fields.py57
-rw-r--r--rest_framework/filters.py9
-rw-r--r--rest_framework/generics.py98
-rw-r--r--rest_framework/mixins.py8
-rw-r--r--rest_framework/negotiation.py6
-rw-r--r--rest_framework/parsers.py7
-rw-r--r--rest_framework/permissions.py36
-rw-r--r--rest_framework/relations.py123
-rw-r--r--rest_framework/renderers.py71
-rw-r--r--rest_framework/request.py38
-rw-r--r--rest_framework/response.py14
-rw-r--r--rest_framework/routers.py71
-rwxr-xr-xrest_framework/runtests/runcoverage.py78
-rwxr-xr-xrest_framework/runtests/runtests.py52
-rw-r--r--rest_framework/runtests/settings.py169
-rw-r--r--rest_framework/runtests/urls.py7
-rw-r--r--rest_framework/serializers.py126
-rw-r--r--rest_framework/settings.py19
-rw-r--r--rest_framework/six.py389
-rw-r--r--rest_framework/status.py4
-rw-r--r--rest_framework/templates/rest_framework/base.html7
-rw-r--r--rest_framework/templates/rest_framework/login_base.html16
-rw-r--r--rest_framework/templatetags/rest_framework.py101
-rw-r--r--rest_framework/test.py16
-rw-r--r--rest_framework/tests/__init__.py0
-rw-r--r--rest_framework/tests/accounts/__init__.py0
-rw-r--r--rest_framework/tests/accounts/models.py8
-rw-r--r--rest_framework/tests/accounts/serializers.py11
-rw-r--r--rest_framework/tests/description.py26
-rw-r--r--rest_framework/tests/extras/__init__.py0
-rw-r--r--rest_framework/tests/extras/bad_import.py1
-rw-r--r--rest_framework/tests/models.py177
-rw-r--r--rest_framework/tests/records/__init__.py0
-rw-r--r--rest_framework/tests/records/models.py6
-rw-r--r--rest_framework/tests/serializers.py8
-rw-r--r--rest_framework/tests/test_authentication.py663
-rw-r--r--rest_framework/tests/test_breadcrumbs.py73
-rw-r--r--rest_framework/tests/test_decorators.py157
-rw-r--r--rest_framework/tests/test_description.py108
-rw-r--r--rest_framework/tests/test_fields.py984
-rw-r--r--rest_framework/tests/test_files.py95
-rw-r--r--rest_framework/tests/test_filters.py661
-rw-r--r--rest_framework/tests/test_genericrelations.py133
-rw-r--r--rest_framework/tests/test_generics.py609
-rw-r--r--rest_framework/tests/test_htmlrenderer.py120
-rw-r--r--rest_framework/tests/test_hyperlinkedserializers.py379
-rw-r--r--rest_framework/tests/test_multitable_inheritance.py67
-rw-r--r--rest_framework/tests/test_negotiation.py45
-rw-r--r--rest_framework/tests/test_nullable_fields.py30
-rw-r--r--rest_framework/tests/test_pagination.py521
-rw-r--r--rest_framework/tests/test_parsers.py115
-rw-r--r--rest_framework/tests/test_permissions.py300
-rw-r--r--rest_framework/tests/test_relations.py120
-rw-r--r--rest_framework/tests/test_relations_hyperlink.py524
-rw-r--r--rest_framework/tests/test_relations_nested.py326
-rw-r--r--rest_framework/tests/test_relations_pk.py551
-rw-r--r--rest_framework/tests/test_relations_slug.py257
-rw-r--r--rest_framework/tests/test_renderers.py655
-rw-r--r--rest_framework/tests/test_request.py347
-rw-r--r--rest_framework/tests/test_response.py278
-rw-r--r--rest_framework/tests/test_reverse.py27
-rw-r--r--rest_framework/tests/test_routers.py216
-rw-r--r--rest_framework/tests/test_serializer.py1949
-rw-r--r--rest_framework/tests/test_serializer_bulk_update.py278
-rw-r--r--rest_framework/tests/test_serializer_empty.py15
-rw-r--r--rest_framework/tests/test_serializer_import.py19
-rw-r--r--rest_framework/tests/test_serializer_nested.py347
-rw-r--r--rest_framework/tests/test_serializers.py28
-rw-r--r--rest_framework/tests/test_settings.py22
-rw-r--r--rest_framework/tests/test_status.py33
-rw-r--r--rest_framework/tests/test_templatetags.py51
-rw-r--r--rest_framework/tests/test_testing.py164
-rw-r--r--rest_framework/tests/test_throttling.py281
-rw-r--r--rest_framework/tests/test_urlpatterns.py76
-rw-r--r--rest_framework/tests/test_validation.py148
-rw-r--r--rest_framework/tests/test_views.py142
-rw-r--r--rest_framework/tests/test_write_only_fields.py42
-rw-r--r--rest_framework/tests/tests.py16
-rw-r--r--rest_framework/tests/users/__init__.py0
-rw-r--r--rest_framework/tests/users/models.py6
-rw-r--r--rest_framework/tests/users/serializers.py8
-rw-r--r--rest_framework/tests/utils.py25
-rw-r--r--rest_framework/tests/views.py8
-rw-r--r--rest_framework/throttling.py29
-rw-r--r--rest_framework/urlpatterns.py2
-rw-r--r--rest_framework/urls.py18
-rw-r--r--rest_framework/utils/encoders.py34
-rw-r--r--rest_framework/utils/formatting.py4
-rw-r--r--rest_framework/utils/mediatypes.py6
-rw-r--r--rest_framework/views.py2
-rw-r--r--rest_framework/viewsets.py10
101 files changed, 656 insertions, 13857 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 2d76b55d..f30012b9 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,14 +1,14 @@
"""
-______ _____ _____ _____ __ _
-| ___ \ ___/ ___|_ _| / _| | |
-| |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| | __
+______ _____ _____ _____ __
+| ___ \ ___/ ___|_ _| / _| | |
+| |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| |__
| /| __| `--. \ | | | _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ /
-| |\ \| |___/\__/ / | | | | | | | (_| | | | | | | __/\ V V / (_) | | | <
+| |\ \| |___/\__/ / | | | | | | | (_| | | | | | | __/\ V V / (_) | | | <
\_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_|
"""
__title__ = 'Django REST framework'
-__version__ = '2.3.13'
+__version__ = '2.3.14'
__author__ = 'Tom Christie'
__license__ = 'BSD 2-Clause'
__copyright__ = 'Copyright 2011-2014 Tom Christie'
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index da9ca510..5721a869 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -6,9 +6,9 @@ import base64
from django.contrib.auth import authenticate
from django.core.exceptions import ImproperlyConfigured
+from django.middleware.csrf import CsrfViewMiddleware
from django.conf import settings
from rest_framework import exceptions, HTTP_HEADER_ENCODING
-from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
from rest_framework.compat import oauth2_provider, provider_now, check_nonce
from rest_framework.authtoken.models import Token
@@ -21,7 +21,7 @@ def get_authorization_header(request):
Hide some test client ickyness where the header can be unicode.
"""
auth = request.META.get('HTTP_AUTHORIZATION', b'')
- if type(auth) == type(''):
+ if isinstance(auth, type('')):
# Work around django test client oddness
auth = auth.encode(HTTP_HEADER_ENCODING)
return auth
@@ -310,6 +310,13 @@ class OAuth2Authentication(BaseAuthentication):
auth = get_authorization_header(request).split()
+ if len(auth) == 1:
+ msg = 'Invalid bearer header. No credentials provided.'
+ raise exceptions.AuthenticationFailed(msg)
+ elif len(auth) > 2:
+ msg = 'Invalid bearer header. Token string should not contain spaces.'
+ raise exceptions.AuthenticationFailed(msg)
+
if auth and auth[0].lower() == b'bearer':
access_token = auth[1]
elif 'access_token' in request.POST:
@@ -319,13 +326,6 @@ class OAuth2Authentication(BaseAuthentication):
else:
return None
- if len(auth) == 1:
- msg = 'Invalid bearer header. No credentials provided.'
- raise exceptions.AuthenticationFailed(msg)
- elif len(auth) > 2:
- msg = 'Invalid bearer header. Token string should not contain spaces.'
- raise exceptions.AuthenticationFailed(msg)
-
return self.authenticate_credentials(request, access_token)
def authenticate_credentials(self, request, access_token):
diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py
index d5965e40..2e5d6b47 100644
--- a/rest_framework/authtoken/migrations/0001_initial.py
+++ b/rest_framework/authtoken/migrations/0001_initial.py
@@ -1,67 +1,27 @@
-# -*- coding: utf-8 -*-
-import datetime
-from south.db import db
-from south.v2 import SchemaMigration
-from django.db import models
-
-from rest_framework.settings import api_settings
-
-
-try:
- from django.contrib.auth import get_user_model
-except ImportError: # django < 1.5
- from django.contrib.auth.models import User
-else:
- User = get_user_model()
-
-
-class Migration(SchemaMigration):
-
- def forwards(self, orm):
- # Adding model 'Token'
- db.create_table('authtoken_token', (
- ('key', self.gf('django.db.models.fields.CharField')(max_length=40, primary_key=True)),
- ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['%s.%s' % (User._meta.app_label, User._meta.object_name)])),
- ('created', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)),
- ))
- db.send_create_signal('authtoken', ['Token'])
-
-
- def backwards(self, orm):
- # Deleting model 'Token'
- db.delete_table('authtoken_token')
-
-
- models = {
- 'auth.group': {
- 'Meta': {'object_name': 'Group'},
- 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
- 'name': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '80'}),
- 'permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'})
- },
- 'auth.permission': {
- 'Meta': {'ordering': "('content_type__app_label', 'content_type__model', 'codename')", 'unique_together': "(('content_type', 'codename'),)", 'object_name': 'Permission'},
- 'codename': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
- 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['contenttypes.ContentType']"}),
- 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
- 'name': ('django.db.models.fields.CharField', [], {'max_length': '50'})
- },
- "%s.%s" % (User._meta.app_label, User._meta.module_name): {
- 'Meta': {'object_name': User._meta.module_name},
- },
- 'authtoken.token': {
- 'Meta': {'object_name': 'Token'},
- 'created': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}),
- 'key': ('django.db.models.fields.CharField', [], {'max_length': '40', 'primary_key': 'True'}),
- 'user': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'auth_token'", 'unique': 'True', 'to': "orm['%s.%s']" % (User._meta.app_label, User._meta.object_name)})
- },
- 'contenttypes.contenttype': {
- 'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"},
- 'app_label': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
- 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
- 'model': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
- 'name': ('django.db.models.fields.CharField', [], {'max_length': '100'})
- }
- }
-
- complete_apps = ['authtoken']
+# encoding: utf8
+from __future__ import unicode_literals
+
+from django.db import models, migrations
+from django.conf import settings
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ migrations.swappable_dependency(settings.AUTH_USER_MODEL),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name='Token',
+ fields=[
+ ('key', models.CharField(max_length=40, serialize=False, primary_key=True)),
+ ('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, to_field='id')),
+ ('created', models.DateTimeField(auto_now_add=True)),
+ ],
+ options={
+ 'abstract': False,
+ },
+ bases=(models.Model,),
+ ),
+ ]
diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py
index 8eac2cc4..db21d44c 100644
--- a/rest_framework/authtoken/models.py
+++ b/rest_framework/authtoken/models.py
@@ -1,6 +1,5 @@
import binascii
import os
-from hashlib import sha1
from django.conf import settings
from django.db import models
@@ -34,7 +33,7 @@ class Token(models.Model):
return super(Token, self).save(*args, **kwargs)
def generate_key(self):
- return binascii.hexlify(os.urandom(20))
+ return binascii.hexlify(os.urandom(20)).decode()
def __unicode__(self):
return self.key
diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py
index 60a3740e..99e99ae3 100644
--- a/rest_framework/authtoken/serializers.py
+++ b/rest_framework/authtoken/serializers.py
@@ -1,4 +1,6 @@
from django.contrib.auth import authenticate
+from django.utils.translation import ugettext_lazy as _
+
from rest_framework import serializers
@@ -15,10 +17,13 @@ class AuthTokenSerializer(serializers.Serializer):
if user:
if not user.is_active:
- raise serializers.ValidationError('User account is disabled.')
+ msg = _('User account is disabled.')
+ raise serializers.ValidationError(msg)
attrs['user'] = user
return attrs
else:
- raise serializers.ValidationError('Unable to login with provided credentials.')
+ msg = _('Unable to login with provided credentials.')
+ raise serializers.ValidationError(msg)
else:
- raise serializers.ValidationError('Must include "username" and "password"')
+ msg = _('Must include "username" and "password"')
+ raise serializers.ValidationError(msg)
diff --git a/rest_framework/authtoken/south_migrations/0001_initial.py b/rest_framework/authtoken/south_migrations/0001_initial.py
new file mode 100644
index 00000000..926de02b
--- /dev/null
+++ b/rest_framework/authtoken/south_migrations/0001_initial.py
@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+from south.db import db
+from south.v2 import SchemaMigration
+
+try:
+ from django.contrib.auth import get_user_model
+except ImportError: # django < 1.5
+ from django.contrib.auth.models import User
+else:
+ User = get_user_model()
+
+
+class Migration(SchemaMigration):
+
+ def forwards(self, orm):
+ # Adding model 'Token'
+ db.create_table('authtoken_token', (
+ ('key', self.gf('django.db.models.fields.CharField')(max_length=40, primary_key=True)),
+ ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['%s.%s' % (User._meta.app_label, User._meta.object_name)])),
+ ('created', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)),
+ ))
+ db.send_create_signal('authtoken', ['Token'])
+
+ def backwards(self, orm):
+ # Deleting model 'Token'
+ db.delete_table('authtoken_token')
+
+ models = {
+ 'auth.group': {
+ 'Meta': {'object_name': 'Group'},
+ 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
+ 'name': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '80'}),
+ 'permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'})
+ },
+ 'auth.permission': {
+ 'Meta': {'ordering': "('content_type__app_label', 'content_type__model', 'codename')", 'unique_together': "(('content_type', 'codename'),)", 'object_name': 'Permission'},
+ 'codename': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
+ 'content_type': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['contenttypes.ContentType']"}),
+ 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
+ 'name': ('django.db.models.fields.CharField', [], {'max_length': '50'})
+ },
+ "%s.%s" % (User._meta.app_label, User._meta.module_name): {
+ 'Meta': {'object_name': User._meta.module_name},
+ },
+ 'authtoken.token': {
+ 'Meta': {'object_name': 'Token'},
+ 'created': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}),
+ 'key': ('django.db.models.fields.CharField', [], {'max_length': '40', 'primary_key': 'True'}),
+ 'user': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'auth_token'", 'unique': 'True', 'to': "orm['%s.%s']" % (User._meta.app_label, User._meta.object_name)})
+ },
+ 'contenttypes.contenttype': {
+ 'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"},
+ 'app_label': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
+ 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
+ 'model': ('django.db.models.fields.CharField', [], {'max_length': '100'}),
+ 'name': ('django.db.models.fields.CharField', [], {'max_length': '100'})
+ }
+ }
+
+ complete_apps = ['authtoken']
diff --git a/rest_framework/runtests/__init__.py b/rest_framework/authtoken/south_migrations/__init__.py
index e69de29b..e69de29b 100644
--- a/rest_framework/runtests/__init__.py
+++ b/rest_framework/authtoken/south_migrations/__init__.py
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index d155f554..fa0f0bfb 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -5,25 +5,14 @@ versions of django/python, and compatibility wrappers around optional packages.
# flake8: noqa
from __future__ import unicode_literals
-
import django
import inspect
from django.core.exceptions import ImproperlyConfigured
from django.conf import settings
+from django.utils import six
-# Try to import six from Django, fallback to included `six`.
-try:
- from django.utils import six
-except ImportError:
- from rest_framework import six
-
-# location of patterns, url, include changes in 1.4 onwards
-try:
- from django.conf.urls import patterns, url, include
-except ImportError:
- from django.conf.urls.defaults import patterns, url, include
-# Handle django.utils.encoding rename:
+# Handle django.utils.encoding rename in 1.5 onwards.
# smart_unicode -> smart_text
# force_unicode -> force_text
try:
@@ -42,17 +31,23 @@ try:
except ImportError:
from django.http import HttpResponse as HttpResponseBase
+
# django-filter is optional
try:
import django_filters
except ImportError:
django_filters = None
-# guardian is optional
-try:
- import guardian
-except ImportError:
- guardian = None
+
+# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
+# Fixes (#1712). We keep the try/except for the test suite.
+guardian = None
+if 'guardian' in settings.INSTALLED_APPS:
+ try:
+ import guardian
+ import guardian.shortcuts # Fixes #1624
+ except ImportError:
+ pass
# cStringIO only if it's available, otherwise StringIO
@@ -104,46 +99,13 @@ def get_concrete_model(model_cls):
return model_cls
+# View._allowed_methods only present from 1.5 onwards
if django.VERSION >= (1, 5):
from django.views.generic import View
else:
- from django.views.generic import View as _View
- from django.utils.decorators import classonlymethod
- from django.utils.functional import update_wrapper
-
- class View(_View):
- # 1.3 does not include head method in base View class
- # See: https://code.djangoproject.com/ticket/15668
- @classonlymethod
- def as_view(cls, **initkwargs):
- """
- Main entry point for a request-response process.
- """
- # sanitize keyword arguments
- for key in initkwargs:
- if key in cls.http_method_names:
- raise TypeError("You tried to pass in the %s method name as a "
- "keyword argument to %s(). Don't do that."
- % (key, cls.__name__))
- if not hasattr(cls, key):
- raise TypeError("%s() received an invalid keyword %r" % (
- cls.__name__, key))
-
- def view(request, *args, **kwargs):
- self = cls(**initkwargs)
- if hasattr(self, 'get') and not hasattr(self, 'head'):
- self.head = self.get
- return self.dispatch(request, *args, **kwargs)
-
- # take name and docstring from class
- update_wrapper(view, cls, updated=())
-
- # and possible attributes set by decorators
- # like csrf_exempt from dispatch
- update_wrapper(view, cls.dispatch, assigned=())
- return view
-
- # _allowed_methods only present from 1.5 onwards
+ from django.views.generic import View as DjangoView
+
+ class View(DjangoView):
def _allowed_methods(self):
return [m.upper() for m in self.http_method_names if hasattr(self, m)]
@@ -153,316 +115,16 @@ if 'patch' not in View.http_method_names:
View.http_method_names = View.http_method_names + ['patch']
-# PUT, DELETE do not require CSRF until 1.4. They should. Make it better.
-if django.VERSION >= (1, 4):
- from django.middleware.csrf import CsrfViewMiddleware
-else:
- import hashlib
- import re
- import random
- import logging
-
- from django.conf import settings
- from django.core.urlresolvers import get_callable
-
- try:
- from logging import NullHandler
- except ImportError:
- class NullHandler(logging.Handler):
- def emit(self, record):
- pass
-
- logger = logging.getLogger('django.request')
-
- if not logger.handlers:
- logger.addHandler(NullHandler())
-
- def same_origin(url1, url2):
- """
- Checks if two URLs are 'same-origin'
- """
- p1, p2 = urlparse.urlparse(url1), urlparse.urlparse(url2)
- return p1[0:2] == p2[0:2]
-
- def constant_time_compare(val1, val2):
- """
- Returns True if the two strings are equal, False otherwise.
-
- The time taken is independent of the number of characters that match.
- """
- if len(val1) != len(val2):
- return False
- result = 0
- for x, y in zip(val1, val2):
- result |= ord(x) ^ ord(y)
- return result == 0
-
- # Use the system (hardware-based) random number generator if it exists.
- if hasattr(random, 'SystemRandom'):
- randrange = random.SystemRandom().randrange
- else:
- randrange = random.randrange
-
- _MAX_CSRF_KEY = 18446744073709551616 # 2 << 63
-
- REASON_NO_REFERER = "Referer checking failed - no Referer."
- REASON_BAD_REFERER = "Referer checking failed - %s does not match %s."
- REASON_NO_CSRF_COOKIE = "CSRF cookie not set."
- REASON_BAD_TOKEN = "CSRF token missing or incorrect."
-
- def _get_failure_view():
- """
- Returns the view to be used for CSRF rejections
- """
- return get_callable(settings.CSRF_FAILURE_VIEW)
-
- def _get_new_csrf_key():
- return hashlib.md5("%s%s" % (randrange(0, _MAX_CSRF_KEY), settings.SECRET_KEY)).hexdigest()
-
- def get_token(request):
- """
- Returns the the CSRF token required for a POST form. The token is an
- alphanumeric value.
-
- A side effect of calling this function is to make the the csrf_protect
- decorator and the CsrfViewMiddleware add a CSRF cookie and a 'Vary: Cookie'
- header to the outgoing response. For this reason, you may need to use this
- function lazily, as is done by the csrf context processor.
- """
- request.META["CSRF_COOKIE_USED"] = True
- return request.META.get("CSRF_COOKIE", None)
-
- def _sanitize_token(token):
- # Allow only alphanum, and ensure we return a 'str' for the sake of the post
- # processing middleware.
- token = re.sub('[^a-zA-Z0-9]', '', str(token.decode('ascii', 'ignore')))
- if token == "":
- # In case the cookie has been truncated to nothing at some point.
- return _get_new_csrf_key()
- else:
- return token
-
- class CsrfViewMiddleware(object):
- """
- Middleware that requires a present and correct csrfmiddlewaretoken
- for POST requests that have a CSRF cookie, and sets an outgoing
- CSRF cookie.
-
- This middleware should be used in conjunction with the csrf_token template
- tag.
- """
- # The _accept and _reject methods currently only exist for the sake of the
- # requires_csrf_token decorator.
- def _accept(self, request):
- # Avoid checking the request twice by adding a custom attribute to
- # request. This will be relevant when both decorator and middleware
- # are used.
- request.csrf_processing_done = True
- return None
-
- def _reject(self, request, reason):
- return _get_failure_view()(request, reason=reason)
-
- def process_view(self, request, callback, callback_args, callback_kwargs):
-
- if getattr(request, 'csrf_processing_done', False):
- return None
-
- try:
- csrf_token = _sanitize_token(request.COOKIES[settings.CSRF_COOKIE_NAME])
- # Use same token next time
- request.META['CSRF_COOKIE'] = csrf_token
- except KeyError:
- csrf_token = None
- # Generate token and store it in the request, so it's available to the view.
- request.META["CSRF_COOKIE"] = _get_new_csrf_key()
-
- # Wait until request.META["CSRF_COOKIE"] has been manipulated before
- # bailing out, so that get_token still works
- if getattr(callback, 'csrf_exempt', False):
- return None
-
- # Assume that anything not defined as 'safe' by RC2616 needs protection.
- if request.method not in ('GET', 'HEAD', 'OPTIONS', 'TRACE'):
- if getattr(request, '_dont_enforce_csrf_checks', False):
- # Mechanism to turn off CSRF checks for test suite. It comes after
- # the creation of CSRF cookies, so that everything else continues to
- # work exactly the same (e.g. cookies are sent etc), but before the
- # any branches that call reject()
- return self._accept(request)
-
- if request.is_secure():
- # Suppose user visits http://example.com/
- # An active network attacker,(man-in-the-middle, MITM) sends a
- # POST form which targets https://example.com/detonate-bomb/ and
- # submits it via javascript.
- #
- # The attacker will need to provide a CSRF cookie and token, but
- # that is no problem for a MITM and the session independent
- # nonce we are using. So the MITM can circumvent the CSRF
- # protection. This is true for any HTTP connection, but anyone
- # using HTTPS expects better! For this reason, for
- # https://example.com/ we need additional protection that treats
- # http://example.com/ as completely untrusted. Under HTTPS,
- # Barth et al. found that the Referer header is missing for
- # same-domain requests in only about 0.2% of cases or less, so
- # we can use strict Referer checking.
- referer = request.META.get('HTTP_REFERER')
- if referer is None:
- logger.warning('Forbidden (%s): %s' % (REASON_NO_REFERER, request.path),
- extra={
- 'status_code': 403,
- 'request': request,
- }
- )
- return self._reject(request, REASON_NO_REFERER)
-
- # Note that request.get_host() includes the port
- good_referer = 'https://%s/' % request.get_host()
- if not same_origin(referer, good_referer):
- reason = REASON_BAD_REFERER % (referer, good_referer)
- logger.warning('Forbidden (%s): %s' % (reason, request.path),
- extra={
- 'status_code': 403,
- 'request': request,
- }
- )
- return self._reject(request, reason)
-
- if csrf_token is None:
- # No CSRF cookie. For POST requests, we insist on a CSRF cookie,
- # and in this way we can avoid all CSRF attacks, including login
- # CSRF.
- logger.warning('Forbidden (%s): %s' % (REASON_NO_CSRF_COOKIE, request.path),
- extra={
- 'status_code': 403,
- 'request': request,
- }
- )
- return self._reject(request, REASON_NO_CSRF_COOKIE)
-
- # check non-cookie token for match
- request_csrf_token = ""
- if request.method == "POST":
- request_csrf_token = request.POST.get('csrfmiddlewaretoken', '')
-
- if request_csrf_token == "":
- # Fall back to X-CSRFToken, to make things easier for AJAX,
- # and possible for PUT/DELETE
- request_csrf_token = request.META.get('HTTP_X_CSRFTOKEN', '')
-
- if not constant_time_compare(request_csrf_token, csrf_token):
- logger.warning('Forbidden (%s): %s' % (REASON_BAD_TOKEN, request.path),
- extra={
- 'status_code': 403,
- 'request': request,
- }
- )
- return self._reject(request, REASON_BAD_TOKEN)
-
- return self._accept(request)
-
-# timezone support is new in Django 1.4
-try:
- from django.utils import timezone
-except ImportError:
- timezone = None
-
-# dateparse is ALSO new in Django 1.4
-try:
- from django.utils.dateparse import parse_date, parse_datetime, parse_time
-except ImportError:
- import datetime
- import re
-
- date_re = re.compile(
- r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})$'
- )
-
- datetime_re = re.compile(
- r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})'
- r'[T ](?P<hour>\d{1,2}):(?P<minute>\d{1,2})'
- r'(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?'
- r'(?P<tzinfo>Z|[+-]\d{1,2}:\d{1,2})?$'
- )
-
- time_re = re.compile(
- r'(?P<hour>\d{1,2}):(?P<minute>\d{1,2})'
- r'(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?'
- )
-
- def parse_date(value):
- match = date_re.match(value)
- if match:
- kw = dict((k, int(v)) for k, v in match.groupdict().iteritems())
- return datetime.date(**kw)
-
- def parse_time(value):
- match = time_re.match(value)
- if match:
- kw = match.groupdict()
- if kw['microsecond']:
- kw['microsecond'] = kw['microsecond'].ljust(6, '0')
- kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None)
- return datetime.time(**kw)
-
- def parse_datetime(value):
- """Parse datetime, but w/o the timezone awareness in 1.4"""
- match = datetime_re.match(value)
- if match:
- kw = match.groupdict()
- if kw['microsecond']:
- kw['microsecond'] = kw['microsecond'].ljust(6, '0')
- kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None)
- return datetime.datetime(**kw)
-
-
-# smart_urlquote is new on Django 1.4
-try:
- from django.utils.html import smart_urlquote
-except ImportError:
- import re
- from django.utils.encoding import smart_str
- try:
- from urllib.parse import quote, urlsplit, urlunsplit
- except ImportError: # Python 2
- from urllib import quote
- from urlparse import urlsplit, urlunsplit
-
- unquoted_percents_re = re.compile(r'%(?![0-9A-Fa-f]{2})')
-
- def smart_urlquote(url):
- "Quotes a URL if it isn't already quoted."
- # Handle IDN before quoting.
- scheme, netloc, path, query, fragment = urlsplit(url)
- try:
- netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE
- except UnicodeError: # invalid domain part
- pass
- else:
- url = urlunsplit((scheme, netloc, path, query, fragment))
-
- # An URL is considered unquoted if it contains no % characters or
- # contains a % not followed by two hexadecimal digits. See #9655.
- if '%' not in url or unquoted_percents_re.search(url):
- # See http://bugs.python.org/issue2637
- url = quote(smart_str(url), safe=b'!*\'();:@&=+$,/?#[]~')
-
- return force_text(url)
-
-
-# RequestFactory only provide `generic` from 1.5 onwards
-
+# RequestFactory only provides `generic` from 1.5 onwards
from django.test.client import RequestFactory as DjangoRequestFactory
from django.test.client import FakePayload
try:
# In 1.5 the test client uses force_bytes
from django.utils.encoding import force_bytes as force_bytes_or_smart_bytes
except ImportError:
- # In 1.3 and 1.4 the test client just uses smart_str
+ # In 1.4 the test client just uses smart_str
from django.utils.encoding import smart_str as force_bytes_or_smart_bytes
-
class RequestFactory(DjangoRequestFactory):
def generic(self, method, path,
data='', content_type='application/octet-stream', **extra):
@@ -487,6 +149,7 @@ class RequestFactory(DjangoRequestFactory):
r.update(extra)
return self.request(**r)
+
# Markdown is optional
try:
import markdown
@@ -501,7 +164,6 @@ try:
safe_mode = False
md = markdown.Markdown(extensions=extensions, safe_mode=safe_mode)
return md.convert(text)
-
except ImportError:
apply_markdown = None
@@ -519,14 +181,16 @@ try:
except ImportError:
etree = None
-# OAuth is optional
+
+# OAuth2 is optional
try:
# Note: The `oauth2` package actually provides oauth1.0a support. Urg.
import oauth2 as oauth
except ImportError:
oauth = None
-# OAuth is optional
+
+# OAuthProvider is optional
try:
import oauth_provider
from oauth_provider.store import store as oauth_provider_store
@@ -548,6 +212,7 @@ except (ImportError, ImproperlyConfigured):
oauth_provider_store = None
check_nonce = None
+
# OAuth 2 support is optional
try:
import provider as oauth2_provider
@@ -567,7 +232,8 @@ except ImportError:
oauth2_constants = None
provider_now = None
-# Handle lazy strings
+
+# Handle lazy strings across Py2/Py3
from django.utils.functional import Promise
if six.PY3:
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index c69756a4..449ba0a2 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -3,13 +3,14 @@ The most important decorator in this module is `@api_view`, which is used
for writing function-based views with REST framework.
There are also various decorators for setting the API policies on function
-based views, as well as the `@action` and `@link` decorators, which are
+based views, as well as the `@detail_route` and `@list_route` decorators, which are
used to annotate methods on viewsets that should be included by routers.
"""
from __future__ import unicode_literals
-from rest_framework.compat import six
+from django.utils import six
from rest_framework.views import APIView
import types
+import warnings
def api_view(http_method_names):
@@ -107,23 +108,59 @@ def permission_classes(permission_classes):
return decorator
+def detail_route(methods=['get'], **kwargs):
+ """
+ Used to mark a method on a ViewSet that should be routed for detail requests.
+ """
+ def decorator(func):
+ func.bind_to_methods = methods
+ func.detail = True
+ func.kwargs = kwargs
+ return func
+ return decorator
+
+
+def list_route(methods=['get'], **kwargs):
+ """
+ Used to mark a method on a ViewSet that should be routed for list requests.
+ """
+ def decorator(func):
+ func.bind_to_methods = methods
+ func.detail = False
+ func.kwargs = kwargs
+ return func
+ return decorator
+
+
+# These are now pending deprecation, in favor of `detail_route` and `list_route`.
+
def link(**kwargs):
"""
- Used to mark a method on a ViewSet that should be routed for GET requests.
+ Used to mark a method on a ViewSet that should be routed for detail GET requests.
"""
+ msg = 'link is pending deprecation. Use detail_route instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+
def decorator(func):
func.bind_to_methods = ['get']
+ func.detail = True
func.kwargs = kwargs
return func
+
return decorator
def action(methods=['post'], **kwargs):
"""
- Used to mark a method on a ViewSet that should be routed for POST requests.
+ Used to mark a method on a ViewSet that should be routed for detail POST requests.
"""
+ msg = 'action is pending deprecation. Use detail_route instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+
def decorator(func):
func.bind_to_methods = methods
+ func.detail = True
func.kwargs = kwargs
return func
+
return decorator
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 389032bd..ad52d172 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -23,6 +23,7 @@ class APIException(Exception):
def __str__(self):
return self.detail
+
class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = 'Malformed request.'
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 68b95682..9d707c9b 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -18,12 +18,14 @@ from django.conf import settings
from django.db.models.fields import BLANK_CHOICE_DASH
from django.http import QueryDict
from django.forms import widgets
+from django.utils import six, timezone
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
from django.utils.datastructures import SortedDict
+from django.utils.dateparse import parse_date, parse_datetime, parse_time
from rest_framework import ISO_8601
from rest_framework.compat import (
- timezone, parse_date, parse_datetime, parse_time, BytesIO, six, smart_text,
+ BytesIO, smart_text,
force_text, is_non_str_iterable
)
from rest_framework.settings import api_settings
@@ -61,8 +63,10 @@ def get_component(obj, attr_name):
def readable_datetime_formats(formats):
- format = ', '.join(formats).replace(ISO_8601,
- 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]')
+ format = ', '.join(formats).replace(
+ ISO_8601,
+ 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
+ )
return humanize_strptime(format)
@@ -154,7 +158,12 @@ class Field(object):
def widget_html(self):
if not self.widget:
return ''
- return self.widget.render(self._name, self._value)
+
+ attrs = {}
+ if 'id' not in self.widget.attrs:
+ attrs['id'] = self._name
+
+ return self.widget.render(self._name, self._value, attrs=attrs)
def label_tag(self):
return '<label for="%s">%s:</label>' % (self._name, self.label)
@@ -164,7 +173,7 @@ class Field(object):
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
- model_field - The model field this field corresponds to, if one exists.
+ field_name - The name of the field being initialized.
"""
self.parent = parent
self.root = parent.root or parent
@@ -182,7 +191,7 @@ class Field(object):
def field_to_native(self, obj, field_name):
"""
- Given and object and a field name, returns the value that should be
+ Given an object and a field name, returns the value that should be
serialized for that field.
"""
if obj is None:
@@ -260,13 +269,6 @@ class WritableField(Field):
validators=[], error_messages=None, widget=None,
default=None, blank=None):
- # 'blank' is to be deprecated in favor of 'required'
- if blank is not None:
- warnings.warn('The `blank` keyword argument is deprecated. '
- 'Use the `required` keyword argument instead.',
- DeprecationWarning, stacklevel=2)
- required = not(blank)
-
super(WritableField, self).__init__(source=source, label=label, help_text=help_text)
self.read_only = read_only
@@ -289,7 +291,7 @@ class WritableField(Field):
self.validators = self.default_validators + validators
self.default = default if default is not None else self.default
- # Widgets are ony used for HTML forms.
+ # Widgets are only used for HTML forms.
widget = widget or self.widget
if isinstance(widget, type):
widget = widget()
@@ -425,7 +427,7 @@ class ModelField(WritableField):
}
-##### Typed Fields #####
+# Typed Fields
class BooleanField(WritableField):
type_name = 'BooleanField'
@@ -460,8 +462,9 @@ class CharField(WritableField):
type_label = 'string'
form_field_class = forms.CharField
- def __init__(self, max_length=None, min_length=None, *args, **kwargs):
+ def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs):
self.max_length, self.min_length = max_length, min_length
+ self.allow_none = allow_none
super(CharField, self).__init__(*args, **kwargs)
if min_length is not None:
self.validators.append(validators.MinLengthValidator(min_length))
@@ -469,8 +472,12 @@ class CharField(WritableField):
self.validators.append(validators.MaxLengthValidator(max_length))
def from_native(self, value):
- if isinstance(value, six.string_types) or value is None:
+ if isinstance(value, six.string_types):
return value
+
+ if value is None and not self.allow_none:
+ return ''
+
return smart_text(value)
@@ -479,7 +486,7 @@ class URLField(CharField):
type_label = 'url'
def __init__(self, **kwargs):
- if not 'validators' in kwargs:
+ if 'validators' not in kwargs:
kwargs['validators'] = [validators.URLValidator()]
super(URLField, self).__init__(**kwargs)
@@ -501,7 +508,7 @@ class SlugField(CharField):
class ChoiceField(WritableField):
type_name = 'ChoiceField'
- type_label = 'multiple choice'
+ type_label = 'choice'
form_field_class = forms.ChoiceField
widget = widgets.Select
default_error_messages = {
@@ -509,12 +516,16 @@ class ChoiceField(WritableField):
'the available choices.'),
}
- def __init__(self, choices=(), *args, **kwargs):
+ def __init__(self, choices=(), blank_display_value=None, *args, **kwargs):
self.empty = kwargs.pop('empty', '')
super(ChoiceField, self).__init__(*args, **kwargs)
self.choices = choices
if not self.required:
- self.choices = BLANK_CHOICE_DASH + self.choices
+ if blank_display_value is None:
+ blank_choice = BLANK_CHOICE_DASH
+ else:
+ blank_choice = [('', blank_display_value)]
+ self.choices = blank_choice + self.choices
def _get_choices(self):
return self._choices
@@ -1018,9 +1029,9 @@ class SerializerMethodField(Field):
A field that gets its value by calling a method on the serializer it's attached to.
"""
- def __init__(self, method_name):
+ def __init__(self, method_name, *args, **kwargs):
self.method_name = method_name
- super(SerializerMethodField, self).__init__()
+ super(SerializerMethodField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name):
value = getattr(self.parent, self.method_name)(obj)
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index 96d15eb9..e2080013 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -5,7 +5,8 @@ returned by list views.
from __future__ import unicode_literals
from django.core.exceptions import ImproperlyConfigured
from django.db import models
-from rest_framework.compat import django_filters, six, guardian, get_model_name
+from django.utils import six
+from rest_framework.compat import django_filters, guardian, get_model_name
from rest_framework.settings import api_settings
from functools import reduce
import operator
@@ -44,7 +45,7 @@ class DjangoFilterBackend(BaseFilterBackend):
if filter_class:
filter_model = filter_class.Meta.model
- assert issubclass(filter_model, queryset.model), \
+ assert issubclass(queryset.model, filter_model), \
'FilterSet model %s does not match queryset model %s' % \
(filter_model, queryset.model)
@@ -116,6 +117,10 @@ class OrderingFilter(BaseFilterBackend):
def get_ordering(self, request):
"""
Ordering is set by a comma delimited ?ordering=... query parameter.
+
+ The `ordering` query parameter can be overridden by setting
+ the `ordering_param` value on the OrderingFilter or by
+ specifying an `ORDERING_PARAM` value in the API settings.
"""
params = request.QUERY_PARAMS.get(self.ordering_param)
if params:
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 7bac510f..77deb8e4 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -25,6 +25,7 @@ def strict_positive_int(integer_string, cutoff=None):
ret = min(ret, cutoff)
return ret
+
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
"""
Same as Django's standard shortcut, but make sure to raise 404
@@ -43,6 +44,10 @@ class GenericAPIView(views.APIView):
# You'll need to either set these attributes,
# or override `get_queryset()`/`get_serializer_class()`.
+ # If you are overriding a view method, it is important that you call
+ # `get_queryset()` instead of accessing the `queryset` property directly,
+ # as `queryset` will get evaluated only once, and those results are cached
+ # for all subsequent requests.
queryset = None
serializer_class = None
@@ -90,8 +95,8 @@ class GenericAPIView(views.APIView):
'view': self
}
- def get_serializer(self, instance=None, data=None,
- files=None, many=False, partial=False):
+ def get_serializer(self, instance=None, data=None, files=None, many=False,
+ partial=False, allow_add_remove=False):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
@@ -99,7 +104,9 @@ class GenericAPIView(views.APIView):
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
return serializer_class(instance, data=data, files=files,
- many=many, partial=partial, context=context)
+ many=many, partial=partial,
+ allow_add_remove=allow_add_remove,
+ context=context)
def get_pagination_serializer(self, page):
"""
@@ -121,11 +128,11 @@ class GenericAPIView(views.APIView):
deprecated_style = False
if page_size is not None:
warnings.warn('The `page_size` parameter to `paginate_queryset()` '
- 'is due to be deprecated. '
+ 'is deprecated. '
'Note that the return style of this method is also '
'changed, and will simply return a page object '
'when called without a `page_size` argument.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
deprecated_style = True
else:
# Determine the required page size.
@@ -136,10 +143,10 @@ class GenericAPIView(views.APIView):
if not self.allow_empty:
warnings.warn(
- 'The `allow_empty` parameter is due to be deprecated. '
+ 'The `allow_empty` parameter is deprecated. '
'To use `allow_empty=False` style behavior, You should override '
'`get_queryset()` and explicitly raise a 404 on empty querysets.',
- PendingDeprecationWarning, stacklevel=2
+ DeprecationWarning, stacklevel=2
)
paginator = self.paginator_class(queryset, page_size,
@@ -156,10 +163,11 @@ class GenericAPIView(views.APIView):
raise Http404(_("Page is not 'last', nor can it be converted to an int."))
try:
page = paginator.page(page_number)
- except InvalidPage as e:
- raise Http404(_('Invalid page (%(page_number)s): %(message)s') % {
- 'page_number': page_number,
- 'message': str(e)
+ except InvalidPage as exc:
+ error_format = _('Invalid page (%(page_number)s): %(message)s')
+ raise Http404(error_format % {
+ 'page_number': page_number,
+ 'message': str(exc)
})
if deprecated_style:
@@ -183,22 +191,27 @@ class GenericAPIView(views.APIView):
"""
Returns the list of filter backends that this view requires.
"""
- filter_backends = self.filter_backends or []
+ if self.filter_backends is None:
+ filter_backends = []
+ else:
+ # Note that we are returning a *copy* of the class attribute,
+ # so that it is safe for the view to mutate it if needed.
+ filter_backends = list(self.filter_backends)
+
if not filter_backends and self.filter_backend:
warnings.warn(
'The `filter_backend` attribute and `FILTER_BACKEND` setting '
- 'are due to be deprecated in favor of a `filter_backends` '
+ 'are deprecated in favor of a `filter_backends` '
'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take '
'a *list* of filter backend classes.',
- PendingDeprecationWarning, stacklevel=2
+ DeprecationWarning, stacklevel=2
)
filter_backends = [self.filter_backend]
- return filter_backends
+ return filter_backends
- ########################
- ### The following methods provide default implementations
- ### that you may want to override for more complex cases.
+ # The following methods provide default implementations
+ # that you may want to override for more complex cases.
def get_paginate_by(self, queryset=None):
"""
@@ -211,8 +224,8 @@ class GenericAPIView(views.APIView):
"""
if queryset is not None:
warnings.warn('The `queryset` parameter to `get_paginate_by()` '
- 'is due to be deprecated.',
- PendingDeprecationWarning, stacklevel=2)
+ 'is deprecated.',
+ DeprecationWarning, stacklevel=2)
if self.paginate_by_param:
try:
@@ -256,6 +269,10 @@ class GenericAPIView(views.APIView):
This must be an iterable, and may be a queryset.
Defaults to using `self.queryset`.
+ This method should always be used rather than accessing `self.queryset`
+ directly, as `self.queryset` gets evaluated only once, and those results
+ are cached for all subsequent requests.
+
You may want to override this if you need to provide different
querysets depending on the incoming request.
@@ -267,8 +284,8 @@ class GenericAPIView(views.APIView):
if self.model is not None:
return self.model._default_manager.all()
- raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'"
- % self.__class__.__name__)
+ error_format = "'%s' must define 'queryset' or 'model'"
+ raise ImproperlyConfigured(error_format % self.__class__.__name__)
def get_object(self, queryset=None):
"""
@@ -295,16 +312,16 @@ class GenericAPIView(views.APIView):
filter_kwargs = {self.lookup_field: lookup}
elif pk is not None and self.lookup_field == 'pk':
warnings.warn(
- 'The `pk_url_kwarg` attribute is due to be deprecated. '
+ 'The `pk_url_kwarg` attribute is deprecated. '
'Use the `lookup_field` attribute instead',
- PendingDeprecationWarning
+ DeprecationWarning
)
filter_kwargs = {'pk': pk}
elif slug is not None and self.lookup_field == 'pk':
warnings.warn(
- 'The `slug_url_kwarg` attribute is due to be deprecated. '
+ 'The `slug_url_kwarg` attribute is deprecated. '
'Use the `lookup_field` attribute instead',
- PendingDeprecationWarning
+ DeprecationWarning
)
filter_kwargs = {self.slug_field: slug}
else:
@@ -322,12 +339,11 @@ class GenericAPIView(views.APIView):
return obj
- ########################
- ### The following are placeholder methods,
- ### and are intended to be overridden.
- ###
- ### The are not called by GenericAPIView directly,
- ### but are used by the mixin methods.
+ # The following are placeholder methods,
+ # and are intended to be overridden.
+ #
+ # The are not called by GenericAPIView directly,
+ # but are used by the mixin methods.
def pre_save(self, obj):
"""
@@ -399,10 +415,8 @@ class GenericAPIView(views.APIView):
return ret
-##########################################################
-### Concrete view classes that provide method handlers ###
-### by composing the mixin classes with the base view. ###
-##########################################################
+# Concrete view classes that provide method handlers
+# by composing the mixin classes with the base view.
class CreateAPIView(mixins.CreateModelMixin,
GenericAPIView):
@@ -517,16 +531,14 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
return self.destroy(request, *args, **kwargs)
-##########################
-### Deprecated classes ###
-##########################
+# Deprecated classes
class MultipleObjectAPIView(GenericAPIView):
def __init__(self, *args, **kwargs):
warnings.warn(
- 'Subclassing `MultipleObjectAPIView` is due to be deprecated. '
+ 'Subclassing `MultipleObjectAPIView` is deprecated. '
'You should simply subclass `GenericAPIView` instead.',
- PendingDeprecationWarning, stacklevel=2
+ DeprecationWarning, stacklevel=2
)
super(MultipleObjectAPIView, self).__init__(*args, **kwargs)
@@ -534,8 +546,8 @@ class MultipleObjectAPIView(GenericAPIView):
class SingleObjectAPIView(GenericAPIView):
def __init__(self, *args, **kwargs):
warnings.warn(
- 'Subclassing `SingleObjectAPIView` is due to be deprecated. '
+ 'Subclassing `SingleObjectAPIView` is deprecated. '
'You should simply subclass `GenericAPIView` instead.',
- PendingDeprecationWarning, stacklevel=2
+ DeprecationWarning, stacklevel=2
)
super(SingleObjectAPIView, self).__init__(*args, **kwargs)
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index e1a24dc7..2cc87eef 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -26,14 +26,14 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None)
include = []
if pk:
- # Pending deprecation
+ # Deprecated
pk_field = obj._meta.pk
while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk
include.append(pk_field.name)
if slug_field:
- # Pending deprecation
+ # Deprecated
include.append(slug_field)
if lookup_field and lookup_field != 'pk':
@@ -79,10 +79,10 @@ class ListModelMixin(object):
# `.allow_empty = False`, to raise 404 errors on empty querysets.
if not self.allow_empty and not self.object_list:
warnings.warn(
- 'The `allow_empty` parameter is due to be deprecated. '
+ 'The `allow_empty` parameter is deprecated. '
'To use `allow_empty=False` style behavior, You should override '
'`get_queryset()` and explicitly raise a 404 on empty querysets.',
- PendingDeprecationWarning
+ DeprecationWarning
)
class_name = self.__class__.__name__
error_msg = self.empty_error % {'class_name': class_name}
diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py
index 4d205c0e..ca7b5397 100644
--- a/rest_framework/negotiation.py
+++ b/rest_framework/negotiation.py
@@ -54,8 +54,10 @@ class DefaultContentNegotiation(BaseContentNegotiation):
for media_type in media_type_set:
if media_type_matches(renderer.media_type, media_type):
# Return the most specific media type as accepted.
- if (_MediaType(renderer.media_type).precedence >
- _MediaType(media_type).precedence):
+ if (
+ _MediaType(renderer.media_type).precedence >
+ _MediaType(media_type).precedence
+ ):
# Eg client requests '*/*'
# Accepted media type is 'application/json'
return renderer, renderer.media_type
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index f1b3e38d..aa4fd3f1 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -10,7 +10,8 @@ from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter
-from rest_framework.compat import etree, six, yaml
+from django.utils import six
+from rest_framework.compat import etree, yaml, force_text
from rest_framework.exceptions import ParseError
from rest_framework import renderers
import json
@@ -288,7 +289,7 @@ class FileUploadParser(BaseParser):
try:
meta = parser_context['request'].META
- disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'])
- return disposition[1]['filename']
+ disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8'))
+ return force_text(disposition[1]['filename'])
except (AttributeError, KeyError):
pass
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index f24a5123..6a1a0077 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -2,15 +2,12 @@
Provides a set of pluggable permission policies.
"""
from __future__ import unicode_literals
-import inspect
-import warnings
-
-SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
-
from django.http import Http404
from rest_framework.compat import (get_model_name, oauth2_provider_scope,
oauth2_constants)
+SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
+
class BasePermission(object):
"""
@@ -27,13 +24,6 @@ class BasePermission(object):
"""
Return `True` if permission is granted, `False` otherwise.
"""
- if len(inspect.getargspec(self.has_permission).args) == 4:
- warnings.warn(
- 'The `obj` argument in `has_permission` is deprecated. '
- 'Use `has_object_permission()` instead for object permissions.',
- DeprecationWarning, stacklevel=2
- )
- return self.has_permission(request, view, obj)
return True
@@ -72,9 +62,11 @@ class IsAuthenticatedOrReadOnly(BasePermission):
"""
def has_permission(self, request, view):
- return (request.method in SAFE_METHODS or
- request.user and
- request.user.is_authenticated())
+ return (
+ request.method in SAFE_METHODS or
+ request.user and
+ request.user.is_authenticated()
+ )
class DjangoModelPermissions(BasePermission):
@@ -132,9 +124,11 @@ class DjangoModelPermissions(BasePermission):
perms = self.get_required_permissions(request.method, model_cls)
- return (request.user and
+ return (
+ request.user and
(request.user.is_authenticated() or not self.authenticated_users_only) and
- request.user.has_perms(perms))
+ request.user.has_perms(perms)
+ )
class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
@@ -222,6 +216,8 @@ class TokenHasReadWriteScope(BasePermission):
required = oauth2_constants.READ if read_only else oauth2_constants.WRITE
return oauth2_provider_scope.check(required, request.auth.scope)
- assert False, ('TokenHasReadWriteScope requires either the'
- '`OAuthAuthentication` or `OAuth2Authentication` authentication '
- 'class to be used.')
+ assert False, (
+ 'TokenHasReadWriteScope requires either the'
+ '`OAuthAuthentication` or `OAuth2Authentication` authentication '
+ 'class to be used.'
+ )
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 308545ce..1acbdce2 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -19,8 +19,7 @@ from rest_framework.compat import smart_text
import warnings
-##### Relational fields #####
-
+# Relational fields
# Not actually Writable, but subclasses may need to be.
class RelatedField(WritableField):
@@ -41,14 +40,6 @@ class RelatedField(WritableField):
many = False
def __init__(self, *args, **kwargs):
-
- # 'null' is to be deprecated in favor of 'required'
- if 'null' in kwargs:
- warnings.warn('The `null` keyword argument is deprecated. '
- 'Use the `required` keyword argument instead.',
- DeprecationWarning, stacklevel=2)
- kwargs['required'] = not kwargs.pop('null')
-
queryset = kwargs.pop('queryset', None)
self.many = kwargs.pop('many', self.many)
if self.many:
@@ -59,6 +50,8 @@ class RelatedField(WritableField):
super(RelatedField, self).__init__(*args, **kwargs)
if not self.required:
+ # Accessed in ModelChoiceIterator django/forms/models.py:1034
+ # If set adds empty choice.
self.empty_label = BLANK_CHOICE_DASH[0][1]
self.queryset = queryset
@@ -72,7 +65,7 @@ class RelatedField(WritableField):
else: # Reverse
self.queryset = manager.field.rel.to._default_manager.all()
- ### We need this stuff to make form choices work...
+ # We need this stuff to make form choices work...
def prepare_value(self, obj):
return self.to_native(obj)
@@ -119,7 +112,7 @@ class RelatedField(WritableField):
choices = property(_get_choices, _set_choices)
- ### Default value handling
+ # Default value handling
def get_default_value(self):
default = super(RelatedField, self).get_default_value()
@@ -127,7 +120,7 @@ class RelatedField(WritableField):
return []
return default
- ### Regular serializer stuff...
+ # Regular serializer stuff...
def field_to_native(self, obj, field_name):
try:
@@ -187,7 +180,7 @@ class RelatedField(WritableField):
into[(self.source or field_name)] = self.from_native(value)
-### PrimaryKey relationships
+# PrimaryKey relationships
class PrimaryKeyRelatedField(RelatedField):
"""
@@ -275,8 +268,7 @@ class PrimaryKeyRelatedField(RelatedField):
return self.to_native(pk)
-### Slug relationships
-
+# Slug relationships
class SlugRelatedField(RelatedField):
"""
@@ -311,7 +303,7 @@ class SlugRelatedField(RelatedField):
raise ValidationError(msg)
-### Hyperlinked relationships
+# Hyperlinked relationships
class HyperlinkedRelatedField(RelatedField):
"""
@@ -328,7 +320,7 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
}
- # These are all pending deprecation
+ # These are all deprecated
pk_url_kwarg = 'pk'
slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
@@ -342,16 +334,16 @@ class HyperlinkedRelatedField(RelatedField):
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.format = kwargs.pop('format', None)
- # These are pending deprecation
+ # These are deprecated
if 'pk_url_kwarg' in kwargs:
- msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_url_kwarg' in kwargs:
- msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_field' in kwargs:
- msg = 'slug_field is pending deprecation. Use lookup_field instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ msg = 'slug_field is deprecated. Use lookup_field instead.'
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
self.slug_field = kwargs.pop('slug_field', self.slug_field)
@@ -394,9 +386,9 @@ class HyperlinkedRelatedField(RelatedField):
# If the lookup succeeds using the default slug params,
# then `slug_field` is being used implicitly, and we
# we need to warn about the pending deprecation.
- msg = 'Implicit slug field hyperlinked fields are pending deprecation.' \
+ msg = 'Implicit slug field hyperlinked fields are deprecated.' \
'You should set `lookup_field=slug` on the HyperlinkedRelatedField.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
return ret
except NoReverseMatch:
pass
@@ -430,14 +422,11 @@ class HyperlinkedRelatedField(RelatedField):
request = self.context.get('request', None)
format = self.format or self.context.get('format', None)
- if request is None:
- msg = (
- "Using `HyperlinkedRelatedField` without including the request "
- "in the serializer context is deprecated. "
- "Add `context={'request': request}` when instantiating "
- "the serializer."
- )
- warnings.warn(msg, DeprecationWarning, stacklevel=4)
+ assert request is not None, (
+ "`HyperlinkedRelatedField` requires the request in the serializer "
+ "context. Add `context={'request': request}` when instantiating "
+ "the serializer."
+ )
# If the object has not yet been saved then we cannot hyperlink to it.
if getattr(obj, 'pk', None) is None:
@@ -497,7 +486,7 @@ class HyperlinkedIdentityField(Field):
lookup_field = 'pk'
read_only = True
- # These are all pending deprecation
+ # These are all deprecated
pk_url_kwarg = 'pk'
slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
@@ -513,16 +502,16 @@ class HyperlinkedIdentityField(Field):
lookup_field = kwargs.pop('lookup_field', None)
self.lookup_field = lookup_field or self.lookup_field
- # These are pending deprecation
+ # These are deprecated
if 'pk_url_kwarg' in kwargs:
- msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_url_kwarg' in kwargs:
- msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_field' in kwargs:
- msg = 'slug_field is pending deprecation. Use lookup_field instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ msg = 'slug_field is deprecated. Use lookup_field instead.'
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.slug_field = kwargs.pop('slug_field', self.slug_field)
default_slug_kwarg = self.slug_url_kwarg or self.slug_field
@@ -536,11 +525,11 @@ class HyperlinkedIdentityField(Field):
format = self.context.get('format', None)
view_name = self.view_name
- if request is None:
- warnings.warn("Using `HyperlinkedIdentityField` without including the "
- "request in the serializer context is deprecated. "
- "Add `context={'request': request}` when instantiating the serializer.",
- DeprecationWarning, stacklevel=4)
+ assert request is not None, (
+ "`HyperlinkedIdentityField` requires the request in the serializer"
+ " context. Add `context={'request': request}` when instantiating "
+ "the serializer."
+ )
# By default use whatever format is given for the current context
# unless the target is a different type to the source.
@@ -604,41 +593,3 @@ class HyperlinkedIdentityField(Field):
pass
raise NoReverseMatch()
-
-
-### Old-style many classes for backwards compat
-
-class ManyRelatedField(RelatedField):
- def __init__(self, *args, **kwargs):
- warnings.warn('`ManyRelatedField()` is deprecated. '
- 'Use `RelatedField(many=True)` instead.',
- DeprecationWarning, stacklevel=2)
- kwargs['many'] = True
- super(ManyRelatedField, self).__init__(*args, **kwargs)
-
-
-class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
- def __init__(self, *args, **kwargs):
- warnings.warn('`ManyPrimaryKeyRelatedField()` is deprecated. '
- 'Use `PrimaryKeyRelatedField(many=True)` instead.',
- DeprecationWarning, stacklevel=2)
- kwargs['many'] = True
- super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs)
-
-
-class ManySlugRelatedField(SlugRelatedField):
- def __init__(self, *args, **kwargs):
- warnings.warn('`ManySlugRelatedField()` is deprecated. '
- 'Use `SlugRelatedField(many=True)` instead.',
- DeprecationWarning, stacklevel=2)
- kwargs['many'] = True
- super(ManySlugRelatedField, self).__init__(*args, **kwargs)
-
-
-class ManyHyperlinkedRelatedField(HyperlinkedRelatedField):
- def __init__(self, *args, **kwargs):
- warnings.warn('`ManyHyperlinkedRelatedField()` is deprecated. '
- 'Use `HyperlinkedRelatedField(many=True)` instead.',
- DeprecationWarning, stacklevel=2)
- kwargs['many'] = True
- super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs)
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 7a7da561..748ebac9 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -8,7 +8,6 @@ REST framework also provides an HTML renderer the renders the browsable API.
"""
from __future__ import unicode_literals
-import copy
import json
import django
from django import forms
@@ -16,11 +15,9 @@ from django.core.exceptions import ImproperlyConfigured
from django.http.multipartparser import parse_header
from django.template import RequestContext, loader, Template
from django.test.client import encode_multipart
+from django.utils import six
from django.utils.xmlutils import SimplerXMLGenerator
-from rest_framework.compat import StringIO
-from rest_framework.compat import six
-from rest_framework.compat import smart_text
-from rest_framework.compat import yaml
+from rest_framework.compat import StringIO, smart_text, yaml
from rest_framework.exceptions import ParseError
from rest_framework.settings import api_settings
from rest_framework.request import is_form_media_type, override_method
@@ -54,35 +51,41 @@ class JSONRenderer(BaseRenderer):
format = 'json'
encoder_class = encoders.JSONEncoder
ensure_ascii = True
- charset = None
- # JSON is a binary encoding, that can be encoded as utf-8, utf-16 or utf-32.
+
+ # We don't set a charset because JSON is a binary encoding,
+ # that can be encoded as utf-8, utf-16 or utf-32.
# See: http://www.ietf.org/rfc/rfc4627.txt
# Also: http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/
+ charset = None
+
+ def get_indent(self, accepted_media_type, renderer_context):
+ if accepted_media_type:
+ # If the media type looks like 'application/json; indent=4',
+ # then pretty print the result.
+ base_media_type, params = parse_header(accepted_media_type.encode('ascii'))
+ try:
+ return max(min(int(params['indent']), 8), 0)
+ except (KeyError, ValueError, TypeError):
+ pass
+
+ # If 'indent' is provided in the context, then pretty print the result.
+ # E.g. If we're being called by the BrowsableAPIRenderer.
+ return renderer_context.get('indent', None)
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
- Render `data` into JSON.
+ Render `data` into JSON, returning a bytestring.
"""
if data is None:
return bytes()
- # If 'indent' is provided in the context, then pretty print the result.
- # E.g. If we're being called by the BrowsableAPIRenderer.
renderer_context = renderer_context or {}
- indent = renderer_context.get('indent', None)
+ indent = self.get_indent(accepted_media_type, renderer_context)
- if accepted_media_type:
- # If the media type looks like 'application/json; indent=4',
- # then pretty print the result.
- base_media_type, params = parse_header(accepted_media_type.encode('ascii'))
- indent = params.get('indent', indent)
- try:
- indent = max(min(int(indent), 8), 0)
- except (ValueError, TypeError):
- indent = None
-
- ret = json.dumps(data, cls=self.encoder_class,
- indent=indent, ensure_ascii=self.ensure_ascii)
+ ret = json.dumps(
+ data, cls=self.encoder_class,
+ indent=indent, ensure_ascii=self.ensure_ascii
+ )
# On python 2.x json.dumps() returns bytestrings if ensure_ascii=True,
# but if ensure_ascii=False, the return type is underspecified,
@@ -193,6 +196,7 @@ class YAMLRenderer(BaseRenderer):
format = 'yaml'
encoder = encoders.SafeDumper
charset = 'utf-8'
+ ensure_ascii = True
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -203,7 +207,15 @@ class YAMLRenderer(BaseRenderer):
if data is None:
return ''
- return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder)
+ return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii)
+
+
+class UnicodeYAMLRenderer(YAMLRenderer):
+ """
+ Renderer which serializes to YAML.
+ Does *not* apply character escaping for non-ascii characters.
+ """
+ ensure_ascii = False
class TemplateHTMLRenderer(BaseRenderer):
@@ -400,7 +412,7 @@ class BrowsableAPIRenderer(BaseRenderer):
"""
Returns True if a form should be shown for this method.
"""
- if not method in view.allowed_methods:
+ if method not in view.allowed_methods:
return # Not a valid method
if not api_settings.FORM_METHOD_OVERRIDE:
@@ -440,8 +452,10 @@ class BrowsableAPIRenderer(BaseRenderer):
if method in ('DELETE', 'OPTIONS'):
return True # Don't actually need to return a form
- if (not getattr(view, 'get_serializer', None)
- or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)):
+ if (
+ not getattr(view, 'get_serializer', None)
+ or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)
+ ):
return
serializer = view.get_serializer(instance=obj, data=data, files=files)
@@ -562,7 +576,7 @@ class BrowsableAPIRenderer(BaseRenderer):
'version': VERSION,
'breadcrumblist': self.get_breadcrumbs(request),
'allowed_methods': view.allowed_methods,
- 'available_formats': [renderer.format for renderer in view.renderer_classes],
+ 'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
'response_headers': response_headers,
'put_form': self.get_rendered_html_form(view, 'PUT', request),
@@ -611,4 +625,3 @@ class MultiPartRenderer(BaseRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None):
return encode_multipart(self.BOUNDARY, data)
-
diff --git a/rest_framework/request.py b/rest_framework/request.py
index 40467c03..27532661 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -42,13 +42,20 @@ class override_method(object):
self.view = view
self.request = request
self.method = method
+ self.action = getattr(view, 'action', None)
def __enter__(self):
self.view.request = clone_request(self.request, self.method)
+ if self.action is not None:
+ # For viewsets we also set the `.action` attribute.
+ action_map = getattr(self.view, 'action_map', {})
+ self.view.action = action_map.get(self.method.lower())
return self.view.request
def __exit__(self, *args, **kwarg):
self.view.request = self.request
+ if self.action is not None:
+ self.view.action = self.action
class Empty(object):
@@ -280,16 +287,19 @@ class Request(object):
self._method = self._request.method
# Allow X-HTTP-METHOD-OVERRIDE header
- self._method = self.META.get('HTTP_X_HTTP_METHOD_OVERRIDE',
- self._method)
+ if 'HTTP_X_HTTP_METHOD_OVERRIDE' in self.META:
+ self._method = self.META['HTTP_X_HTTP_METHOD_OVERRIDE'].upper()
def _load_stream(self):
"""
Return the content body of the request, as a stream.
"""
try:
- content_length = int(self.META.get('CONTENT_LENGTH',
- self.META.get('HTTP_CONTENT_LENGTH')))
+ content_length = int(
+ self.META.get(
+ 'CONTENT_LENGTH', self.META.get('HTTP_CONTENT_LENGTH')
+ )
+ )
except (ValueError, TypeError):
content_length = 0
@@ -313,9 +323,11 @@ class Request(object):
)
# We only need to use form overloading on form POST requests.
- if (not USE_FORM_OVERLOADING
+ if (
+ not USE_FORM_OVERLOADING
or self._request.method != 'POST'
- or not is_form_media_type(self._content_type)):
+ or not is_form_media_type(self._content_type)
+ ):
return
# At this point we're committed to parsing the request as form data.
@@ -323,15 +335,19 @@ class Request(object):
self._files = self._request.FILES
# Method overloading - change the method and remove the param from the content.
- if (self._METHOD_PARAM and
- self._METHOD_PARAM in self._data):
+ if (
+ self._METHOD_PARAM and
+ self._METHOD_PARAM in self._data
+ ):
self._method = self._data[self._METHOD_PARAM].upper()
# Content overloading - modify the content type, and force re-parse.
- if (self._CONTENT_PARAM and
+ if (
+ self._CONTENT_PARAM and
self._CONTENTTYPE_PARAM and
self._CONTENT_PARAM in self._data and
- self._CONTENTTYPE_PARAM in self._data):
+ self._CONTENTTYPE_PARAM in self._data
+ ):
self._content_type = self._data[self._CONTENTTYPE_PARAM]
self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))
self._data, self._files = (Empty, Empty)
@@ -387,7 +403,7 @@ class Request(object):
self._not_authenticated()
raise
- if not user_auth_tuple is None:
+ if user_auth_tuple is not None:
self._authenticator = authenticator
self._user, self._auth = user_auth_tuple
return
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 1dc6abcf..0a7d313f 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -5,9 +5,10 @@ it is initialized with unrendered data, instead of a pre-rendered string.
The appropriate renderer is called during Django's template response rendering.
"""
from __future__ import unicode_literals
+import django
from django.core.handlers.wsgi import STATUS_CODE_TEXT
from django.template.response import SimpleTemplateResponse
-from rest_framework.compat import six
+from django.utils import six
class Response(SimpleTemplateResponse):
@@ -15,8 +16,11 @@ class Response(SimpleTemplateResponse):
An HttpResponse that allows its data to be rendered into
arbitrary media types.
"""
+ # TODO: remove that once Django 1.3 isn't supported
+ if django.VERSION >= (1, 4):
+ rendering_attrs = SimpleTemplateResponse.rendering_attrs + ['_closable_objects']
- def __init__(self, data=None, status=200,
+ def __init__(self, data=None, status=None,
template_name=None, headers=None,
exception=False, content_type=None):
"""
@@ -58,8 +62,10 @@ class Response(SimpleTemplateResponse):
ret = renderer.render(self.data, media_type, context)
if isinstance(ret, six.text_type):
- assert charset, 'renderer returned unicode, and did not specify ' \
- 'a charset value.'
+ assert charset, (
+ 'renderer returned unicode, and did not specify '
+ 'a charset value.'
+ )
return bytes(ret.encode(charset))
if not ret:
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 97b35c10..406ebcf7 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -17,15 +17,17 @@ from __future__ import unicode_literals
import itertools
from collections import namedtuple
+from django.conf.urls import patterns, url
from django.core.exceptions import ImproperlyConfigured
from rest_framework import views
-from rest_framework.compat import patterns, url
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.urlpatterns import format_suffix_patterns
Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs'])
+DynamicDetailRoute = namedtuple('DynamicDetailRoute', ['url', 'name', 'initkwargs'])
+DynamicListRoute = namedtuple('DynamicListRoute', ['url', 'name', 'initkwargs'])
def replace_methodname(format_string, methodname):
@@ -88,6 +90,14 @@ class SimpleRouter(BaseRouter):
name='{basename}-list',
initkwargs={'suffix': 'List'}
),
+ # Dynamically generated list routes.
+ # Generated using @list_route decorator
+ # on methods of the viewset.
+ DynamicListRoute(
+ url=r'^{prefix}/{methodname}{trailing_slash}$',
+ name='{basename}-{methodnamehyphen}',
+ initkwargs={}
+ ),
# Detail route.
Route(
url=r'^{prefix}/{lookup}{trailing_slash}$',
@@ -100,13 +110,10 @@ class SimpleRouter(BaseRouter):
name='{basename}-detail',
initkwargs={'suffix': 'Instance'}
),
- # Dynamically generated routes.
- # Generated using @action or @link decorators on methods of the viewset.
- Route(
+ # Dynamically generated detail routes.
+ # Generated using @detail_route decorator on methods of the viewset.
+ DynamicDetailRoute(
url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$',
- mapping={
- '{httpmethod}': '{methodname}',
- },
name='{basename}-{methodnamehyphen}',
initkwargs={}
),
@@ -139,25 +146,42 @@ class SimpleRouter(BaseRouter):
Returns a list of the Route namedtuple.
"""
- known_actions = flatten([route.mapping.values() for route in self.routes])
+ known_actions = flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)])
- # Determine any `@action` or `@link` decorated methods on the viewset
- dynamic_routes = []
+ # Determine any `@detail_route` or `@list_route` decorated methods on the viewset
+ detail_routes = []
+ list_routes = []
for methodname in dir(viewset):
attr = getattr(viewset, methodname)
httpmethods = getattr(attr, 'bind_to_methods', None)
+ detail = getattr(attr, 'detail', True)
if httpmethods:
if methodname in known_actions:
- raise ImproperlyConfigured('Cannot use @action or @link decorator on '
- 'method "%s" as it is an existing route' % methodname)
+ raise ImproperlyConfigured('Cannot use @detail_route or @list_route '
+ 'decorators on method "%s" '
+ 'as it is an existing route' % methodname)
httpmethods = [method.lower() for method in httpmethods]
- dynamic_routes.append((httpmethods, methodname))
+ if detail:
+ detail_routes.append((httpmethods, methodname))
+ else:
+ list_routes.append((httpmethods, methodname))
ret = []
for route in self.routes:
- if route.mapping == {'{httpmethod}': '{methodname}'}:
- # Dynamic routes (@link or @action decorator)
- for httpmethods, methodname in dynamic_routes:
+ if isinstance(route, DynamicDetailRoute):
+ # Dynamic detail routes (@detail_route decorator)
+ for httpmethods, methodname in detail_routes:
+ initkwargs = route.initkwargs.copy()
+ initkwargs.update(getattr(viewset, methodname).kwargs)
+ ret.append(Route(
+ url=replace_methodname(route.url, methodname),
+ mapping=dict((httpmethod, methodname) for httpmethod in httpmethods),
+ name=replace_methodname(route.name, methodname),
+ initkwargs=initkwargs,
+ ))
+ elif isinstance(route, DynamicListRoute):
+ # Dynamic list routes (@list_route decorator)
+ for httpmethods, methodname in list_routes:
initkwargs = route.initkwargs.copy()
initkwargs.update(getattr(viewset, methodname).kwargs)
ret.append(Route(
@@ -195,13 +219,16 @@ class SimpleRouter(BaseRouter):
https://github.com/alanjds/drf-nested-routers
"""
- if self.trailing_slash:
- base_regex = '(?P<{lookup_prefix}{lookup_field}>[^/]+)'
- else:
- # Don't consume `.json` style suffixes
- base_regex = '(?P<{lookup_prefix}{lookup_field}>[^/.]+)'
+ base_regex = '(?P<{lookup_prefix}{lookup_field}>{lookup_value})'
+ # Use `pk` as default field, unset set. Default regex should not
+ # consume `.json` style suffixes and should break at '/' boundaries.
lookup_field = getattr(viewset, 'lookup_field', 'pk')
- return base_regex.format(lookup_field=lookup_field, lookup_prefix=lookup_prefix)
+ lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+')
+ return base_regex.format(
+ lookup_prefix=lookup_prefix,
+ lookup_field=lookup_field,
+ lookup_value=lookup_value
+ )
def get_urls(self):
"""
diff --git a/rest_framework/runtests/runcoverage.py b/rest_framework/runtests/runcoverage.py
deleted file mode 100755
index ce11b213..00000000
--- a/rest_framework/runtests/runcoverage.py
+++ /dev/null
@@ -1,78 +0,0 @@
-#!/usr/bin/env python
-"""
-Useful tool to run the test suite for rest_framework and generate a coverage report.
-"""
-
-# http://ericholscher.com/blog/2009/jun/29/enable-setuppy-test-your-django-apps/
-# http://www.travisswicegood.com/2010/01/17/django-virtualenv-pip-and-fabric/
-# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
-import os
-import sys
-
-# fix sys path so we don't need to setup PYTHONPATH
-sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
-os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'
-
-from coverage import coverage
-
-
-def main():
- """Run the tests for rest_framework and generate a coverage report."""
-
- cov = coverage()
- cov.erase()
- cov.start()
-
- from django.conf import settings
- from django.test.utils import get_runner
- TestRunner = get_runner(settings)
-
- if hasattr(TestRunner, 'func_name'):
- # Pre 1.2 test runners were just functions,
- # and did not support the 'failfast' option.
- import warnings
- warnings.warn(
- 'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
- DeprecationWarning
- )
- failures = TestRunner(['tests'])
- else:
- test_runner = TestRunner()
- failures = test_runner.run_tests(['tests'])
- cov.stop()
-
- # Discover the list of all modules that we should test coverage for
- import rest_framework
-
- project_dir = os.path.dirname(rest_framework.__file__)
- cov_files = []
-
- for (path, dirs, files) in os.walk(project_dir):
- # Drop tests and runtests directories from the test coverage report
- if os.path.basename(path) in ['tests', 'runtests', 'migrations']:
- continue
-
- # Drop the compat and six modules from coverage, since we're not interested in the coverage
- # of modules which are specifically for resolving environment dependant imports.
- # (Because we'll end up getting different coverage reports for it for each environment)
- if 'compat.py' in files:
- files.remove('compat.py')
-
- if 'six.py' in files:
- files.remove('six.py')
-
- # Same applies to template tags module.
- # This module has to include branching on Django versions,
- # so it's never possible for it to have full coverage.
- if 'rest_framework.py' in files:
- files.remove('rest_framework.py')
-
- cov_files.extend([os.path.join(path, file) for file in files if file.endswith('.py')])
-
- cov.report(cov_files)
- if '--html' in sys.argv:
- cov.html_report(cov_files, directory='coverage')
- sys.exit(failures)
-
-if __name__ == '__main__':
- main()
diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py
deleted file mode 100755
index 2daaae4e..00000000
--- a/rest_framework/runtests/runtests.py
+++ /dev/null
@@ -1,52 +0,0 @@
-#!/usr/bin/env python
-
-# http://ericholscher.com/blog/2009/jun/29/enable-setuppy-test-your-django-apps/
-# http://www.travisswicegood.com/2010/01/17/django-virtualenv-pip-and-fabric/
-# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
-import os
-import sys
-
-# fix sys path so we don't need to setup PYTHONPATH
-sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
-os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'
-
-import django
-from django.conf import settings
-from django.test.utils import get_runner
-
-
-def usage():
- return """
- Usage: python runtests.py [UnitTestClass].[method]
-
- You can pass the Class name of the `UnitTestClass` you want to test.
-
- Append a method name if you only want to test a specific method of that class.
- """
-
-
-def main():
- try:
- django.setup()
- except AttributeError:
- pass
- TestRunner = get_runner(settings)
-
- test_runner = TestRunner()
- if len(sys.argv) == 2:
- test_case = '.' + sys.argv[1]
- elif len(sys.argv) == 1:
- test_case = ''
- else:
- print(usage())
- sys.exit(1)
- test_module_name = 'rest_framework.tests'
- if django.VERSION[0] == 1 and django.VERSION[1] < 6:
- test_module_name = 'tests'
-
- failures = test_runner.run_tests([test_module_name + test_case])
-
- sys.exit(failures)
-
-if __name__ == '__main__':
- main()
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
deleted file mode 100644
index 3fc0eb2f..00000000
--- a/rest_framework/runtests/settings.py
+++ /dev/null
@@ -1,169 +0,0 @@
-# Django settings for testproject project.
-
-DEBUG = True
-TEMPLATE_DEBUG = DEBUG
-DEBUG_PROPAGATE_EXCEPTIONS = True
-
-ALLOWED_HOSTS = ['*']
-
-ADMINS = (
- # ('Your Name', 'your_email@domain.com'),
-)
-
-MANAGERS = ADMINS
-
-DATABASES = {
- 'default': {
- 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'.
- 'NAME': 'sqlite.db', # Or path to database file if using sqlite3.
- 'USER': '', # Not used with sqlite3.
- 'PASSWORD': '', # Not used with sqlite3.
- 'HOST': '', # Set to empty string for localhost. Not used with sqlite3.
- 'PORT': '', # Set to empty string for default. Not used with sqlite3.
- }
-}
-
-CACHES = {
- 'default': {
- 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
- }
-}
-
-# Local time zone for this installation. Choices can be found here:
-# http://en.wikipedia.org/wiki/List_of_tz_zones_by_name
-# although not all choices may be available on all operating systems.
-# On Unix systems, a value of None will cause Django to use the same
-# timezone as the operating system.
-# If running in a Windows environment this must be set to the same as your
-# system time zone.
-TIME_ZONE = 'Europe/London'
-
-# Language code for this installation. All choices can be found here:
-# http://www.i18nguy.com/unicode/language-identifiers.html
-LANGUAGE_CODE = 'en-uk'
-
-SITE_ID = 1
-
-# If you set this to False, Django will make some optimizations so as not
-# to load the internationalization machinery.
-USE_I18N = True
-
-# If you set this to False, Django will not format dates, numbers and
-# calendars according to the current locale
-USE_L10N = True
-
-# Absolute filesystem path to the directory that will hold user-uploaded files.
-# Example: "/home/media/media.lawrence.com/"
-MEDIA_ROOT = ''
-
-# URL that handles the media served from MEDIA_ROOT. Make sure to use a
-# trailing slash if there is a path component (optional in other cases).
-# Examples: "http://media.lawrence.com", "http://example.com/media/"
-MEDIA_URL = ''
-
-# Make this unique, and don't share it with anybody.
-SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy'
-
-# List of callables that know how to import templates from various sources.
-TEMPLATE_LOADERS = (
- 'django.template.loaders.filesystem.Loader',
- 'django.template.loaders.app_directories.Loader',
-# 'django.template.loaders.eggs.Loader',
-)
-
-MIDDLEWARE_CLASSES = (
- 'django.middleware.common.CommonMiddleware',
- 'django.contrib.sessions.middleware.SessionMiddleware',
- 'django.middleware.csrf.CsrfViewMiddleware',
- 'django.contrib.auth.middleware.AuthenticationMiddleware',
- 'django.contrib.messages.middleware.MessageMiddleware',
-)
-
-ROOT_URLCONF = 'urls'
-
-TEMPLATE_DIRS = (
- # Put strings here, like "/home/html/django_templates" or "C:/www/django/templates".
- # Always use forward slashes, even on Windows.
- # Don't forget to use absolute paths, not relative paths.
-)
-
-INSTALLED_APPS = (
- 'django.contrib.auth',
- 'django.contrib.contenttypes',
- 'django.contrib.sessions',
- 'django.contrib.sites',
- 'django.contrib.messages',
- # Uncomment the next line to enable the admin:
- # 'django.contrib.admin',
- # Uncomment the next line to enable admin documentation:
- # 'django.contrib.admindocs',
- 'rest_framework',
- 'rest_framework.authtoken',
- 'rest_framework.tests',
- 'rest_framework.tests.accounts',
- 'rest_framework.tests.records',
- 'rest_framework.tests.users',
-)
-
-# OAuth is optional and won't work if there is no oauth_provider & oauth2
-try:
- import oauth_provider
- import oauth2
-except ImportError:
- pass
-else:
- INSTALLED_APPS += (
- 'oauth_provider',
- )
-
-try:
- import provider
-except ImportError:
- pass
-else:
- INSTALLED_APPS += (
- 'provider',
- 'provider.oauth2',
- )
-
-# guardian is optional
-try:
- import guardian
-except ImportError:
- pass
-else:
- ANONYMOUS_USER_ID = -1
- AUTHENTICATION_BACKENDS = (
- 'django.contrib.auth.backends.ModelBackend', # default
- 'guardian.backends.ObjectPermissionBackend',
- )
- INSTALLED_APPS += (
- 'guardian',
- )
-
-STATIC_URL = '/static/'
-
-PASSWORD_HASHERS = (
- 'django.contrib.auth.hashers.SHA1PasswordHasher',
- 'django.contrib.auth.hashers.PBKDF2PasswordHasher',
- 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher',
- 'django.contrib.auth.hashers.BCryptPasswordHasher',
- 'django.contrib.auth.hashers.MD5PasswordHasher',
- 'django.contrib.auth.hashers.CryptPasswordHasher',
-)
-
-AUTH_USER_MODEL = 'auth.User'
-
-import django
-
-if django.VERSION < (1, 3):
- INSTALLED_APPS += ('staticfiles',)
-
-
-# If we're running on the Jenkins server we want to archive the coverage reports as XML.
-import os
-if os.environ.get('HUDSON_URL', None):
- TEST_RUNNER = 'xmlrunner.extra.djangotestrunner.XMLTestRunner'
- TEST_OUTPUT_VERBOSE = True
- TEST_OUTPUT_DESCRIPTIONS = True
- TEST_OUTPUT_DIR = 'xmlrunner'
diff --git a/rest_framework/runtests/urls.py b/rest_framework/runtests/urls.py
deleted file mode 100644
index ed5baeae..00000000
--- a/rest_framework/runtests/urls.py
+++ /dev/null
@@ -1,7 +0,0 @@
-"""
-Blank URLConf just to keep runtests.py happy.
-"""
-from rest_framework.compat import patterns
-
-urlpatterns = patterns('',
-)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index cb7539e0..be8ad3f2 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -16,11 +16,13 @@ import datetime
import inspect
import types
from decimal import Decimal
+from django.contrib.contenttypes.generic import GenericForeignKey
from django.core.paginator import Page
from django.db import models
from django.forms import widgets
+from django.utils import six
from django.utils.datastructures import SortedDict
-from rest_framework.compat import get_concrete_model, six
+from django.core.exceptions import ObjectDoesNotExist
from rest_framework.settings import api_settings
@@ -31,8 +33,8 @@ from rest_framework.settings import api_settings
# This helps keep the separation between model fields, form fields, and
# serializer fields more explicit.
-from rest_framework.relations import *
-from rest_framework.fields import *
+from rest_framework.relations import * # NOQA
+from rest_framework.fields import * # NOQA
def _resolve_model(obj):
@@ -47,7 +49,7 @@ def _resolve_model(obj):
String representations should have the format:
'appname.ModelName'
"""
- if type(obj) == str and len(obj.split('.')) == 2:
+ if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
@@ -180,7 +182,7 @@ class BaseSerializer(WritableField):
_dict_class = SortedDictWithMetadata
def __init__(self, instance=None, data=None, files=None,
- context=None, partial=False, many=None,
+ context=None, partial=False, many=False,
allow_add_remove=False, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
@@ -343,7 +345,7 @@ class BaseSerializer(WritableField):
for field_name, field in self.fields.items():
if field.read_only and obj is None:
- continue
+ continue
field.initialize(parent=self, field_name=field_name)
key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name)
@@ -410,12 +412,7 @@ class BaseSerializer(WritableField):
if value is None:
return None
- if self.many is not None:
- many = self.many
- else:
- many = hasattr(value, '__iter__') and not isinstance(value, (Page, dict, six.text_type))
-
- if many:
+ if self.many:
return [self.to_native(item) for item in value]
return self.to_native(value)
@@ -452,9 +449,11 @@ class BaseSerializer(WritableField):
# If we have a model manager or similar object then we need
# to iterate through each instance.
- if (self.many and
+ if (
+ self.many and
not hasattr(obj, '__iter__') and
- is_simple_callable(getattr(obj, 'all', None))):
+ is_simple_callable(getattr(obj, 'all', None))
+ ):
obj = obj.all()
kwargs = {
@@ -604,8 +603,10 @@ class BaseSerializer(WritableField):
API schemas for auto-documentation.
"""
return SortedDict(
- [(field_name, field.metadata())
- for field_name, field in six.iteritems(self.fields)]
+ [
+ (field_name, field.metadata())
+ for field_name, field in six.iteritems(self.fields)
+ ]
)
@@ -659,9 +660,11 @@ class ModelSerializer(Serializer):
"""
cls = self.opts.model
- assert cls is not None, \
- "Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__
- opts = get_concrete_model(cls)._meta
+ assert cls is not None, (
+ "Serializer class '%s' is missing 'model' Meta option" %
+ self.__class__.__name__
+ )
+ opts = cls._meta.concrete_model._meta
ret = SortedDict()
nested = bool(self.opts.depth)
@@ -671,9 +674,9 @@ class ModelSerializer(Serializer):
# If model is a child via multitable inheritance, use parent's pk
pk_field = pk_field.rel.to._meta.pk
- field = self.get_pk_field(pk_field)
- if field:
- ret[pk_field.name] = field
+ serializer_pk_field = self.get_pk_field(pk_field)
+ if serializer_pk_field:
+ ret[pk_field.name] = serializer_pk_field
# Deal with forward relationships
forward_rels = [field for field in opts.fields if field.serialize]
@@ -694,10 +697,10 @@ class ModelSerializer(Serializer):
if len(inspect.getargspec(self.get_nested_field).args) == 2:
warnings.warn(
'The `get_nested_field(model_field)` call signature '
- 'is due to be deprecated. '
+ 'is deprecated. '
'Use `get_nested_field(model_field, related_model, '
'to_many) instead',
- PendingDeprecationWarning
+ DeprecationWarning
)
field = self.get_nested_field(model_field)
else:
@@ -706,10 +709,10 @@ class ModelSerializer(Serializer):
if len(inspect.getargspec(self.get_nested_field).args) == 3:
warnings.warn(
'The `get_related_field(model_field, to_many)` call '
- 'signature is due to be deprecated. '
+ 'signature is deprecated. '
'Use `get_related_field(model_field, related_model, '
'to_many) instead',
- PendingDeprecationWarning
+ DeprecationWarning
)
field = self.get_related_field(model_field, to_many=to_many)
else:
@@ -742,9 +745,11 @@ class ModelSerializer(Serializer):
is_m2m = isinstance(relation.field,
models.fields.related.ManyToManyField)
- if (is_m2m and
+ if (
+ is_m2m and
hasattr(relation.field.rel, 'through') and
- not relation.field.rel.through._meta.auto_created):
+ not relation.field.rel.through._meta.auto_created
+ ):
has_through_model = True
if nested:
@@ -757,9 +762,9 @@ class ModelSerializer(Serializer):
field.read_only = True
ret[accessor_name] = field
-
+
# Ensure that 'read_only_fields' is an iterable
- assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
+ assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
# Add the `read_only` flag to any fields that have been specified
# in the `read_only_fields` option
@@ -774,10 +779,10 @@ class ModelSerializer(Serializer):
"on serializer '%s'." %
(field_name, self.__class__.__name__))
ret[field_name].read_only = True
-
+
# Ensure that 'write_only_fields' is an iterable
- assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
-
+ assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
+
for field_name in self.opts.write_only_fields:
assert field_name not in self.base_fields.keys(), (
"field '%s' on serializer '%s' specified in "
@@ -788,7 +793,7 @@ class ModelSerializer(Serializer):
"Non-existant field '%s' specified in `write_only_fields` "
"on serializer '%s'." %
(field_name, self.__class__.__name__))
- ret[field_name].write_only = True
+ ret[field_name].write_only = True
return ret
@@ -827,6 +832,19 @@ class ModelSerializer(Serializer):
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
+ if model_field.help_text is not None:
+ kwargs['help_text'] = model_field.help_text
+ if model_field.verbose_name is not None:
+ kwargs['label'] = model_field.verbose_name
+
+ if not model_field.editable:
+ kwargs['read_only'] = True
+
+ if model_field.verbose_name is not None:
+ kwargs['label'] = model_field.verbose_name
+
+ if model_field.help_text is not None:
+ kwargs['help_text'] = model_field.help_text
return PrimaryKeyRelatedField(**kwargs)
@@ -866,6 +884,10 @@ class ModelSerializer(Serializer):
issubclass(model_field.__class__, models.PositiveSmallIntegerField):
kwargs['min_value'] = 0
+ if model_field.null and \
+ issubclass(model_field.__class__, (models.CharField, models.TextField)):
+ kwargs['allow_none'] = True
+
attribute_dict = {
models.CharField: ['max_length'],
models.CommaSeparatedIntegerField: ['max_length'],
@@ -892,15 +914,17 @@ class ModelSerializer(Serializer):
Return a list of field names to exclude from model validation.
"""
cls = self.opts.model
- opts = get_concrete_model(cls)._meta
+ opts = cls._meta.concrete_model._meta
exclusions = [field.name for field in opts.fields + opts.many_to_many]
for field_name, field in self.fields.items():
field_name = field.source or field_name
- if field_name in exclusions \
- and not field.read_only \
- and (field.required or hasattr(instance, field_name)) \
- and not isinstance(field, Serializer):
+ if (
+ field_name in exclusions
+ and not field.read_only
+ and (field.required or hasattr(instance, field_name))
+ and not isinstance(field, Serializer)
+ ):
exclusions.remove(field_name)
return exclusions
@@ -943,6 +967,8 @@ class ModelSerializer(Serializer):
# Forward m2m relations
for field in meta.many_to_many + meta.virtual_fields:
+ if isinstance(field, GenericForeignKey):
+ continue
if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name)
@@ -952,17 +978,15 @@ class ModelSerializer(Serializer):
if isinstance(self.fields.get(field_name, None), Serializer):
nested_forward_relations[field_name] = attrs[field_name]
- # Update an existing instance...
- if instance is not None:
- for key, val in attrs.items():
- try:
- setattr(instance, key, val)
- except ValueError:
- self._errors[key] = self.error_messages['required']
+ # Create an empty instance of the model
+ if instance is None:
+ instance = self.opts.model()
- # ...or create a new instance
- else:
- instance = self.opts.model(**attrs)
+ for key, val in attrs.items():
+ try:
+ setattr(instance, key, val)
+ except ValueError:
+ self._errors[key] = [self.error_messages['required']]
# Any relations that cannot be set until we've
# saved the model get hidden away on these
@@ -1087,6 +1111,10 @@ class HyperlinkedModelSerializer(ModelSerializer):
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
+ if model_field.help_text is not None:
+ kwargs['help_text'] = model_field.help_text
+ if model_field.verbose_name is not None:
+ kwargs['label'] = model_field.verbose_name
if self.opts.lookup_field:
kwargs['lookup_field'] = self.opts.lookup_field
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 38753c96..644751f8 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -18,12 +18,9 @@ REST framework settings, checking for user settings first, then falling
back to the defaults.
"""
from __future__ import unicode_literals
-
from django.conf import settings
-from django.utils import importlib
-
+from django.utils import importlib, six
from rest_framework import ISO_8601
-from rest_framework.compat import six
USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None)
@@ -46,16 +43,12 @@ DEFAULTS = {
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.AllowAny',
),
- 'DEFAULT_THROTTLE_CLASSES': (
- ),
- 'DEFAULT_CONTENT_NEGOTIATION_CLASS':
- 'rest_framework.negotiation.DefaultContentNegotiation',
+ 'DEFAULT_THROTTLE_CLASSES': (),
+ 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
# Genric view behavior
- 'DEFAULT_MODEL_SERIALIZER_CLASS':
- 'rest_framework.serializers.ModelSerializer',
- 'DEFAULT_PAGINATION_SERIALIZER_CLASS':
- 'rest_framework.pagination.PaginationSerializer',
+ 'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.serializers.ModelSerializer',
+ 'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer',
'DEFAULT_FILTER_BACKENDS': (),
# Throttling
@@ -63,6 +56,7 @@ DEFAULTS = {
'user': None,
'anon': None,
},
+ 'NUM_PROXIES': None,
# Pagination
'PAGINATE_BY': None,
@@ -119,6 +113,7 @@ DEFAULTS = {
# Pending deprecation
'FILTER_BACKEND': None,
+
}
diff --git a/rest_framework/six.py b/rest_framework/six.py
deleted file mode 100644
index 9e382312..00000000
--- a/rest_framework/six.py
+++ /dev/null
@@ -1,389 +0,0 @@
-"""Utilities for writing code that runs on Python 2 and 3"""
-
-import operator
-import sys
-import types
-
-__author__ = "Benjamin Peterson <benjamin@python.org>"
-__version__ = "1.2.0"
-
-
-# True if we are running on Python 3.
-PY3 = sys.version_info[0] == 3
-
-if PY3:
- string_types = str,
- integer_types = int,
- class_types = type,
- text_type = str
- binary_type = bytes
-
- MAXSIZE = sys.maxsize
-else:
- string_types = basestring,
- integer_types = (int, long)
- class_types = (type, types.ClassType)
- text_type = unicode
- binary_type = str
-
- if sys.platform == "java":
- # Jython always uses 32 bits.
- MAXSIZE = int((1 << 31) - 1)
- else:
- # It's possible to have sizeof(long) != sizeof(Py_ssize_t).
- class X(object):
- def __len__(self):
- return 1 << 31
- try:
- len(X())
- except OverflowError:
- # 32-bit
- MAXSIZE = int((1 << 31) - 1)
- else:
- # 64-bit
- MAXSIZE = int((1 << 63) - 1)
- del X
-
-
-def _add_doc(func, doc):
- """Add documentation to a function."""
- func.__doc__ = doc
-
-
-def _import_module(name):
- """Import module, returning the module after the last dot."""
- __import__(name)
- return sys.modules[name]
-
-
-class _LazyDescr(object):
-
- def __init__(self, name):
- self.name = name
-
- def __get__(self, obj, tp):
- result = self._resolve()
- setattr(obj, self.name, result)
- # This is a bit ugly, but it avoids running this again.
- delattr(tp, self.name)
- return result
-
-
-class MovedModule(_LazyDescr):
-
- def __init__(self, name, old, new=None):
- super(MovedModule, self).__init__(name)
- if PY3:
- if new is None:
- new = name
- self.mod = new
- else:
- self.mod = old
-
- def _resolve(self):
- return _import_module(self.mod)
-
-
-class MovedAttribute(_LazyDescr):
-
- def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
- super(MovedAttribute, self).__init__(name)
- if PY3:
- if new_mod is None:
- new_mod = name
- self.mod = new_mod
- if new_attr is None:
- if old_attr is None:
- new_attr = name
- else:
- new_attr = old_attr
- self.attr = new_attr
- else:
- self.mod = old_mod
- if old_attr is None:
- old_attr = name
- self.attr = old_attr
-
- def _resolve(self):
- module = _import_module(self.mod)
- return getattr(module, self.attr)
-
-
-
-class _MovedItems(types.ModuleType):
- """Lazy loading of moved objects"""
-
-
-_moved_attributes = [
- MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
- MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
- MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
- MovedAttribute("map", "itertools", "builtins", "imap", "map"),
- MovedAttribute("reload_module", "__builtin__", "imp", "reload"),
- MovedAttribute("reduce", "__builtin__", "functools"),
- MovedAttribute("StringIO", "StringIO", "io"),
- MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
- MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
-
- MovedModule("builtins", "__builtin__"),
- MovedModule("configparser", "ConfigParser"),
- MovedModule("copyreg", "copy_reg"),
- MovedModule("http_cookiejar", "cookielib", "http.cookiejar"),
- MovedModule("http_cookies", "Cookie", "http.cookies"),
- MovedModule("html_entities", "htmlentitydefs", "html.entities"),
- MovedModule("html_parser", "HTMLParser", "html.parser"),
- MovedModule("http_client", "httplib", "http.client"),
- MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"),
- MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"),
- MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"),
- MovedModule("cPickle", "cPickle", "pickle"),
- MovedModule("queue", "Queue"),
- MovedModule("reprlib", "repr"),
- MovedModule("socketserver", "SocketServer"),
- MovedModule("tkinter", "Tkinter"),
- MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"),
- MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"),
- MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"),
- MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"),
- MovedModule("tkinter_tix", "Tix", "tkinter.tix"),
- MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
- MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
- MovedModule("tkinter_colorchooser", "tkColorChooser",
- "tkinter.colorchooser"),
- MovedModule("tkinter_commondialog", "tkCommonDialog",
- "tkinter.commondialog"),
- MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
- MovedModule("tkinter_font", "tkFont", "tkinter.font"),
- MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
- MovedModule("tkinter_tksimpledialog", "tkSimpleDialog",
- "tkinter.simpledialog"),
- MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"),
- MovedModule("winreg", "_winreg"),
-]
-for attr in _moved_attributes:
- setattr(_MovedItems, attr.name, attr)
-del attr
-
-moves = sys.modules["django.utils.six.moves"] = _MovedItems("moves")
-
-
-def add_move(move):
- """Add an item to six.moves."""
- setattr(_MovedItems, move.name, move)
-
-
-def remove_move(name):
- """Remove item from six.moves."""
- try:
- delattr(_MovedItems, name)
- except AttributeError:
- try:
- del moves.__dict__[name]
- except KeyError:
- raise AttributeError("no such move, %r" % (name,))
-
-
-if PY3:
- _meth_func = "__func__"
- _meth_self = "__self__"
-
- _func_code = "__code__"
- _func_defaults = "__defaults__"
-
- _iterkeys = "keys"
- _itervalues = "values"
- _iteritems = "items"
-else:
- _meth_func = "im_func"
- _meth_self = "im_self"
-
- _func_code = "func_code"
- _func_defaults = "func_defaults"
-
- _iterkeys = "iterkeys"
- _itervalues = "itervalues"
- _iteritems = "iteritems"
-
-
-try:
- advance_iterator = next
-except NameError:
- def advance_iterator(it):
- return it.next()
-next = advance_iterator
-
-
-if PY3:
- def get_unbound_function(unbound):
- return unbound
-
- Iterator = object
-
- def callable(obj):
- return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
-else:
- def get_unbound_function(unbound):
- return unbound.im_func
-
- class Iterator(object):
-
- def next(self):
- return type(self).__next__(self)
-
- callable = callable
-_add_doc(get_unbound_function,
- """Get the function out of a possibly unbound function""")
-
-
-get_method_function = operator.attrgetter(_meth_func)
-get_method_self = operator.attrgetter(_meth_self)
-get_function_code = operator.attrgetter(_func_code)
-get_function_defaults = operator.attrgetter(_func_defaults)
-
-
-def iterkeys(d):
- """Return an iterator over the keys of a dictionary."""
- return iter(getattr(d, _iterkeys)())
-
-def itervalues(d):
- """Return an iterator over the values of a dictionary."""
- return iter(getattr(d, _itervalues)())
-
-def iteritems(d):
- """Return an iterator over the (key, value) pairs of a dictionary."""
- return iter(getattr(d, _iteritems)())
-
-
-if PY3:
- def b(s):
- return s.encode("latin-1")
- def u(s):
- return s
- if sys.version_info[1] <= 1:
- def int2byte(i):
- return bytes((i,))
- else:
- # This is about 2x faster than the implementation above on 3.2+
- int2byte = operator.methodcaller("to_bytes", 1, "big")
- import io
- StringIO = io.StringIO
- BytesIO = io.BytesIO
-else:
- def b(s):
- return s
- def u(s):
- return unicode(s, "unicode_escape")
- int2byte = chr
- import StringIO
- StringIO = BytesIO = StringIO.StringIO
-_add_doc(b, """Byte literal""")
-_add_doc(u, """Text literal""")
-
-
-if PY3:
- import builtins
- exec_ = getattr(builtins, "exec")
-
-
- def reraise(tp, value, tb=None):
- if value.__traceback__ is not tb:
- raise value.with_traceback(tb)
- raise value
-
-
- print_ = getattr(builtins, "print")
- del builtins
-
-else:
- def exec_(code, globs=None, locs=None):
- """Execute code in a namespace."""
- if globs is None:
- frame = sys._getframe(1)
- globs = frame.f_globals
- if locs is None:
- locs = frame.f_locals
- del frame
- elif locs is None:
- locs = globs
- exec("""exec code in globs, locs""")
-
-
- exec_("""def reraise(tp, value, tb=None):
- raise tp, value, tb
-""")
-
-
- def print_(*args, **kwargs):
- """The new-style print function."""
- fp = kwargs.pop("file", sys.stdout)
- if fp is None:
- return
- def write(data):
- if not isinstance(data, basestring):
- data = str(data)
- fp.write(data)
- want_unicode = False
- sep = kwargs.pop("sep", None)
- if sep is not None:
- if isinstance(sep, unicode):
- want_unicode = True
- elif not isinstance(sep, str):
- raise TypeError("sep must be None or a string")
- end = kwargs.pop("end", None)
- if end is not None:
- if isinstance(end, unicode):
- want_unicode = True
- elif not isinstance(end, str):
- raise TypeError("end must be None or a string")
- if kwargs:
- raise TypeError("invalid keyword arguments to print()")
- if not want_unicode:
- for arg in args:
- if isinstance(arg, unicode):
- want_unicode = True
- break
- if want_unicode:
- newline = unicode("\n")
- space = unicode(" ")
- else:
- newline = "\n"
- space = " "
- if sep is None:
- sep = space
- if end is None:
- end = newline
- for i, arg in enumerate(args):
- if i:
- write(sep)
- write(arg)
- write(end)
-
-_add_doc(reraise, """Reraise an exception.""")
-
-
-def with_metaclass(meta, base=object):
- """Create a base class with a metaclass."""
- return meta("NewBase", (base,), {})
-
-
-### Additional customizations for Django ###
-
-if PY3:
- _iterlists = "lists"
- _assertRaisesRegex = "assertRaisesRegex"
-else:
- _iterlists = "iterlists"
- _assertRaisesRegex = "assertRaisesRegexp"
-
-
-def iterlists(d):
- """Return an iterator over the values of a MultiValueDict."""
- return getattr(d, _iterlists)()
-
-
-def assertRaisesRegex(self, *args, **kwargs):
- return getattr(self, _assertRaisesRegex)(*args, **kwargs)
-
-
-add_move(MovedModule("_dummy_thread", "dummy_thread"))
-add_move(MovedModule("_thread", "thread"))
diff --git a/rest_framework/status.py b/rest_framework/status.py
index 76435371..90a75508 100644
--- a/rest_framework/status.py
+++ b/rest_framework/status.py
@@ -10,15 +10,19 @@ from __future__ import unicode_literals
def is_informational(code):
return code >= 100 and code <= 199
+
def is_success(code):
return code >= 200 and code <= 299
+
def is_redirect(code):
return code >= 300 and code <= 399
+
def is_client_error(code):
return code >= 400 and code <= 499
+
def is_server_error(code):
return code >= 500 and code <= 599
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 7067ee2f..b6e9ca5c 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -1,4 +1,5 @@
{% load url from future %}
+{% load staticfiles %}
{% load rest_framework %}
<!DOCTYPE html>
<html>
@@ -24,6 +25,7 @@
{% endblock %}
</head>
+ {% block body %}
<body class="{% block bodyclass %}{% endblock %} container">
<div class="wrapper">
@@ -93,7 +95,7 @@
{% endif %}
{% if options_form %}
- <form class="button-form" action="{{ request.get_full_path }}" method="POST" class="pull-right">
+ <form class="button-form" action="{{ request.get_full_path }}" method="POST">
{% csrf_token %}
<input type="hidden" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="OPTIONS" />
<button class="btn btn-primary js-tooltip" title="Make an OPTIONS request on the {{ name }} resource">OPTIONS</button>
@@ -101,7 +103,7 @@
{% endif %}
{% if delete_form %}
- <form class="button-form" action="{{ request.get_full_path }}" method="POST" class="pull-right">
+ <form class="button-form" action="{{ request.get_full_path }}" method="POST">
{% csrf_token %}
<input type="hidden" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="DELETE" />
<button class="btn btn-danger js-tooltip" title="Make a DELETE request on the {{ name }} resource">DELETE</button>
@@ -230,4 +232,5 @@
<script src="{% static "rest_framework/js/default.js" %}"></script>
{% endblock %}
</body>
+ {% endblock %}
</html>
diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html
index be9a0072..43860e53 100644
--- a/rest_framework/templates/rest_framework/login_base.html
+++ b/rest_framework/templates/rest_framework/login_base.html
@@ -1,17 +1,9 @@
+{% extends "rest_framework/base.html" %}
{% load url from future %}
+{% load staticfiles %}
{% load rest_framework %}
-<html>
-
- <head>
- {% block style %}
- {% block bootstrap_theme %}
- <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>
- <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
- {% endblock %}
- <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>
- {% endblock %}
- </head>
+ {% block body %}
<body class="container">
<div class="container-fluid" style="margin-top: 30px">
@@ -50,4 +42,4 @@
</div><!-- /.row-fluid -->
</div><!-- /.container-fluid -->
</body>
-</html>
+ {% endblock %}
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index beb8c5b0..b80a7d77 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -2,98 +2,17 @@ from __future__ import unicode_literals, absolute_import
from django import template
from django.core.urlresolvers import reverse, NoReverseMatch
from django.http import QueryDict
+from django.utils import six
from django.utils.encoding import iri_to_uri
from django.utils.html import escape
from django.utils.safestring import SafeData, mark_safe
-from rest_framework.compat import urlparse, force_text, six, smart_urlquote
+from django.utils.html import smart_urlquote
+from rest_framework.compat import urlparse, force_text
import re
register = template.Library()
-# Note we don't use 'load staticfiles', because we need a 1.3 compatible
-# version, so instead we include the `static` template tag ourselves.
-
-# When 1.3 becomes unsupported by REST framework, we can instead start to
-# use the {% load staticfiles %} tag, remove the following code,
-# and add a dependency that `django.contrib.staticfiles` must be installed.
-
-# Note: We can't put this into the `compat` module because the compat import
-# from rest_framework.compat import ...
-# conflicts with this rest_framework template tag module.
-
-try: # Django 1.5+
- from django.contrib.staticfiles.templatetags.staticfiles import StaticFilesNode
-
- @register.tag('static')
- def do_static(parser, token):
- return StaticFilesNode.handle_token(parser, token)
-
-except ImportError:
- try: # Django 1.4
- from django.contrib.staticfiles.storage import staticfiles_storage
-
- @register.simple_tag
- def static(path):
- """
- A template tag that returns the URL to a file
- using staticfiles' storage backend
- """
- return staticfiles_storage.url(path)
-
- except ImportError: # Django 1.3
- from urlparse import urljoin
- from django import template
- from django.templatetags.static import PrefixNode
-
- class StaticNode(template.Node):
- def __init__(self, varname=None, path=None):
- if path is None:
- raise template.TemplateSyntaxError(
- "Static template nodes must be given a path to return.")
- self.path = path
- self.varname = varname
-
- def url(self, context):
- path = self.path.resolve(context)
- return self.handle_simple(path)
-
- def render(self, context):
- url = self.url(context)
- if self.varname is None:
- return url
- context[self.varname] = url
- return ''
-
- @classmethod
- def handle_simple(cls, path):
- return urljoin(PrefixNode.handle_simple("STATIC_URL"), path)
-
- @classmethod
- def handle_token(cls, parser, token):
- """
- Class method to parse prefix node and return a Node.
- """
- bits = token.split_contents()
-
- if len(bits) < 2:
- raise template.TemplateSyntaxError(
- "'%s' takes at least one argument (path to file)" % bits[0])
-
- path = parser.compile_filter(bits[1])
-
- if len(bits) >= 2 and bits[-2] == 'as':
- varname = bits[3]
- else:
- varname = None
-
- return cls(varname, path)
-
- @register.tag('static')
- def do_static_13(parser, token):
- return StaticNode.handle_token(parser, token)
-
-
def replace_query_param(url, key, val):
"""
Given a URL and a key/val pair, set or replace an item in the query
@@ -122,7 +41,7 @@ def optional_login(request):
except NoReverseMatch:
return ''
- snippet = "<a href='%s?next=%s'>Log in</a>" % (login_url, request.path)
+ snippet = "<a href='%s?next=%s'>Log in</a>" % (login_url, escape(request.path))
return snippet
@@ -136,7 +55,7 @@ def optional_logout(request):
except NoReverseMatch:
return ''
- snippet = "<a href='%s?next=%s'>Log out</a>" % (logout_url, request.path)
+ snippet = "<a href='%s?next=%s'>Log out</a>" % (logout_url, escape(request.path))
return snippet
@@ -180,7 +99,7 @@ def add_class(value, css_class):
# Bunch of stuff cloned from urlize
-TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"]
+TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}", "'"]
WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('&lt;', '&gt;'),
('"', '"'), ("'", "'")]
word_split_re = re.compile(r'(\s+)')
@@ -234,8 +153,10 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
middle = middle[len(opening):]
lead = lead + opening
# Keep parentheses at the end only if they're balanced.
- if (middle.endswith(closing)
- and middle.count(closing) == middle.count(opening) + 1):
+ if (
+ middle.endswith(closing)
+ and middle.count(closing) == middle.count(opening) + 1
+ ):
middle = middle[:-len(closing)]
trail = closing + trail
@@ -246,7 +167,7 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
url = smart_urlquote_wrapper(middle)
elif simple_url_2_re.match(middle):
url = smart_urlquote_wrapper('http://%s' % middle)
- elif not ':' in middle and simple_email_re.match(middle):
+ elif ':' not in middle and simple_email_re.match(middle):
local, domain = middle.rsplit('@', 1)
try:
domain = domain.encode('idna').decode('ascii')
diff --git a/rest_framework/test.py b/rest_framework/test.py
index df5a5b3b..f89a6dcd 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -8,10 +8,11 @@ from django.conf import settings
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
from django.test import testcases
+from django.utils import six
from django.utils.http import urlencode
from rest_framework.settings import api_settings
from rest_framework.compat import RequestFactory as DjangoRequestFactory
-from rest_framework.compat import force_bytes_or_smart_bytes, six
+from rest_framework.compat import force_bytes_or_smart_bytes
def force_authenticate(request, user=None, token=None):
@@ -36,7 +37,7 @@ class APIRequestFactory(DjangoRequestFactory):
"""
if not data:
- return ('', None)
+ return ('', content_type)
assert format is None or content_type is None, (
'You may not set both `format` and `content_type`.'
@@ -49,9 +50,10 @@ class APIRequestFactory(DjangoRequestFactory):
else:
format = format or self.default_format
- assert format in self.renderer_classes, ("Invalid format '{0}'. "
- "Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES "
- "to enable extra request formats.".format(
+ assert format in self.renderer_classes, (
+ "Invalid format '{0}'. Available formats are {1}. "
+ "Set TEST_REQUEST_RENDERER_CLASSES to enable "
+ "extra request formats.".format(
format,
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])
)
@@ -154,6 +156,10 @@ class APIClient(APIRequestFactory, DjangoClient):
kwargs.update(self._credentials)
return super(APIClient, self).request(**kwargs)
+ def logout(self):
+ self._credentials = {}
+ return super(APIClient, self).logout()
+
class APITransactionTestCase(testcases.TransactionTestCase):
client_class = APIClient
diff --git a/rest_framework/tests/__init__.py b/rest_framework/tests/__init__.py
deleted file mode 100644
index e69de29b..00000000
--- a/rest_framework/tests/__init__.py
+++ /dev/null
diff --git a/rest_framework/tests/accounts/__init__.py b/rest_framework/tests/accounts/__init__.py
deleted file mode 100644
index e69de29b..00000000
--- a/rest_framework/tests/accounts/__init__.py
+++ /dev/null
diff --git a/rest_framework/tests/accounts/models.py b/rest_framework/tests/accounts/models.py
deleted file mode 100644
index 525e601b..00000000
--- a/rest_framework/tests/accounts/models.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from django.db import models
-
-from rest_framework.tests.users.models import User
-
-
-class Account(models.Model):
- owner = models.ForeignKey(User, related_name='accounts_owned')
- admins = models.ManyToManyField(User, blank=True, null=True, related_name='accounts_administered')
diff --git a/rest_framework/tests/accounts/serializers.py b/rest_framework/tests/accounts/serializers.py
deleted file mode 100644
index a27b9ca6..00000000
--- a/rest_framework/tests/accounts/serializers.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from rest_framework import serializers
-
-from rest_framework.tests.accounts.models import Account
-from rest_framework.tests.users.serializers import UserSerializer
-
-
-class AccountSerializer(serializers.ModelSerializer):
- admins = UserSerializer(many=True)
-
- class Meta:
- model = Account
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py
deleted file mode 100644
index b46d7f54..00000000
--- a/rest_framework/tests/description.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# -- coding: utf-8 --
-
-# Apparently there is a python 2.6 issue where docstrings of imported view classes
-# do not retain their encoding information even if a module has a proper
-# encoding declaration at the top of its source file. Therefore for tests
-# to catch unicode related errors, a mock view has to be declared in a separate
-# module.
-
-from rest_framework.views import APIView
-
-
-# test strings snatched from http://www.columbia.edu/~fdc/utf8/,
-# http://winrus.com/utf8-jap.htm and memory
-UTF8_TEST_DOCSTRING = (
- 'zażółć gęślą jaźń'
- 'Sîne klâwen durh die wolken sint geslagen'
- 'Τη γλώσσα μου έδωσαν ελληνική'
- 'யாமறிந்த மொழிகளிலே தமிழ்மொழி'
- 'На берегу пустынных волн'
- 'てすと'
- 'アイウエオカキクケコサシスセソタチツテ'
-)
-
-
-class ViewWithNonASCIICharactersInDocstring(APIView):
- __doc__ = UTF8_TEST_DOCSTRING
diff --git a/rest_framework/tests/extras/__init__.py b/rest_framework/tests/extras/__init__.py
deleted file mode 100644
index e69de29b..00000000
--- a/rest_framework/tests/extras/__init__.py
+++ /dev/null
diff --git a/rest_framework/tests/extras/bad_import.py b/rest_framework/tests/extras/bad_import.py
deleted file mode 100644
index 68263d94..00000000
--- a/rest_framework/tests/extras/bad_import.py
+++ /dev/null
@@ -1 +0,0 @@
-raise ValueError
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
deleted file mode 100644
index 6c8f2342..00000000
--- a/rest_framework/tests/models.py
+++ /dev/null
@@ -1,177 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.utils.translation import ugettext_lazy as _
-from rest_framework import serializers
-
-
-def foobar():
- return 'foobar'
-
-
-class CustomField(models.CharField):
-
- def __init__(self, *args, **kwargs):
- kwargs['max_length'] = 12
- super(CustomField, self).__init__(*args, **kwargs)
-
-
-class RESTFrameworkModel(models.Model):
- """
- Base for test models that sets app_label, so they play nicely.
- """
- class Meta:
- app_label = 'tests'
- abstract = True
-
-
-class HasPositiveIntegerAsChoice(RESTFrameworkModel):
- some_choices = ((1, 'A'), (2, 'B'), (3, 'C'))
- some_integer = models.PositiveIntegerField(choices=some_choices)
-
-
-class Anchor(RESTFrameworkModel):
- text = models.CharField(max_length=100, default='anchor')
-
-
-class BasicModel(RESTFrameworkModel):
- text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description."))
-
-
-class SlugBasedModel(RESTFrameworkModel):
- text = models.CharField(max_length=100)
- slug = models.SlugField(max_length=32)
-
-
-class DefaultValueModel(RESTFrameworkModel):
- text = models.CharField(default='foobar', max_length=100)
- extra = models.CharField(blank=True, null=True, max_length=100)
-
-
-class CallableDefaultValueModel(RESTFrameworkModel):
- text = models.CharField(default=foobar, max_length=100)
-
-
-class ManyToManyModel(RESTFrameworkModel):
- rel = models.ManyToManyField(Anchor, help_text='Some help text.')
-
-
-class ReadOnlyManyToManyModel(RESTFrameworkModel):
- text = models.CharField(max_length=100, default='anchor')
- rel = models.ManyToManyField(Anchor)
-
-
-# Model for regression test for #285
-
-class Comment(RESTFrameworkModel):
- email = models.EmailField()
- content = models.CharField(max_length=200)
- created = models.DateTimeField(auto_now_add=True)
-
-
-class ActionItem(RESTFrameworkModel):
- title = models.CharField(max_length=200)
- started = models.NullBooleanField(default=False)
- done = models.BooleanField(default=False)
- info = CustomField(default='---', max_length=12)
-
-
-# Models for reverse relations
-class Person(RESTFrameworkModel):
- name = models.CharField(max_length=10)
- age = models.IntegerField(null=True, blank=True)
-
- @property
- def info(self):
- return {
- 'name': self.name,
- 'age': self.age,
- }
-
-
-class BlogPost(RESTFrameworkModel):
- title = models.CharField(max_length=100)
- writer = models.ForeignKey(Person, null=True, blank=True)
-
- def get_first_comment(self):
- return self.blogpostcomment_set.all()[0]
-
-
-class BlogPostComment(RESTFrameworkModel):
- text = models.TextField()
- blog_post = models.ForeignKey(BlogPost)
-
-
-class Album(RESTFrameworkModel):
- title = models.CharField(max_length=100, unique=True)
- ref = models.CharField(max_length=10, unique=True, null=True, blank=True)
-
-class Photo(RESTFrameworkModel):
- description = models.TextField()
- album = models.ForeignKey(Album)
-
-
-# Model for issue #324
-class BlankFieldModel(RESTFrameworkModel):
- title = models.CharField(max_length=100, blank=True, null=False)
-
-
-# Model for issue #380
-class OptionalRelationModel(RESTFrameworkModel):
- other = models.ForeignKey('OptionalRelationModel', blank=True, null=True)
-
-
-# Model for RegexField
-class Book(RESTFrameworkModel):
- isbn = models.CharField(max_length=13)
-
-
-# Models for relations tests
-# ManyToMany
-class ManyToManyTarget(RESTFrameworkModel):
- name = models.CharField(max_length=100)
-
-
-class ManyToManySource(RESTFrameworkModel):
- name = models.CharField(max_length=100)
- targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
-
-
-# ForeignKey
-class ForeignKeyTarget(RESTFrameworkModel):
- name = models.CharField(max_length=100)
-
-
-class ForeignKeySource(RESTFrameworkModel):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
-
-
-# Nullable ForeignKey
-class NullableForeignKeySource(RESTFrameworkModel):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
- related_name='nullable_sources')
-
-
-# OneToOne
-class OneToOneTarget(RESTFrameworkModel):
- name = models.CharField(max_length=100)
-
-
-class NullableOneToOneSource(RESTFrameworkModel):
- name = models.CharField(max_length=100)
- target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
- related_name='nullable_source')
-
-
-# Serializer used to test BasicModel
-class BasicModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = BasicModel
-
-
-# Models to test filters
-class FilterableItem(models.Model):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
diff --git a/rest_framework/tests/records/__init__.py b/rest_framework/tests/records/__init__.py
deleted file mode 100644
index e69de29b..00000000
--- a/rest_framework/tests/records/__init__.py
+++ /dev/null
diff --git a/rest_framework/tests/records/models.py b/rest_framework/tests/records/models.py
deleted file mode 100644
index 76954807..00000000
--- a/rest_framework/tests/records/models.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from django.db import models
-
-
-class Record(models.Model):
- account = models.ForeignKey('accounts.Account', blank=True, null=True)
- owner = models.ForeignKey('users.User', blank=True, null=True)
diff --git a/rest_framework/tests/serializers.py b/rest_framework/tests/serializers.py
deleted file mode 100644
index cc943c7d..00000000
--- a/rest_framework/tests/serializers.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from rest_framework import serializers
-
-from rest_framework.tests.models import NullableForeignKeySource
-
-
-class NullableFKSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableForeignKeySource
diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py
deleted file mode 100644
index c37d2a51..00000000
--- a/rest_framework/tests/test_authentication.py
+++ /dev/null
@@ -1,663 +0,0 @@
-from __future__ import unicode_literals
-from django.contrib.auth.models import User
-from django.http import HttpResponse
-from django.test import TestCase
-from django.utils import unittest
-from django.utils.http import urlencode
-from rest_framework import HTTP_HEADER_ENCODING
-from rest_framework import exceptions
-from rest_framework import permissions
-from rest_framework import renderers
-from rest_framework.response import Response
-from rest_framework import status
-from rest_framework.authentication import (
- BaseAuthentication,
- TokenAuthentication,
- BasicAuthentication,
- SessionAuthentication,
- OAuthAuthentication,
- OAuth2Authentication
-)
-from rest_framework.authtoken.models import Token
-from rest_framework.compat import patterns, url, include
-from rest_framework.compat import oauth2_provider, oauth2_provider_scope
-from rest_framework.compat import oauth, oauth_provider
-from rest_framework.test import APIRequestFactory, APIClient
-from rest_framework.views import APIView
-import base64
-import time
-import datetime
-
-factory = APIRequestFactory()
-
-
-class MockView(APIView):
- permission_classes = (permissions.IsAuthenticated,)
-
- def get(self, request):
- return HttpResponse({'a': 1, 'b': 2, 'c': 3})
-
- def post(self, request):
- return HttpResponse({'a': 1, 'b': 2, 'c': 3})
-
- def put(self, request):
- return HttpResponse({'a': 1, 'b': 2, 'c': 3})
-
-
-urlpatterns = patterns('',
- (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
- (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
- (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
- (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
- (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
- (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
- permission_classes=[permissions.TokenHasReadWriteScope]))
-)
-
-class OAuth2AuthenticationDebug(OAuth2Authentication):
- allow_query_params_token = True
-
-if oauth2_provider is not None:
- urlpatterns += patterns('',
- url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
- url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
- url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
- url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
- permission_classes=[permissions.TokenHasReadWriteScope])),
- )
-
-
-class BasicAuthTests(TestCase):
- """Basic authentication"""
- urls = 'rest_framework.tests.test_authentication'
-
- def setUp(self):
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- def test_post_form_passing_basic_auth(self):
- """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
- credentials = ('%s:%s' % (self.username, self.password))
- base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
- auth = 'Basic %s' % base64_credentials
- response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_post_json_passing_basic_auth(self):
- """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
- credentials = ('%s:%s' % (self.username, self.password))
- base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
- auth = 'Basic %s' % base64_credentials
- response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_post_form_failing_basic_auth(self):
- """Ensure POSTing form over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/basic/', {'example': 'example'})
- self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
- def test_post_json_failing_basic_auth(self):
- """Ensure POSTing json over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json')
- self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
- self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
-
-
-class SessionAuthTests(TestCase):
- """User session authentication"""
- urls = 'rest_framework.tests.test_authentication'
-
- def setUp(self):
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.non_csrf_client = APIClient(enforce_csrf_checks=False)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- def tearDown(self):
- self.csrf_client.logout()
-
- def test_post_form_session_auth_failing_csrf(self):
- """
- Ensure POSTing form over session authentication without CSRF token fails.
- """
- self.csrf_client.login(username=self.username, password=self.password)
- response = self.csrf_client.post('/session/', {'example': 'example'})
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- def test_post_form_session_auth_passing(self):
- """
- Ensure POSTing form over session authentication with logged in user and CSRF token passes.
- """
- self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.post('/session/', {'example': 'example'})
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_put_form_session_auth_passing(self):
- """
- Ensure PUTting form over session authentication with logged in user and CSRF token passes.
- """
- self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.put('/session/', {'example': 'example'})
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_post_form_session_auth_failing(self):
- """
- Ensure POSTing form over session authentication without logged in user fails.
- """
- response = self.csrf_client.post('/session/', {'example': 'example'})
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
-
-class TokenAuthTests(TestCase):
- """Token authentication"""
- urls = 'rest_framework.tests.test_authentication'
-
- def setUp(self):
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- self.key = 'abcd1234'
- self.token = Token.objects.create(key=self.key, user=self.user)
-
- def test_post_form_passing_token_auth(self):
- """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
- auth = 'Token ' + self.key
- response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_post_json_passing_token_auth(self):
- """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
- auth = "Token " + self.key
- response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_post_form_failing_token_auth(self):
- """Ensure POSTing form over token auth without correct credentials fails"""
- response = self.csrf_client.post('/token/', {'example': 'example'})
- self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
- def test_post_json_failing_token_auth(self):
- """Ensure POSTing json over token auth without correct credentials fails"""
- response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')
- self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
- def test_token_has_auto_assigned_key_if_none_provided(self):
- """Ensure creating a token with no key will auto-assign a key"""
- self.token.delete()
- token = Token.objects.create(user=self.user)
- self.assertTrue(bool(token.key))
-
- def test_token_login_json(self):
- """Ensure token login view using JSON POST works."""
- client = APIClient(enforce_csrf_checks=True)
- response = client.post('/auth-token/',
- {'username': self.username, 'password': self.password}, format='json')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['token'], self.key)
-
- def test_token_login_json_bad_creds(self):
- """Ensure token login view using JSON POST fails if bad credentials are used."""
- client = APIClient(enforce_csrf_checks=True)
- response = client.post('/auth-token/',
- {'username': self.username, 'password': "badpass"}, format='json')
- self.assertEqual(response.status_code, 400)
-
- def test_token_login_json_missing_fields(self):
- """Ensure token login view using JSON POST fails if missing fields."""
- client = APIClient(enforce_csrf_checks=True)
- response = client.post('/auth-token/',
- {'username': self.username}, format='json')
- self.assertEqual(response.status_code, 400)
-
- def test_token_login_form(self):
- """Ensure token login view using form POST works."""
- client = APIClient(enforce_csrf_checks=True)
- response = client.post('/auth-token/',
- {'username': self.username, 'password': self.password})
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['token'], self.key)
-
-
-class IncorrectCredentialsTests(TestCase):
- def test_incorrect_credentials(self):
- """
- If a request contains bad authentication credentials, then
- authentication should run and error, even if no permissions
- are set on the view.
- """
- class IncorrectCredentialsAuth(BaseAuthentication):
- def authenticate(self, request):
- raise exceptions.AuthenticationFailed('Bad credentials')
-
- request = factory.get('/')
- view = MockView.as_view(
- authentication_classes=(IncorrectCredentialsAuth,),
- permission_classes=()
- )
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertEqual(response.data, {'detail': 'Bad credentials'})
-
-
-class OAuthTests(TestCase):
- """OAuth 1.0a authentication"""
- urls = 'rest_framework.tests.test_authentication'
-
- def setUp(self):
- # these imports are here because oauth is optional and hiding them in try..except block or compat
- # could obscure problems if something breaks
- from oauth_provider.models import Consumer, Scope
- from oauth_provider.models import Token as OAuthToken
- from oauth_provider import consts
-
- self.consts = consts
-
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- self.CONSUMER_KEY = 'consumer_key'
- self.CONSUMER_SECRET = 'consumer_secret'
- self.TOKEN_KEY = "token_key"
- self.TOKEN_SECRET = "token_secret"
-
- self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
- name='example', user=self.user, status=self.consts.ACCEPTED)
-
- self.scope = Scope.objects.create(name="resource name", url="api/")
- self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, scope=self.scope,
- token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True
- )
-
- def _create_authorization_header(self):
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': self.token.key,
- 'oauth_consumer_key': self.consumer.key
- }
-
- req = oauth.Request(method="GET", url="http://example.com", parameters=params)
-
- signature_method = oauth.SignatureMethod_PLAINTEXT()
- req.sign_request(signature_method, self.consumer, self.token)
-
- return req.to_header()["Authorization"]
-
- def _create_authorization_url_parameters(self):
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': self.token.key,
- 'oauth_consumer_key': self.consumer.key
- }
-
- req = oauth.Request(method="GET", url="http://example.com", parameters=params)
-
- signature_method = oauth.SignatureMethod_PLAINTEXT()
- req.sign_request(signature_method, self.consumer, self.token)
- return dict(req)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_passing_oauth(self):
- """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_repeated_nonce_failing_oauth(self):
- """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails"""
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- # simulate reply attack auth header containes already used (nonce, timestamp) pair
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_token_removed_failing_oauth(self):
- """Ensure POSTing when there is no OAuth access token in db fails"""
- self.token.delete()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_consumer_status_not_accepted_failing_oauth(self):
- """Ensure POSTing when consumer status is anything other than ACCEPTED fails"""
- for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED):
- self.consumer.status = consumer_status
- self.consumer.save()
-
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_with_request_token_failing_oauth(self):
- """Ensure POSTing with unauthorized request token instead of access token fails"""
- self.token.token_type = self.token.REQUEST
- self.token.save()
-
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_with_urlencoded_parameters(self):
- """Ensure POSTing with x-www-form-urlencoded auth parameters passes"""
- params = self._create_authorization_url_parameters()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', params, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_get_form_with_url_parameters(self):
- """Ensure GETing with auth in url parameters passes"""
- params = self._create_authorization_url_parameters()
- response = self.csrf_client.get('/oauth/', params)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_hmac_sha1_signature_passes(self):
- """Ensure POSTing using HMAC_SHA1 signature method passes"""
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': self.token.key,
- 'oauth_consumer_key': self.consumer.key
- }
-
- req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
-
- signature_method = oauth.SignatureMethod_HMAC_SHA1()
- req.sign_request(signature_method, self.consumer, self.token)
- auth = req.to_header()["Authorization"]
-
- response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_get_form_with_readonly_resource_passing_auth(self):
- """Ensure POSTing with a readonly scope instead of a write scope fails"""
- read_only_access_token = self.token
- read_only_access_token.scope.is_readonly = True
- read_only_access_token.scope.save()
- params = self._create_authorization_url_parameters()
- response = self.csrf_client.get('/oauth-with-scope/', params)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_with_readonly_resource_failing_auth(self):
- """Ensure POSTing with a readonly resource instead of a write scope fails"""
- read_only_access_token = self.token
- read_only_access_token.scope.is_readonly = True
- read_only_access_token.scope.save()
- params = self._create_authorization_url_parameters()
- response = self.csrf_client.post('/oauth-with-scope/', params)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_with_write_resource_passing_auth(self):
- """Ensure POSTing with a write resource succeed"""
- read_write_access_token = self.token
- read_write_access_token.scope.is_readonly = False
- read_write_access_token.scope.save()
- params = self._create_authorization_url_parameters()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth-with-scope/', params, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_bad_consumer_key(self):
- """Ensure POSTing using HMAC_SHA1 signature method passes"""
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': self.token.key,
- 'oauth_consumer_key': 'badconsumerkey'
- }
-
- req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
-
- signature_method = oauth.SignatureMethod_HMAC_SHA1()
- req.sign_request(signature_method, self.consumer, self.token)
- auth = req.to_header()["Authorization"]
-
- response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_bad_token_key(self):
- """Ensure POSTing using HMAC_SHA1 signature method passes"""
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': 'badtokenkey',
- 'oauth_consumer_key': self.consumer.key
- }
-
- req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
-
- signature_method = oauth.SignatureMethod_HMAC_SHA1()
- req.sign_request(signature_method, self.consumer, self.token)
- auth = req.to_header()["Authorization"]
-
- response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
-
-class OAuth2Tests(TestCase):
- """OAuth 2.0 authentication"""
- urls = 'rest_framework.tests.test_authentication'
-
- def setUp(self):
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- self.CLIENT_ID = 'client_key'
- self.CLIENT_SECRET = 'client_secret'
- self.ACCESS_TOKEN = "access_token"
- self.REFRESH_TOKEN = "refresh_token"
-
- self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create(
- client_id=self.CLIENT_ID,
- client_secret=self.CLIENT_SECRET,
- redirect_uri='',
- client_type=0,
- name='example',
- user=None,
- )
-
- self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create(
- token=self.ACCESS_TOKEN,
- client=self.oauth2_client,
- user=self.user,
- )
- self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create(
- user=self.user,
- access_token=self.access_token,
- client=self.oauth2_client
- )
-
- def _create_authorization_header(self, token=None):
- return "Bearer {0}".format(token or self.access_token.token)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_authorization_header_token_type_failing(self):
- """Ensure that a wrong token type lead to the correct HTTP error status code"""
- auth = "Wrong token-type-obsviously"
- response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_authorization_header_token_format_failing(self):
- """Ensure that a wrong token format lead to the correct HTTP error status code"""
- auth = "Bearer wrong token format"
- response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_authorization_header_token_failing(self):
- """Ensure that a wrong token lead to the correct HTTP error status code"""
- auth = "Bearer wrong-token"
- response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_passing_auth(self):
- """Ensure GETing form over OAuth with correct client credentials succeed"""
- auth = self._create_authorization_header()
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_passing_auth_url_transport(self):
- """Ensure GETing form over OAuth with correct client credentials in form data succeed"""
- response = self.csrf_client.post('/oauth2-test/',
- data={'access_token': self.access_token.token})
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_passing_auth_url_transport(self):
- """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True"""
- query = urlencode({'access_token': self.access_token.token})
- response = self.csrf_client.get('/oauth2-test-debug/?%s' % query)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_failing_auth_url_transport(self):
- """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False"""
- query = urlencode({'access_token': self.access_token.token})
- response = self.csrf_client.get('/oauth2-test/?%s' % query)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_passing_auth(self):
- """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_token_removed_failing_auth(self):
- """Ensure POSTing when there is no OAuth access token in db fails"""
- self.access_token.delete()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_with_refresh_token_failing_auth(self):
- """Ensure POSTing with refresh token instead of access token fails"""
- auth = self._create_authorization_header(token=self.refresh_token.token)
- response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_with_expired_access_token_failing_auth(self):
- """Ensure POSTing with expired access token fails with an 'Invalid token' error"""
- self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
- self.access_token.save()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
- self.assertIn('Invalid token', response.content)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_with_invalid_scope_failing_auth(self):
- """Ensure POSTing with a readonly scope instead of a write scope fails"""
- read_only_access_token = self.access_token
- read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
- read_only_access_token.save()
- auth = self._create_authorization_header(token=read_only_access_token.token)
- response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
- response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_with_valid_scope_passing_auth(self):
- """Ensure POSTing with a write scope succeed"""
- read_write_access_token = self.access_token
- read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
- read_write_access_token.save()
- auth = self._create_authorization_header(token=read_write_access_token.token)
- response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
-
-class FailingAuthAccessedInRenderer(TestCase):
- def setUp(self):
- class AuthAccessingRenderer(renderers.BaseRenderer):
- media_type = 'text/plain'
- format = 'txt'
-
- def render(self, data, media_type=None, renderer_context=None):
- request = renderer_context['request']
- if request.user.is_authenticated():
- return b'authenticated'
- return b'not authenticated'
-
- class FailingAuth(BaseAuthentication):
- def authenticate(self, request):
- raise exceptions.AuthenticationFailed('authentication failed')
-
- class ExampleView(APIView):
- authentication_classes = (FailingAuth,)
- renderer_classes = (AuthAccessingRenderer,)
-
- def get(self, request):
- return Response({'foo': 'bar'})
-
- self.view = ExampleView.as_view()
-
- def test_failing_auth_accessed_in_renderer(self):
- """
- When authentication fails the renderer should still be able to access
- `request.user` without raising an exception. Particularly relevant
- to HTML responses that might reasonably access `request.user`.
- """
- request = factory.get('/')
- response = self.view(request)
- content = response.render().content
- self.assertEqual(content, b'not authenticated')
diff --git a/rest_framework/tests/test_breadcrumbs.py b/rest_framework/tests/test_breadcrumbs.py
deleted file mode 100644
index 41ddf2ce..00000000
--- a/rest_framework/tests/test_breadcrumbs.py
+++ /dev/null
@@ -1,73 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.compat import patterns, url
-from rest_framework.utils.breadcrumbs import get_breadcrumbs
-from rest_framework.views import APIView
-
-
-class Root(APIView):
- pass
-
-
-class ResourceRoot(APIView):
- pass
-
-
-class ResourceInstance(APIView):
- pass
-
-
-class NestedResourceRoot(APIView):
- pass
-
-
-class NestedResourceInstance(APIView):
- pass
-
-urlpatterns = patterns('',
- url(r'^$', Root.as_view()),
- url(r'^resource/$', ResourceRoot.as_view()),
- url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()),
- url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()),
- url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()),
-)
-
-
-class BreadcrumbTests(TestCase):
- """Tests the breadcrumb functionality used by the HTML renderer."""
-
- urls = 'rest_framework.tests.test_breadcrumbs'
-
- def test_root_breadcrumbs(self):
- url = '/'
- self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
-
- def test_resource_root_breadcrumbs(self):
- url = '/resource/'
- self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
- ('Resource Root', '/resource/')])
-
- def test_resource_instance_breadcrumbs(self):
- url = '/resource/123'
- self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
- ('Resource Root', '/resource/'),
- ('Resource Instance', '/resource/123')])
-
- def test_nested_resource_breadcrumbs(self):
- url = '/resource/123/'
- self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
- ('Resource Root', '/resource/'),
- ('Resource Instance', '/resource/123'),
- ('Nested Resource Root', '/resource/123/')])
-
- def test_nested_resource_instance_breadcrumbs(self):
- url = '/resource/123/abc'
- self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
- ('Resource Root', '/resource/'),
- ('Resource Instance', '/resource/123'),
- ('Nested Resource Root', '/resource/123/'),
- ('Nested Resource Instance', '/resource/123/abc')])
-
- def test_broken_url_breadcrumbs_handled_gracefully(self):
- url = '/foobar'
- self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
diff --git a/rest_framework/tests/test_decorators.py b/rest_framework/tests/test_decorators.py
deleted file mode 100644
index 195f0ba3..00000000
--- a/rest_framework/tests/test_decorators.py
+++ /dev/null
@@ -1,157 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import status
-from rest_framework.authentication import BasicAuthentication
-from rest_framework.parsers import JSONParser
-from rest_framework.permissions import IsAuthenticated
-from rest_framework.response import Response
-from rest_framework.renderers import JSONRenderer
-from rest_framework.test import APIRequestFactory
-from rest_framework.throttling import UserRateThrottle
-from rest_framework.views import APIView
-from rest_framework.decorators import (
- api_view,
- renderer_classes,
- parser_classes,
- authentication_classes,
- throttle_classes,
- permission_classes,
-)
-
-
-class DecoratorTestCase(TestCase):
-
- def setUp(self):
- self.factory = APIRequestFactory()
-
- def _finalize_response(self, request, response, *args, **kwargs):
- response.request = request
- return APIView.finalize_response(self, request, response, *args, **kwargs)
-
- def test_api_view_incorrect(self):
- """
- If @api_view is not applied correct, we should raise an assertion.
- """
-
- @api_view
- def view(request):
- return Response()
-
- request = self.factory.get('/')
- self.assertRaises(AssertionError, view, request)
-
- def test_api_view_incorrect_arguments(self):
- """
- If @api_view is missing arguments, we should raise an assertion.
- """
-
- with self.assertRaises(AssertionError):
- @api_view('GET')
- def view(request):
- return Response()
-
- def test_calling_method(self):
-
- @api_view(['GET'])
- def view(request):
- return Response({})
-
- request = self.factory.get('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- request = self.factory.post('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
-
- def test_calling_put_method(self):
-
- @api_view(['GET', 'PUT'])
- def view(request):
- return Response({})
-
- request = self.factory.put('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- request = self.factory.post('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
-
- def test_calling_patch_method(self):
-
- @api_view(['GET', 'PATCH'])
- def view(request):
- return Response({})
-
- request = self.factory.patch('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- request = self.factory.post('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
-
- def test_renderer_classes(self):
-
- @api_view(['GET'])
- @renderer_classes([JSONRenderer])
- def view(request):
- return Response({})
-
- request = self.factory.get('/')
- response = view(request)
- self.assertTrue(isinstance(response.accepted_renderer, JSONRenderer))
-
- def test_parser_classes(self):
-
- @api_view(['GET'])
- @parser_classes([JSONParser])
- def view(request):
- self.assertEqual(len(request.parsers), 1)
- self.assertTrue(isinstance(request.parsers[0],
- JSONParser))
- return Response({})
-
- request = self.factory.get('/')
- view(request)
-
- def test_authentication_classes(self):
-
- @api_view(['GET'])
- @authentication_classes([BasicAuthentication])
- def view(request):
- self.assertEqual(len(request.authenticators), 1)
- self.assertTrue(isinstance(request.authenticators[0],
- BasicAuthentication))
- return Response({})
-
- request = self.factory.get('/')
- view(request)
-
- def test_permission_classes(self):
-
- @api_view(['GET'])
- @permission_classes([IsAuthenticated])
- def view(request):
- return Response({})
-
- request = self.factory.get('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- def test_throttle_classes(self):
- class OncePerDayUserThrottle(UserRateThrottle):
- rate = '1/day'
-
- @api_view(['GET'])
- @throttle_classes([OncePerDayUserThrottle])
- def view(request):
- return Response({})
-
- request = self.factory.get('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
diff --git a/rest_framework/tests/test_description.py b/rest_framework/tests/test_description.py
deleted file mode 100644
index 4c03c1de..00000000
--- a/rest_framework/tests/test_description.py
+++ /dev/null
@@ -1,108 +0,0 @@
-# -- coding: utf-8 --
-
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.compat import apply_markdown, smart_text
-from rest_framework.views import APIView
-from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring
-from rest_framework.tests.description import UTF8_TEST_DOCSTRING
-
-# We check that docstrings get nicely un-indented.
-DESCRIPTION = """an example docstring
-====================
-
-* list
-* list
-
-another header
---------------
-
- code block
-
-indented
-
-# hash style header #"""
-
-# If markdown is installed we also test it's working
-# (and that our wrapped forces '=' to h2 and '-' to h3)
-
-# We support markdown < 2.1 and markdown >= 2.1
-MARKED_DOWN_lt_21 = """<h2>an example docstring</h2>
-<ul>
-<li>list</li>
-<li>list</li>
-</ul>
-<h3>another header</h3>
-<pre><code>code block
-</code></pre>
-<p>indented</p>
-<h2 id="hash_style_header">hash style header</h2>"""
-
-MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
-<ul>
-<li>list</li>
-<li>list</li>
-</ul>
-<h3 id="another-header">another header</h3>
-<pre><code>code block
-</code></pre>
-<p>indented</p>
-<h2 id="hash-style-header">hash style header</h2>"""
-
-
-class TestViewNamesAndDescriptions(TestCase):
- def test_view_name_uses_class_name(self):
- """
- Ensure view names are based on the class name.
- """
- class MockView(APIView):
- pass
- self.assertEqual(MockView().get_view_name(), 'Mock')
-
- def test_view_description_uses_docstring(self):
- """Ensure view descriptions are based on the docstring."""
- class MockView(APIView):
- """an example docstring
- ====================
-
- * list
- * list
-
- another header
- --------------
-
- code block
-
- indented
-
- # hash style header #"""
-
- self.assertEqual(MockView().get_view_description(), DESCRIPTION)
-
- def test_view_description_supports_unicode(self):
- """
- Unicode in docstrings should be respected.
- """
-
- self.assertEqual(
- ViewWithNonASCIICharactersInDocstring().get_view_description(),
- smart_text(UTF8_TEST_DOCSTRING)
- )
-
- def test_view_description_can_be_empty(self):
- """
- Ensure that if a view has no docstring,
- then it's description is the empty string.
- """
- class MockView(APIView):
- pass
- self.assertEqual(MockView().get_view_description(), '')
-
- def test_markdown(self):
- """
- Ensure markdown to HTML works as expected.
- """
- if apply_markdown:
- gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21
- lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21
- self.assertTrue(gte_21_match or lt_21_match)
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
deleted file mode 100644
index e127feef..00000000
--- a/rest_framework/tests/test_fields.py
+++ /dev/null
@@ -1,984 +0,0 @@
-"""
-General serializer field tests.
-"""
-from __future__ import unicode_literals
-
-import datetime
-from decimal import Decimal
-from uuid import uuid4
-from django.core import validators
-from django.db import models
-from django.test import TestCase
-from django.utils.datastructures import SortedDict
-from rest_framework import serializers
-from rest_framework.tests.models import RESTFrameworkModel
-
-
-class TimestampedModel(models.Model):
- added = models.DateTimeField(auto_now_add=True)
- updated = models.DateTimeField(auto_now=True)
-
-
-class CharPrimaryKeyModel(models.Model):
- id = models.CharField(max_length=20, primary_key=True)
-
-
-class TimestampedModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = TimestampedModel
-
-
-class CharPrimaryKeyModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = CharPrimaryKeyModel
-
-
-class TimeFieldModel(models.Model):
- clock = models.TimeField()
-
-
-class TimeFieldModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = TimeFieldModel
-
-
-SAMPLE_CHOICES = [
- ('red', 'Red'),
- ('green', 'Green'),
- ('blue', 'Blue'),
-]
-
-
-class ChoiceFieldModel(models.Model):
- choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255)
-
-
-class ChoiceFieldModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = ChoiceFieldModel
-
-
-class ChoiceFieldModelWithNull(models.Model):
- choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255)
-
-
-class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer):
- class Meta:
- model = ChoiceFieldModelWithNull
-
-
-class BasicFieldTests(TestCase):
- def test_auto_now_fields_read_only(self):
- """
- auto_now and auto_now_add fields should be read_only by default.
- """
- serializer = TimestampedModelSerializer()
- self.assertEqual(serializer.fields['added'].read_only, True)
-
- def test_auto_pk_fields_read_only(self):
- """
- AutoField fields should be read_only by default.
- """
- serializer = TimestampedModelSerializer()
- self.assertEqual(serializer.fields['id'].read_only, True)
-
- def test_non_auto_pk_fields_not_read_only(self):
- """
- PK fields other than AutoField fields should not be read_only by default.
- """
- serializer = CharPrimaryKeyModelSerializer()
- self.assertEqual(serializer.fields['id'].read_only, False)
-
- def test_dict_field_ordering(self):
- """
- Field should preserve dictionary ordering, if it exists.
- See: https://github.com/tomchristie/django-rest-framework/issues/832
- """
- ret = SortedDict()
- ret['c'] = 1
- ret['b'] = 1
- ret['a'] = 1
- ret['z'] = 1
- field = serializers.Field()
- keys = list(field.to_native(ret).keys())
- self.assertEqual(keys, ['c', 'b', 'a', 'z'])
-
-
-class DateFieldTest(TestCase):
- """
- Tests for the DateFieldTest from_native() and to_native() behavior
- """
-
- def test_from_native_string(self):
- """
- Make sure from_native() accepts default iso input formats.
- """
- f = serializers.DateField()
- result_1 = f.from_native('1984-07-31')
-
- self.assertEqual(datetime.date(1984, 7, 31), result_1)
-
- def test_from_native_datetime_date(self):
- """
- Make sure from_native() accepts a datetime.date instance.
- """
- f = serializers.DateField()
- result_1 = f.from_native(datetime.date(1984, 7, 31))
-
- self.assertEqual(result_1, datetime.date(1984, 7, 31))
-
- def test_from_native_custom_format(self):
- """
- Make sure from_native() accepts custom input formats.
- """
- f = serializers.DateField(input_formats=['%Y -- %d'])
- result = f.from_native('1984 -- 31')
-
- self.assertEqual(datetime.date(1984, 1, 31), result)
-
- def test_from_native_invalid_default_on_custom_format(self):
- """
- Make sure from_native() don't accept default formats if custom format is preset
- """
- f = serializers.DateField(input_formats=['%Y -- %d'])
-
- try:
- f.from_native('1984-07-31')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_empty(self):
- """
- Make sure from_native() returns None on empty param.
- """
- f = serializers.DateField()
- result = f.from_native('')
-
- self.assertEqual(result, None)
-
- def test_from_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.DateField()
- result = f.from_native(None)
-
- self.assertEqual(result, None)
-
- def test_from_native_invalid_date(self):
- """
- Make sure from_native() raises a ValidationError on passing an invalid date.
- """
- f = serializers.DateField()
-
- try:
- f.from_native('1984-13-31')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_invalid_format(self):
- """
- Make sure from_native() raises a ValidationError on passing an invalid format.
- """
- f = serializers.DateField()
-
- try:
- f.from_native('1984 -- 31')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_to_native(self):
- """
- Make sure to_native() returns datetime as default.
- """
- f = serializers.DateField()
-
- result_1 = f.to_native(datetime.date(1984, 7, 31))
-
- self.assertEqual(datetime.date(1984, 7, 31), result_1)
-
- def test_to_native_iso(self):
- """
- Make sure to_native() with 'iso-8601' returns iso formated date.
- """
- f = serializers.DateField(format='iso-8601')
-
- result_1 = f.to_native(datetime.date(1984, 7, 31))
-
- self.assertEqual('1984-07-31', result_1)
-
- def test_to_native_custom_format(self):
- """
- Make sure to_native() returns correct custom format.
- """
- f = serializers.DateField(format="%Y - %m.%d")
-
- result_1 = f.to_native(datetime.date(1984, 7, 31))
-
- self.assertEqual('1984 - 07.31', result_1)
-
- def test_to_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.DateField(required=False)
- self.assertEqual(None, f.to_native(None))
-
-
-class DateTimeFieldTest(TestCase):
- """
- Tests for the DateTimeField from_native() and to_native() behavior
- """
-
- def test_from_native_string(self):
- """
- Make sure from_native() accepts default iso input formats.
- """
- f = serializers.DateTimeField()
- result_1 = f.from_native('1984-07-31 04:31')
- result_2 = f.from_native('1984-07-31 04:31:59')
- result_3 = f.from_native('1984-07-31 04:31:59.000200')
-
- self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1)
- self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2)
- self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3)
-
- def test_from_native_datetime_datetime(self):
- """
- Make sure from_native() accepts a datetime.datetime instance.
- """
- f = serializers.DateTimeField()
- result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31))
- result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
- result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
-
- self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31))
- self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59))
- self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
-
- def test_from_native_custom_format(self):
- """
- Make sure from_native() accepts custom input formats.
- """
- f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
- result = f.from_native('1984 -- 04:59')
-
- self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result)
-
- def test_from_native_invalid_default_on_custom_format(self):
- """
- Make sure from_native() don't accept default formats if custom format is preset
- """
- f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
-
- try:
- f.from_native('1984-07-31 04:31:59')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_empty(self):
- """
- Make sure from_native() returns None on empty param.
- """
- f = serializers.DateTimeField()
- result = f.from_native('')
-
- self.assertEqual(result, None)
-
- def test_from_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.DateTimeField()
- result = f.from_native(None)
-
- self.assertEqual(result, None)
-
- def test_from_native_invalid_datetime(self):
- """
- Make sure from_native() raises a ValidationError on passing an invalid datetime.
- """
- f = serializers.DateTimeField()
-
- try:
- f.from_native('04:61:59')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
- "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_invalid_format(self):
- """
- Make sure from_native() raises a ValidationError on passing an invalid format.
- """
- f = serializers.DateTimeField()
-
- try:
- f.from_native('04 -- 31')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
- "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_to_native(self):
- """
- Make sure to_native() returns isoformat as default.
- """
- f = serializers.DateTimeField()
-
- result_1 = f.to_native(datetime.datetime(1984, 7, 31))
- result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
- result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
- result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
-
- self.assertEqual(datetime.datetime(1984, 7, 31), result_1)
- self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2)
- self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3)
- self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4)
-
- def test_to_native_iso(self):
- """
- Make sure to_native() with format=iso-8601 returns iso formatted datetime.
- """
- f = serializers.DateTimeField(format='iso-8601')
-
- result_1 = f.to_native(datetime.datetime(1984, 7, 31))
- result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
- result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
- result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
-
- self.assertEqual('1984-07-31T00:00:00', result_1)
- self.assertEqual('1984-07-31T04:31:00', result_2)
- self.assertEqual('1984-07-31T04:31:59', result_3)
- self.assertEqual('1984-07-31T04:31:59.000200', result_4)
-
- def test_to_native_custom_format(self):
- """
- Make sure to_native() returns correct custom format.
- """
- f = serializers.DateTimeField(format="%Y - %H:%M")
-
- result_1 = f.to_native(datetime.datetime(1984, 7, 31))
- result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
- result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
- result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
-
- self.assertEqual('1984 - 00:00', result_1)
- self.assertEqual('1984 - 04:31', result_2)
- self.assertEqual('1984 - 04:31', result_3)
- self.assertEqual('1984 - 04:31', result_4)
-
- def test_to_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.DateTimeField(required=False)
- self.assertEqual(None, f.to_native(None))
-
-
-class TimeFieldTest(TestCase):
- """
- Tests for the TimeField from_native() and to_native() behavior
- """
-
- def test_from_native_string(self):
- """
- Make sure from_native() accepts default iso input formats.
- """
- f = serializers.TimeField()
- result_1 = f.from_native('04:31')
- result_2 = f.from_native('04:31:59')
- result_3 = f.from_native('04:31:59.000200')
-
- self.assertEqual(datetime.time(4, 31), result_1)
- self.assertEqual(datetime.time(4, 31, 59), result_2)
- self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
-
- def test_from_native_datetime_time(self):
- """
- Make sure from_native() accepts a datetime.time instance.
- """
- f = serializers.TimeField()
- result_1 = f.from_native(datetime.time(4, 31))
- result_2 = f.from_native(datetime.time(4, 31, 59))
- result_3 = f.from_native(datetime.time(4, 31, 59, 200))
-
- self.assertEqual(result_1, datetime.time(4, 31))
- self.assertEqual(result_2, datetime.time(4, 31, 59))
- self.assertEqual(result_3, datetime.time(4, 31, 59, 200))
-
- def test_from_native_custom_format(self):
- """
- Make sure from_native() accepts custom input formats.
- """
- f = serializers.TimeField(input_formats=['%H -- %M'])
- result = f.from_native('04 -- 31')
-
- self.assertEqual(datetime.time(4, 31), result)
-
- def test_from_native_invalid_default_on_custom_format(self):
- """
- Make sure from_native() don't accept default formats if custom format is preset
- """
- f = serializers.TimeField(input_formats=['%H -- %M'])
-
- try:
- f.from_native('04:31:59')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_empty(self):
- """
- Make sure from_native() returns None on empty param.
- """
- f = serializers.TimeField()
- result = f.from_native('')
-
- self.assertEqual(result, None)
-
- def test_from_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.TimeField()
- result = f.from_native(None)
-
- self.assertEqual(result, None)
-
- def test_from_native_invalid_time(self):
- """
- Make sure from_native() raises a ValidationError on passing an invalid time.
- """
- f = serializers.TimeField()
-
- try:
- f.from_native('04:61:59')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
- "hh:mm[:ss[.uuuuuu]]"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_invalid_format(self):
- """
- Make sure from_native() raises a ValidationError on passing an invalid format.
- """
- f = serializers.TimeField()
-
- try:
- f.from_native('04 -- 31')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
- "hh:mm[:ss[.uuuuuu]]"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_to_native(self):
- """
- Make sure to_native() returns time object as default.
- """
- f = serializers.TimeField()
- result_1 = f.to_native(datetime.time(4, 31))
- result_2 = f.to_native(datetime.time(4, 31, 59))
- result_3 = f.to_native(datetime.time(4, 31, 59, 200))
-
- self.assertEqual(datetime.time(4, 31), result_1)
- self.assertEqual(datetime.time(4, 31, 59), result_2)
- self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
-
- def test_to_native_iso(self):
- """
- Make sure to_native() with format='iso-8601' returns iso formatted time.
- """
- f = serializers.TimeField(format='iso-8601')
- result_1 = f.to_native(datetime.time(4, 31))
- result_2 = f.to_native(datetime.time(4, 31, 59))
- result_3 = f.to_native(datetime.time(4, 31, 59, 200))
-
- self.assertEqual('04:31:00', result_1)
- self.assertEqual('04:31:59', result_2)
- self.assertEqual('04:31:59.000200', result_3)
-
- def test_to_native_custom_format(self):
- """
- Make sure to_native() returns correct custom format.
- """
- f = serializers.TimeField(format="%H - %S [%f]")
- result_1 = f.to_native(datetime.time(4, 31))
- result_2 = f.to_native(datetime.time(4, 31, 59))
- result_3 = f.to_native(datetime.time(4, 31, 59, 200))
-
- self.assertEqual('04 - 00 [000000]', result_1)
- self.assertEqual('04 - 59 [000000]', result_2)
- self.assertEqual('04 - 59 [000200]', result_3)
-
-
-class DecimalFieldTest(TestCase):
- """
- Tests for the DecimalField from_native() and to_native() behavior
- """
-
- def test_from_native_string(self):
- """
- Make sure from_native() accepts string values
- """
- f = serializers.DecimalField()
- result_1 = f.from_native('9000')
- result_2 = f.from_native('1.00000001')
-
- self.assertEqual(Decimal('9000'), result_1)
- self.assertEqual(Decimal('1.00000001'), result_2)
-
- def test_from_native_invalid_string(self):
- """
- Make sure from_native() raises ValidationError on passing invalid string
- """
- f = serializers.DecimalField()
-
- try:
- f.from_native('123.45.6')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Enter a number."])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_integer(self):
- """
- Make sure from_native() accepts integer values
- """
- f = serializers.DecimalField()
- result = f.from_native(9000)
-
- self.assertEqual(Decimal('9000'), result)
-
- def test_from_native_float(self):
- """
- Make sure from_native() accepts float values
- """
- f = serializers.DecimalField()
- result = f.from_native(1.00000001)
-
- self.assertEqual(Decimal('1.00000001'), result)
-
- def test_from_native_empty(self):
- """
- Make sure from_native() returns None on empty param.
- """
- f = serializers.DecimalField()
- result = f.from_native('')
-
- self.assertEqual(result, None)
-
- def test_from_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.DecimalField()
- result = f.from_native(None)
-
- self.assertEqual(result, None)
-
- def test_to_native(self):
- """
- Make sure to_native() returns Decimal as string.
- """
- f = serializers.DecimalField()
-
- result_1 = f.to_native(Decimal('9000'))
- result_2 = f.to_native(Decimal('1.00000001'))
-
- self.assertEqual(Decimal('9000'), result_1)
- self.assertEqual(Decimal('1.00000001'), result_2)
-
- def test_to_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.DecimalField(required=False)
- self.assertEqual(None, f.to_native(None))
-
- def test_valid_serialization(self):
- """
- Make sure the serializer works correctly
- """
- class DecimalSerializer(serializers.Serializer):
- decimal_field = serializers.DecimalField(max_value=9010,
- min_value=9000,
- max_digits=6,
- decimal_places=2)
-
- self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
- self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
- self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
-
- self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
- self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
- self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
-
- def test_raise_max_value(self):
- """
- Make sure max_value violations raises ValidationError
- """
- class DecimalSerializer(serializers.Serializer):
- decimal_field = serializers.DecimalField(max_value=100)
-
- s = DecimalSerializer(data={'decimal_field': '123'})
-
- self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
-
- def test_raise_min_value(self):
- """
- Make sure min_value violations raises ValidationError
- """
- class DecimalSerializer(serializers.Serializer):
- decimal_field = serializers.DecimalField(min_value=100)
-
- s = DecimalSerializer(data={'decimal_field': '99'})
-
- self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
-
- def test_raise_max_digits(self):
- """
- Make sure max_digits violations raises ValidationError
- """
- class DecimalSerializer(serializers.Serializer):
- decimal_field = serializers.DecimalField(max_digits=5)
-
- s = DecimalSerializer(data={'decimal_field': '123.456'})
-
- self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
-
- def test_raise_max_decimal_places(self):
- """
- Make sure max_decimal_places violations raises ValidationError
- """
- class DecimalSerializer(serializers.Serializer):
- decimal_field = serializers.DecimalField(decimal_places=3)
-
- s = DecimalSerializer(data={'decimal_field': '123.4567'})
-
- self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
-
- def test_raise_max_whole_digits(self):
- """
- Make sure max_whole_digits violations raises ValidationError
- """
- class DecimalSerializer(serializers.Serializer):
- decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
-
- s = DecimalSerializer(data={'decimal_field': '12345.6'})
-
- self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
-
-
-class ChoiceFieldTests(TestCase):
- """
- Tests for the ChoiceField options generator
- """
- def test_choices_required(self):
- """
- Make sure proper choices are rendered if field is required
- """
- f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES)
- self.assertEqual(f.choices, SAMPLE_CHOICES)
-
- def test_choices_not_required(self):
- """
- Make sure proper choices (plus blank) are rendered if the field isn't required
- """
- f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
- self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES)
-
- def test_invalid_choice_model(self):
- s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'})
- self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']})
- self.assertEqual(s.data['choice'], '')
-
- def test_empty_choice_model(self):
- """
- Test that the 'empty' value is correctly passed and used depending on
- the 'null' property on the model field.
- """
- s = ChoiceFieldModelSerializer(data={'choice': ''})
- self.assertTrue(s.is_valid())
- self.assertEqual(s.data['choice'], '')
-
- s = ChoiceFieldModelWithNullSerializer(data={'choice': ''})
- self.assertTrue(s.is_valid())
- self.assertEqual(s.data['choice'], None)
-
- def test_from_native_empty(self):
- """
- Make sure from_native() returns an empty string on empty param by default.
- """
- f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
- self.assertEqual(f.from_native(''), '')
- self.assertEqual(f.from_native(None), '')
-
- def test_from_native_empty_override(self):
- """
- Make sure you can override from_native() behavior regarding empty values.
- """
- f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None)
- self.assertEqual(f.from_native(''), None)
- self.assertEqual(f.from_native(None), None)
-
- def test_metadata_choices(self):
- """
- Make sure proper choices are included in the field's metadata.
- """
- choices = [{'value': v, 'display_name': n} for v, n in SAMPLE_CHOICES]
- f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
- self.assertEqual(f.metadata()['choices'], choices)
-
- def test_metadata_choices_not_required(self):
- """
- Make sure proper choices are included in the field's metadata.
- """
- choices = [{'value': v, 'display_name': n}
- for v, n in models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES]
- f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
- self.assertEqual(f.metadata()['choices'], choices)
-
-
-class EmailFieldTests(TestCase):
- """
- Tests for EmailField attribute values
- """
-
- class EmailFieldModel(RESTFrameworkModel):
- email_field = models.EmailField(blank=True)
-
- class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel):
- email_field = models.EmailField(max_length=150, blank=True)
-
- def test_default_model_value(self):
- class EmailFieldSerializer(serializers.ModelSerializer):
- class Meta:
- model = self.EmailFieldModel
-
- serializer = EmailFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75)
-
- def test_given_model_value(self):
- class EmailFieldSerializer(serializers.ModelSerializer):
- class Meta:
- model = self.EmailFieldWithGivenMaxLengthModel
-
- serializer = EmailFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150)
-
- def test_given_serializer_value(self):
- class EmailFieldSerializer(serializers.ModelSerializer):
- email_field = serializers.EmailField(source='email_field', max_length=20, required=False)
-
- class Meta:
- model = self.EmailFieldModel
-
- serializer = EmailFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20)
-
-
-class SlugFieldTests(TestCase):
- """
- Tests for SlugField attribute values
- """
-
- class SlugFieldModel(RESTFrameworkModel):
- slug_field = models.SlugField(blank=True)
-
- class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel):
- slug_field = models.SlugField(max_length=84, blank=True)
-
- def test_default_model_value(self):
- class SlugFieldSerializer(serializers.ModelSerializer):
- class Meta:
- model = self.SlugFieldModel
-
- serializer = SlugFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50)
-
- def test_given_model_value(self):
- class SlugFieldSerializer(serializers.ModelSerializer):
- class Meta:
- model = self.SlugFieldWithGivenMaxLengthModel
-
- serializer = SlugFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84)
-
- def test_given_serializer_value(self):
- class SlugFieldSerializer(serializers.ModelSerializer):
- slug_field = serializers.SlugField(source='slug_field',
- max_length=20, required=False)
-
- class Meta:
- model = self.SlugFieldModel
-
- serializer = SlugFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['slug_field'],
- 'max_length'), 20)
-
- def test_invalid_slug(self):
- """
- Make sure an invalid slug raises ValidationError
- """
- class SlugFieldSerializer(serializers.ModelSerializer):
- slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True)
-
- class Meta:
- model = self.SlugFieldModel
-
- s = SlugFieldSerializer(data={'slug_field': 'a b'})
-
- self.assertEqual(s.is_valid(), False)
- self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]})
-
-
-class URLFieldTests(TestCase):
- """
- Tests for URLField attribute values.
-
- (Includes test for #1210, checking that validators can be overridden.)
- """
-
- class URLFieldModel(RESTFrameworkModel):
- url_field = models.URLField(blank=True)
-
- class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel):
- url_field = models.URLField(max_length=128, blank=True)
-
- def test_default_model_value(self):
- class URLFieldSerializer(serializers.ModelSerializer):
- class Meta:
- model = self.URLFieldModel
-
- serializer = URLFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['url_field'],
- 'max_length'), 200)
-
- def test_given_model_value(self):
- class URLFieldSerializer(serializers.ModelSerializer):
- class Meta:
- model = self.URLFieldWithGivenMaxLengthModel
-
- serializer = URLFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['url_field'],
- 'max_length'), 128)
-
- def test_given_serializer_value(self):
- class URLFieldSerializer(serializers.ModelSerializer):
- url_field = serializers.URLField(source='url_field',
- max_length=20, required=False)
-
- class Meta:
- model = self.URLFieldWithGivenMaxLengthModel
-
- serializer = URLFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['url_field'],
- 'max_length'), 20)
-
- def test_validators_can_be_overridden(self):
- url_field = serializers.URLField(validators=[])
- validators = url_field.validators
- self.assertEqual([], validators, 'Passing `validators` kwarg should have overridden default validators')
-
-
-class FieldMetadata(TestCase):
- def setUp(self):
- self.required_field = serializers.Field()
- self.required_field.label = uuid4().hex
- self.required_field.required = True
-
- self.optional_field = serializers.Field()
- self.optional_field.label = uuid4().hex
- self.optional_field.required = False
-
- def test_required(self):
- self.assertEqual(self.required_field.metadata()['required'], True)
-
- def test_optional(self):
- self.assertEqual(self.optional_field.metadata()['required'], False)
-
- def test_label(self):
- for field in (self.required_field, self.optional_field):
- self.assertEqual(field.metadata()['label'], field.label)
-
-
-class FieldCallableDefault(TestCase):
- def setUp(self):
- self.simple_callable = lambda: 'foo bar'
-
- def test_default_can_be_simple_callable(self):
- """
- Ensure that the 'default' argument can also be a simple callable.
- """
- field = serializers.WritableField(default=self.simple_callable)
- into = {}
- field.field_from_native({}, {}, 'field', into)
- self.assertEqual(into, {'field': 'foo bar'})
-
-
-class CustomIntegerField(TestCase):
- """
- Test that custom fields apply min_value and max_value constraints
- """
- def test_custom_fields_can_be_validated_for_value(self):
-
- class MoneyField(models.PositiveIntegerField):
- pass
-
- class EntryModel(models.Model):
- bank = MoneyField(validators=[validators.MaxValueValidator(100)])
-
- class EntrySerializer(serializers.ModelSerializer):
- class Meta:
- model = EntryModel
-
- entry = EntryModel(bank=1)
-
- serializer = EntrySerializer(entry, data={"bank": 11})
- self.assertTrue(serializer.is_valid())
-
- serializer = EntrySerializer(entry, data={"bank": -1})
- self.assertFalse(serializer.is_valid())
-
- serializer = EntrySerializer(entry, data={"bank": 101})
- self.assertFalse(serializer.is_valid())
-
-
-class BooleanField(TestCase):
- """
- Tests for BooleanField
- """
- def test_boolean_required(self):
- class BooleanRequiredSerializer(serializers.Serializer):
- bool_field = serializers.BooleanField(required=True)
-
- self.assertFalse(BooleanRequiredSerializer(data={}).is_valid())
diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py
deleted file mode 100644
index 78f4cf42..00000000
--- a/rest_framework/tests/test_files.py
+++ /dev/null
@@ -1,95 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.compat import BytesIO
-from rest_framework.compat import six
-import datetime
-
-
-class UploadedFile(object):
- def __init__(self, file=None, created=None):
- self.file = file
- self.created = created or datetime.datetime.now()
-
-
-class UploadedFileSerializer(serializers.Serializer):
- file = serializers.FileField(required=False)
- created = serializers.DateTimeField()
-
- def restore_object(self, attrs, instance=None):
- if instance:
- instance.file = attrs['file']
- instance.created = attrs['created']
- return instance
- return UploadedFile(**attrs)
-
-
-class FileSerializerTests(TestCase):
- def test_create(self):
- now = datetime.datetime.now()
- file = BytesIO(six.b('stuff'))
- file.name = 'stuff.txt'
- file.size = len(file.getvalue())
- serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
- uploaded_file = UploadedFile(file=file, created=now)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.object.created, uploaded_file.created)
- self.assertEqual(serializer.object.file, uploaded_file.file)
- self.assertFalse(serializer.object is uploaded_file)
-
- def test_creation_failure(self):
- """
- Passing files=None should result in an ValidationError
-
- Regression test for:
- https://github.com/tomchristie/django-rest-framework/issues/542
- """
- now = datetime.datetime.now()
-
- serializer = UploadedFileSerializer(data={'created': now})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.object.created, now)
- self.assertIsNone(serializer.object.file)
-
- def test_remove_with_empty_string(self):
- """
- Passing empty string as data should cause file to be removed
-
- Test for:
- https://github.com/tomchristie/django-rest-framework/issues/937
- """
- now = datetime.datetime.now()
- file = BytesIO(six.b('stuff'))
- file.name = 'stuff.txt'
- file.size = len(file.getvalue())
-
- uploaded_file = UploadedFile(file=file, created=now)
-
- serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.object.created, uploaded_file.created)
- self.assertIsNone(serializer.object.file)
-
- def test_validation_error_with_non_file(self):
- """
- Passing non-files should raise a validation error.
- """
- now = datetime.datetime.now()
- errmsg = 'No file was submitted. Check the encoding type on the form.'
-
- serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'file': [errmsg]})
-
- def test_validation_with_no_data(self):
- """
- Validation should still function when no data dictionary is provided.
- """
- now = datetime.datetime.now()
- file = BytesIO(six.b('stuff'))
- file.name = 'stuff.txt'
- file.size = len(file.getvalue())
- uploaded_file = UploadedFile(file=file, created=now)
-
- serializer = UploadedFileSerializer(files={'file': file})
- self.assertFalse(serializer.is_valid())
diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py
deleted file mode 100644
index 23226bbc..00000000
--- a/rest_framework/tests/test_filters.py
+++ /dev/null
@@ -1,661 +0,0 @@
-from __future__ import unicode_literals
-import datetime
-from decimal import Decimal
-from django.db import models
-from django.core.urlresolvers import reverse
-from django.test import TestCase
-from django.utils import unittest
-from rest_framework import generics, serializers, status, filters
-from rest_framework.compat import django_filters, patterns, url
-from rest_framework.settings import api_settings
-from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import BasicModel
-from .models import FilterableItem
-from .utils import temporary_setting
-
-factory = APIRequestFactory()
-
-
-if django_filters:
- # Basic filter on a list view.
- class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_fields = ['decimal', 'date']
- filter_backends = (filters.DjangoFilterBackend,)
-
- # These class are used to test a filter class.
- class SeveralFieldsFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
- decimal = django_filters.NumberFilter(lookup_type='lt')
- date = django_filters.DateFilter(lookup_type='gt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterClassRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = SeveralFieldsFilter
- filter_backends = (filters.DjangoFilterBackend,)
-
- # These classes are used to test a misconfigured filter class.
- class MisconfiguredFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
-
- class Meta:
- model = BasicModel
- fields = ['text']
-
- class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = MisconfiguredFilter
- filter_backends = (filters.DjangoFilterBackend,)
-
- class FilterClassDetailView(generics.RetrieveAPIView):
- model = FilterableItem
- filter_class = SeveralFieldsFilter
- filter_backends = (filters.DjangoFilterBackend,)
-
- # Regression test for #814
- class FilterableItemSerializer(serializers.ModelSerializer):
- class Meta:
- model = FilterableItem
-
- class FilterFieldsQuerysetView(generics.ListCreateAPIView):
- queryset = FilterableItem.objects.all()
- serializer_class = FilterableItemSerializer
- filter_fields = ['decimal', 'date']
- filter_backends = (filters.DjangoFilterBackend,)
-
- class GetQuerysetView(generics.ListCreateAPIView):
- serializer_class = FilterableItemSerializer
- filter_class = SeveralFieldsFilter
- filter_backends = (filters.DjangoFilterBackend,)
-
- def get_queryset(self):
- return FilterableItem.objects.all()
-
- urlpatterns = patterns('',
- url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
- url(r'^$', FilterClassRootView.as_view(), name='root-view'),
- url(r'^get-queryset/$', GetQuerysetView.as_view(),
- name='get-queryset-view'),
- )
-
-
-class CommonFilteringTestCase(TestCase):
- def _serialize_object(self, obj):
- return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
-
- def setUp(self):
- """
- Create 10 FilterableItem instances.
- """
- base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
- for i in range(10):
- text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
- decimal = base_data[1] + i
- date = base_data[2] - datetime.timedelta(days=i * 2)
- FilterableItem(text=text, decimal=decimal, date=date).save()
-
- self.objects = FilterableItem.objects
- self.data = [
- self._serialize_object(obj)
- for obj in self.objects.all()
- ]
-
-
-class IntegrationTestFiltering(CommonFilteringTestCase):
- """
- Integration tests for filtered list views.
- """
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_get_filtered_fields_root_view(self):
- """
- GET requests to paginated ListCreateAPIView should return paginated results.
- """
- view = FilterFieldsRootView.as_view()
-
- # Basic test with no filter.
- request = factory.get('/')
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- # Tests that the decimal filter works.
- search_decimal = Decimal('2.25')
- request = factory.get('/', {'decimal': '%s' % search_decimal})
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] == search_decimal]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the date filter works.
- search_date = datetime.date(2012, 9, 22)
- request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-09-22'
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] == search_date]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_filter_with_queryset(self):
- """
- Regression test for #814.
- """
- view = FilterFieldsQuerysetView.as_view()
-
- # Tests that the decimal filter works.
- search_decimal = Decimal('2.25')
- request = factory.get('/', {'decimal': '%s' % search_decimal})
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] == search_decimal]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_filter_with_get_queryset_only(self):
- """
- Regression test for #834.
- """
- view = GetQuerysetView.as_view()
- request = factory.get('/get-queryset/')
- view(request).render()
- # Used to raise "issubclass() arg 2 must be a class or tuple of classes"
- # here when neither `model' nor `queryset' was specified.
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_get_filtered_class_root_view(self):
- """
- GET requests to filtered ListCreateAPIView that have a filter_class set
- should return filtered results.
- """
- view = FilterClassRootView.as_view()
-
- # Basic test with no filter.
- request = factory.get('/')
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- # Tests that the decimal filter set with 'lt' in the filter class works.
- search_decimal = Decimal('4.25')
- request = factory.get('/', {'decimal': '%s' % search_decimal})
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] < search_decimal]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the date filter set with 'gt' in the filter class works.
- search_date = datetime.date(2012, 10, 2)
- request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-10-02'
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the text filter set with 'icontains' in the filter class works.
- search_text = 'ff'
- request = factory.get('/', {'text': '%s' % search_text})
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if search_text in f['text'].lower()]
- self.assertEqual(response.data, expected_data)
-
- # Tests that multiple filters works.
- search_decimal = Decimal('5.25')
- search_date = datetime.date(2012, 10, 2)
- request = factory.get('/', {
- 'decimal': '%s' % (search_decimal,),
- 'date': '%s' % (search_date,)
- })
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date and
- f['decimal'] < search_decimal]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_incorrectly_configured_filter(self):
- """
- An error should be displayed when the filter class is misconfigured.
- """
- view = IncorrectlyConfiguredRootView.as_view()
-
- request = factory.get('/')
- self.assertRaises(AssertionError, view, request)
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_unknown_filter(self):
- """
- GET requests with filters that aren't configured should return 200.
- """
- view = FilterFieldsRootView.as_view()
-
- search_integer = 10
- request = factory.get('/', {'integer': '%s' % search_integer})
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
-
-class IntegrationTestDetailFiltering(CommonFilteringTestCase):
- """
- Integration tests for filtered detail views.
- """
- urls = 'rest_framework.tests.test_filters'
-
- def _get_url(self, item):
- return reverse('detail-view', kwargs=dict(pk=item.pk))
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_get_filtered_detail_view(self):
- """
- GET requests to filtered RetrieveAPIView that have a filter_class set
- should return filtered results.
- """
- item = self.objects.all()[0]
- data = self._serialize_object(item)
-
- # Basic test with no filter.
- response = self.client.get(self._get_url(item))
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, data)
-
- # Tests that the decimal filter set that should fail.
- search_decimal = Decimal('4.25')
- high_item = self.objects.filter(decimal__gt=search_decimal)[0]
- response = self.client.get(
- '{url}'.format(url=self._get_url(high_item)),
- {'decimal': '{param}'.format(param=search_decimal)})
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
- # Tests that the decimal filter set that should succeed.
- search_decimal = Decimal('4.25')
- low_item = self.objects.filter(decimal__lt=search_decimal)[0]
- low_item_data = self._serialize_object(low_item)
- response = self.client.get(
- '{url}'.format(url=self._get_url(low_item)),
- {'decimal': '{param}'.format(param=search_decimal)})
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, low_item_data)
-
- # Tests that multiple filters works.
- search_decimal = Decimal('5.25')
- search_date = datetime.date(2012, 10, 2)
- valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
- valid_item_data = self._serialize_object(valid_item)
- response = self.client.get(
- '{url}'.format(url=self._get_url(valid_item)), {
- 'decimal': '{decimal}'.format(decimal=search_decimal),
- 'date': '{date}'.format(date=search_date)
- })
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, valid_item_data)
-
-
-class SearchFilterModel(models.Model):
- title = models.CharField(max_length=20)
- text = models.CharField(max_length=100)
-
-
-class SearchFilterTests(TestCase):
- def setUp(self):
- # Sequence of title/text is:
- #
- # z abc
- # zz bcd
- # zzz cde
- # ...
- for idx in range(10):
- title = 'z' * (idx + 1)
- text = (
- chr(idx + ord('a')) +
- chr(idx + ord('b')) +
- chr(idx + ord('c'))
- )
- SearchFilterModel(title=title, text=text).save()
-
- def test_search(self):
- class SearchListView(generics.ListAPIView):
- model = SearchFilterModel
- filter_backends = (filters.SearchFilter,)
- search_fields = ('title', 'text')
-
- view = SearchListView.as_view()
- request = factory.get('/', {'search': 'b'})
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 1, 'title': 'z', 'text': 'abc'},
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
- ]
- )
-
- def test_exact_search(self):
- class SearchListView(generics.ListAPIView):
- model = SearchFilterModel
- filter_backends = (filters.SearchFilter,)
- search_fields = ('=title', 'text')
-
- view = SearchListView.as_view()
- request = factory.get('/', {'search': 'zzz'})
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 3, 'title': 'zzz', 'text': 'cde'}
- ]
- )
-
- def test_startswith_search(self):
- class SearchListView(generics.ListAPIView):
- model = SearchFilterModel
- filter_backends = (filters.SearchFilter,)
- search_fields = ('title', '^text')
-
- view = SearchListView.as_view()
- request = factory.get('/', {'search': 'b'})
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
- ]
- )
-
- def test_search_with_nonstandard_search_param(self):
- with temporary_setting('SEARCH_PARAM', 'query', module=filters):
- class SearchListView(generics.ListAPIView):
- model = SearchFilterModel
- filter_backends = (filters.SearchFilter,)
- search_fields = ('title', 'text')
-
- view = SearchListView.as_view()
- request = factory.get('/', {'query': 'b'})
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 1, 'title': 'z', 'text': 'abc'},
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
- ]
- )
-
-
-class OrdringFilterModel(models.Model):
- title = models.CharField(max_length=20)
- text = models.CharField(max_length=100)
-
-
-class OrderingFilterRelatedModel(models.Model):
- related_object = models.ForeignKey(OrdringFilterModel,
- related_name="relateds")
-
-
-class OrderingFilterTests(TestCase):
- def setUp(self):
- # Sequence of title/text is:
- #
- # zyx abc
- # yxw bcd
- # xwv cde
- for idx in range(3):
- title = (
- chr(ord('z') - idx) +
- chr(ord('y') - idx) +
- chr(ord('x') - idx)
- )
- text = (
- chr(idx + ord('a')) +
- chr(idx + ord('b')) +
- chr(idx + ord('c'))
- )
- OrdringFilterModel(title=title, text=text).save()
-
- def test_ordering(self):
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
-
- view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'text'})
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- ]
- )
-
- def test_reverse_ordering(self):
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
-
- view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': '-text'})
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
- )
-
- def test_incorrectfield_ordering(self):
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
-
- view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'foobar'})
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
- )
-
- def test_default_ordering(self):
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- oredering_fields = ('text',)
-
- view = OrderingListView.as_view()
- request = factory.get('')
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
- )
-
- def test_default_ordering_using_string(self):
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = 'title'
- ordering_fields = ('text',)
-
- view = OrderingListView.as_view()
- request = factory.get('')
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
- )
-
- def test_ordering_by_aggregate_field(self):
- # create some related models to aggregate order by
- num_objs = [2, 5, 3]
- for obj, num_relateds in zip(OrdringFilterModel.objects.all(),
- num_objs):
- for _ in range(num_relateds):
- new_related = OrderingFilterRelatedModel(
- related_object=obj
- )
- new_related.save()
-
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = 'title'
- ordering_fields = '__all__'
- queryset = OrdringFilterModel.objects.all().annotate(
- models.Count("relateds"))
-
- view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'relateds__count'})
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- ]
- )
-
- def test_ordering_with_nonstandard_ordering_param(self):
- with temporary_setting('ORDERING_PARAM', 'order', filters):
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
-
- view = OrderingListView.as_view()
- request = factory.get('/', {'order': 'text'})
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- ]
- )
-
-
-class SensitiveOrderingFilterModel(models.Model):
- username = models.CharField(max_length=20)
- password = models.CharField(max_length=100)
-
-
-# Three different styles of serializer.
-# All should allow ordering by username, but not by password.
-class SensitiveDataSerializer1(serializers.ModelSerializer):
- username = serializers.CharField()
-
- class Meta:
- model = SensitiveOrderingFilterModel
- fields = ('id', 'username')
-
-
-class SensitiveDataSerializer2(serializers.ModelSerializer):
- username = serializers.CharField()
- password = serializers.CharField(write_only=True)
-
- class Meta:
- model = SensitiveOrderingFilterModel
- fields = ('id', 'username', 'password')
-
-
-class SensitiveDataSerializer3(serializers.ModelSerializer):
- user = serializers.CharField(source='username')
-
- class Meta:
- model = SensitiveOrderingFilterModel
- fields = ('id', 'user')
-
-
-class SensitiveOrderingFilterTests(TestCase):
- def setUp(self):
- for idx in range(3):
- username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
- password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
- SensitiveOrderingFilterModel(username=username, password=password).save()
-
- def test_order_by_serializer_fields(self):
- for serializer_cls in [
- SensitiveDataSerializer1,
- SensitiveDataSerializer2,
- SensitiveDataSerializer3
- ]:
- class OrderingListView(generics.ListAPIView):
- queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
- filter_backends = (filters.OrderingFilter,)
- serializer_class = serializer_cls
-
- view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': '-username'})
- response = view(request)
-
- if serializer_cls == SensitiveDataSerializer3:
- username_field = 'user'
- else:
- username_field = 'username'
-
- # Note: Inverse username ordering correctly applied.
- self.assertEqual(
- response.data,
- [
- {'id': 3, username_field: 'userC'},
- {'id': 2, username_field: 'userB'},
- {'id': 1, username_field: 'userA'},
- ]
- )
-
- def test_cannot_order_by_non_serializer_fields(self):
- for serializer_cls in [
- SensitiveDataSerializer1,
- SensitiveDataSerializer2,
- SensitiveDataSerializer3
- ]:
- class OrderingListView(generics.ListAPIView):
- queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
- filter_backends = (filters.OrderingFilter,)
- serializer_class = serializer_cls
-
- view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'password'})
- response = view(request)
-
- if serializer_cls == SensitiveDataSerializer3:
- username_field = 'user'
- else:
- username_field = 'username'
-
- # Note: The passwords are not in order. Default ordering is used.
- self.assertEqual(
- response.data,
- [
- {'id': 1, username_field: 'userA'}, # PassB
- {'id': 2, username_field: 'userB'}, # PassC
- {'id': 3, username_field: 'userC'}, # PassA
- ]
- )
diff --git a/rest_framework/tests/test_genericrelations.py b/rest_framework/tests/test_genericrelations.py
deleted file mode 100644
index fa09c9e6..00000000
--- a/rest_framework/tests/test_genericrelations.py
+++ /dev/null
@@ -1,133 +0,0 @@
-from __future__ import unicode_literals
-from django.contrib.contenttypes.models import ContentType
-from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.compat import python_2_unicode_compatible
-
-
-@python_2_unicode_compatible
-class Tag(models.Model):
- """
- Tags have a descriptive slug, and are attached to an arbitrary object.
- """
- tag = models.SlugField()
- content_type = models.ForeignKey(ContentType)
- object_id = models.PositiveIntegerField()
- tagged_item = GenericForeignKey('content_type', 'object_id')
-
- def __str__(self):
- return self.tag
-
-
-@python_2_unicode_compatible
-class Bookmark(models.Model):
- """
- A URL bookmark that may have multiple tags attached.
- """
- url = models.URLField()
- tags = GenericRelation(Tag)
-
- def __str__(self):
- return 'Bookmark: %s' % self.url
-
-
-@python_2_unicode_compatible
-class Note(models.Model):
- """
- A textual note that may have multiple tags attached.
- """
- text = models.TextField()
- tags = GenericRelation(Tag)
-
- def __str__(self):
- return 'Note: %s' % self.text
-
-
-class TestGenericRelations(TestCase):
- def setUp(self):
- self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
- Tag.objects.create(tagged_item=self.bookmark, tag='django')
- Tag.objects.create(tagged_item=self.bookmark, tag='python')
- self.note = Note.objects.create(text='Remember the milk')
- Tag.objects.create(tagged_item=self.note, tag='reminder')
-
- def test_generic_relation(self):
- """
- Test a relationship that spans a GenericRelation field.
- IE. A reverse generic relationship.
- """
-
- class BookmarkSerializer(serializers.ModelSerializer):
- tags = serializers.RelatedField(many=True)
-
- class Meta:
- model = Bookmark
- exclude = ('id',)
-
- serializer = BookmarkSerializer(self.bookmark)
- expected = {
- 'tags': ['django', 'python'],
- 'url': 'https://www.djangoproject.com/'
- }
- self.assertEqual(serializer.data, expected)
-
- def test_generic_nested_relation(self):
- """
- Test saving a GenericRelation field via a nested serializer.
- """
-
- class TagSerializer(serializers.ModelSerializer):
- class Meta:
- model = Tag
- exclude = ('content_type', 'object_id')
-
- class BookmarkSerializer(serializers.ModelSerializer):
- tags = TagSerializer()
-
- class Meta:
- model = Bookmark
- exclude = ('id',)
-
- data = {
- 'url': 'https://docs.djangoproject.com/',
- 'tags': [
- {'tag': 'contenttypes'},
- {'tag': 'genericrelations'},
- ]
- }
- serializer = BookmarkSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- serializer.save()
- self.assertEqual(serializer.object.tags.count(), 2)
-
- def test_generic_fk(self):
- """
- Test a relationship that spans a GenericForeignKey field.
- IE. A forward generic relationship.
- """
-
- class TagSerializer(serializers.ModelSerializer):
- tagged_item = serializers.RelatedField()
-
- class Meta:
- model = Tag
- exclude = ('id', 'content_type', 'object_id')
-
- serializer = TagSerializer(Tag.objects.all(), many=True)
- expected = [
- {
- 'tag': 'django',
- 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
- },
- {
- 'tag': 'python',
- 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
- },
- {
- 'tag': 'reminder',
- 'tagged_item': 'Note: Remember the milk'
- }
- ]
- self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py
deleted file mode 100644
index 996bd5b0..00000000
--- a/rest_framework/tests/test_generics.py
+++ /dev/null
@@ -1,609 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.shortcuts import get_object_or_404
-from django.test import TestCase
-from rest_framework import generics, renderers, serializers, status
-from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
-from rest_framework.compat import six
-
-factory = APIRequestFactory()
-
-
-class RootView(generics.ListCreateAPIView):
- """
- Example description for OPTIONS.
- """
- model = BasicModel
-
-
-class InstanceView(generics.RetrieveUpdateDestroyAPIView):
- """
- Example description for OPTIONS.
- """
- model = BasicModel
-
- def get_queryset(self):
- queryset = super(InstanceView, self).get_queryset()
- return queryset.exclude(text='filtered out')
-
-
-class SlugSerializer(serializers.ModelSerializer):
- slug = serializers.Field() # read only
-
- class Meta:
- model = SlugBasedModel
- exclude = ('id',)
-
-
-class SlugBasedInstanceView(InstanceView):
- """
- A model with a slug-field.
- """
- model = SlugBasedModel
- serializer_class = SlugSerializer
- lookup_field = 'slug'
-
-
-class TestRootView(TestCase):
- def setUp(self):
- """
- Create 3 BasicModel instances.
- """
- items = ['foo', 'bar', 'baz']
- for item in items:
- BasicModel(text=item).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = RootView.as_view()
-
- def test_get_root_view(self):
- """
- GET requests to ListCreateAPIView should return list of objects.
- """
- request = factory.get('/')
- with self.assertNumQueries(1):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- def test_post_root_view(self):
- """
- POST requests to ListCreateAPIView should create a new object.
- """
- data = {'text': 'foobar'}
- request = factory.post('/', data, format='json')
- with self.assertNumQueries(1):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
- created = self.objects.get(id=4)
- self.assertEqual(created.text, 'foobar')
-
- def test_put_root_view(self):
- """
- PUT requests to ListCreateAPIView should not be allowed
- """
- data = {'text': 'foobar'}
- request = factory.put('/', data, format='json')
- with self.assertNumQueries(0):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."})
-
- def test_delete_root_view(self):
- """
- DELETE requests to ListCreateAPIView should not be allowed
- """
- request = factory.delete('/')
- with self.assertNumQueries(0):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."})
-
- def test_options_root_view(self):
- """
- OPTIONS requests to ListCreateAPIView should return metadata
- """
- request = factory.options('/')
- with self.assertNumQueries(0):
- response = self.view(request).render()
- expected = {
- 'parses': [
- 'application/json',
- 'application/x-www-form-urlencoded',
- 'multipart/form-data'
- ],
- 'renders': [
- 'application/json',
- 'text/html'
- ],
- 'name': 'Root',
- 'description': 'Example description for OPTIONS.',
- 'actions': {
- 'POST': {
- 'text': {
- 'max_length': 100,
- 'read_only': False,
- 'required': True,
- 'type': 'string',
- "label": "Text comes here",
- "help_text": "Text description."
- },
- 'id': {
- 'read_only': True,
- 'required': False,
- 'type': 'integer',
- 'label': 'ID',
- },
- }
- }
- }
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, expected)
-
- def test_post_cannot_set_id(self):
- """
- POST requests to create a new object should not be able to set the id.
- """
- data = {'id': 999, 'text': 'foobar'}
- request = factory.post('/', data, format='json')
- with self.assertNumQueries(1):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
- created = self.objects.get(id=4)
- self.assertEqual(created.text, 'foobar')
-
-
-class TestInstanceView(TestCase):
- def setUp(self):
- """
- Create 3 BasicModel intances.
- """
- items = ['foo', 'bar', 'baz', 'filtered out']
- for item in items:
- BasicModel(text=item).save()
- self.objects = BasicModel.objects.exclude(text='filtered out')
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = InstanceView.as_view()
- self.slug_based_view = SlugBasedInstanceView.as_view()
-
- def test_get_instance_view(self):
- """
- GET requests to RetrieveUpdateDestroyAPIView should return a single object.
- """
- request = factory.get('/1')
- with self.assertNumQueries(1):
- response = self.view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data[0])
-
- def test_post_instance_view(self):
- """
- POST requests to RetrieveUpdateDestroyAPIView should not be allowed
- """
- data = {'text': 'foobar'}
- request = factory.post('/', data, format='json')
- with self.assertNumQueries(0):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."})
-
- def test_put_instance_view(self):
- """
- PUT requests to RetrieveUpdateDestroyAPIView should update an object.
- """
- data = {'text': 'foobar'}
- request = factory.put('/1', data, format='json')
- with self.assertNumQueries(2):
- response = self.view(request, pk='1').render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
- updated = self.objects.get(id=1)
- self.assertEqual(updated.text, 'foobar')
-
- def test_patch_instance_view(self):
- """
- PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
- """
- data = {'text': 'foobar'}
- request = factory.patch('/1', data, format='json')
-
- with self.assertNumQueries(2):
- response = self.view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
- updated = self.objects.get(id=1)
- self.assertEqual(updated.text, 'foobar')
-
- def test_delete_instance_view(self):
- """
- DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
- """
- request = factory.delete('/1')
- with self.assertNumQueries(2):
- response = self.view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
- self.assertEqual(response.content, six.b(''))
- ids = [obj.id for obj in self.objects.all()]
- self.assertEqual(ids, [2, 3])
-
- def test_options_instance_view(self):
- """
- OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
- """
- request = factory.options('/1')
- with self.assertNumQueries(1):
- response = self.view(request, pk=1).render()
- expected = {
- 'parses': [
- 'application/json',
- 'application/x-www-form-urlencoded',
- 'multipart/form-data'
- ],
- 'renders': [
- 'application/json',
- 'text/html'
- ],
- 'name': 'Instance',
- 'description': 'Example description for OPTIONS.',
- 'actions': {
- 'PUT': {
- 'text': {
- 'max_length': 100,
- 'read_only': False,
- 'required': True,
- 'type': 'string',
- 'label': 'Text comes here',
- 'help_text': 'Text description.'
- },
- 'id': {
- 'read_only': True,
- 'required': False,
- 'type': 'integer',
- 'label': 'ID',
- },
- }
- }
- }
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, expected)
-
- def test_options_before_instance_create(self):
- """
- OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
- before the instance has been created
- """
- request = factory.options('/999')
- with self.assertNumQueries(1):
- response = self.view(request, pk=999).render()
- expected = {
- 'parses': [
- 'application/json',
- 'application/x-www-form-urlencoded',
- 'multipart/form-data'
- ],
- 'renders': [
- 'application/json',
- 'text/html'
- ],
- 'name': 'Instance',
- 'description': 'Example description for OPTIONS.',
- 'actions': {
- 'PUT': {
- 'text': {
- 'max_length': 100,
- 'read_only': False,
- 'required': True,
- 'type': 'string',
- 'label': 'Text comes here',
- 'help_text': 'Text description.'
- },
- 'id': {
- 'read_only': True,
- 'required': False,
- 'type': 'integer',
- 'label': 'ID',
- },
- }
- }
- }
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, expected)
-
- def test_get_instance_view_incorrect_arg(self):
- """
- GET requests with an incorrect pk type, should raise 404, not 500.
- Regression test for #890.
- """
- request = factory.get('/a')
- with self.assertNumQueries(0):
- response = self.view(request, pk='a').render()
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
- def test_put_cannot_set_id(self):
- """
- PUT requests to create a new object should not be able to set the id.
- """
- data = {'id': 999, 'text': 'foobar'}
- request = factory.put('/1', data, format='json')
- with self.assertNumQueries(2):
- response = self.view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
- updated = self.objects.get(id=1)
- self.assertEqual(updated.text, 'foobar')
-
- def test_put_to_deleted_instance(self):
- """
- PUT requests to RetrieveUpdateDestroyAPIView should create an object
- if it does not currently exist.
- """
- self.objects.get(id=1).delete()
- data = {'text': 'foobar'}
- request = factory.put('/1', data, format='json')
- with self.assertNumQueries(3):
- response = self.view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
- updated = self.objects.get(id=1)
- self.assertEqual(updated.text, 'foobar')
-
- def test_put_to_filtered_out_instance(self):
- """
- PUT requests to an URL of instance which is filtered out should not be
- able to create new objects.
- """
- data = {'text': 'foo'}
- filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
- request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
- response = self.view(request, pk=filtered_out_pk).render()
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-
- def test_put_as_create_on_id_based_url(self):
- """
- PUT requests to RetrieveUpdateDestroyAPIView should create an object
- at the requested url if it doesn't exist.
- """
- data = {'text': 'foobar'}
- # pk fields can not be created on demand, only the database can set the pk for a new object
- request = factory.put('/5', data, format='json')
- with self.assertNumQueries(3):
- response = self.view(request, pk=5).render()
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- new_obj = self.objects.get(pk=5)
- self.assertEqual(new_obj.text, 'foobar')
-
- def test_put_as_create_on_slug_based_url(self):
- """
- PUT requests to RetrieveUpdateDestroyAPIView should create an object
- at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
- """
- data = {'text': 'foobar'}
- request = factory.put('/test_slug', data, format='json')
- with self.assertNumQueries(2):
- response = self.slug_based_view(request, slug='test_slug').render()
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'})
- new_obj = SlugBasedModel.objects.get(slug='test_slug')
- self.assertEqual(new_obj.text, 'foobar')
-
- def test_patch_cannot_create_an_object(self):
- """
- PATCH requests should not be able to create objects.
- """
- data = {'text': 'foobar'}
- request = factory.patch('/999', data, format='json')
- with self.assertNumQueries(1):
- response = self.view(request, pk=999).render()
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertFalse(self.objects.filter(id=999).exists())
-
-
-class TestOverriddenGetObject(TestCase):
- """
- Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
- queryset/model mechanism but instead overrides get_object()
- """
- def setUp(self):
- """
- Create 3 BasicModel intances.
- """
- items = ['foo', 'bar', 'baz']
- for item in items:
- BasicModel(text=item).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
-
- class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
- """
- Example detail view for override of get_object().
- """
- model = BasicModel
-
- def get_object(self):
- pk = int(self.kwargs['pk'])
- return get_object_or_404(BasicModel.objects.all(), id=pk)
-
- self.view = OverriddenGetObjectView.as_view()
-
- def test_overridden_get_object_view(self):
- """
- GET requests to RetrieveUpdateDestroyAPIView should return a single object.
- """
- request = factory.get('/1')
- with self.assertNumQueries(1):
- response = self.view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data[0])
-
-
-# Regression test for #285
-
-class CommentSerializer(serializers.ModelSerializer):
- class Meta:
- model = Comment
- exclude = ('created',)
-
-
-class CommentView(generics.ListCreateAPIView):
- serializer_class = CommentSerializer
- model = Comment
-
-
-class TestCreateModelWithAutoNowAddField(TestCase):
- def setUp(self):
- self.objects = Comment.objects
- self.view = CommentView.as_view()
-
- def test_create_model_with_auto_now_add_field(self):
- """
- Regression test for #285
-
- https://github.com/tomchristie/django-rest-framework/issues/285
- """
- data = {'email': 'foobar@example.com', 'content': 'foobar'}
- request = factory.post('/', data, format='json')
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- created = self.objects.get(id=1)
- self.assertEqual(created.content, 'foobar')
-
-
-# Test for particularly ugly regression with m2m in browsable API
-class ClassB(models.Model):
- name = models.CharField(max_length=255)
-
-
-class ClassA(models.Model):
- name = models.CharField(max_length=255)
- childs = models.ManyToManyField(ClassB, blank=True, null=True)
-
-
-class ClassASerializer(serializers.ModelSerializer):
- childs = serializers.PrimaryKeyRelatedField(many=True, source='childs')
-
- class Meta:
- model = ClassA
-
-
-class ExampleView(generics.ListCreateAPIView):
- serializer_class = ClassASerializer
- model = ClassA
-
-
-class TestM2MBrowseableAPI(TestCase):
- def test_m2m_in_browseable_api(self):
- """
- Test for particularly ugly regression with m2m in browsable API
- """
- request = factory.get('/', HTTP_ACCEPT='text/html')
- view = ExampleView().as_view()
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
-
-class InclusiveFilterBackend(object):
- def filter_queryset(self, request, queryset, view):
- return queryset.filter(text='foo')
-
-
-class ExclusiveFilterBackend(object):
- def filter_queryset(self, request, queryset, view):
- return queryset.filter(text='other')
-
-
-class TwoFieldModel(models.Model):
- field_a = models.CharField(max_length=100)
- field_b = models.CharField(max_length=100)
-
-
-class DynamicSerializerView(generics.ListCreateAPIView):
- model = TwoFieldModel
- renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
-
- def get_serializer_class(self):
- if self.request.method == 'POST':
- class DynamicSerializer(serializers.ModelSerializer):
- class Meta:
- model = TwoFieldModel
- fields = ('field_b',)
- return DynamicSerializer
- return super(DynamicSerializerView, self).get_serializer_class()
-
-
-class TestFilterBackendAppliedToViews(TestCase):
-
- def setUp(self):
- """
- Create 3 BasicModel instances to filter on.
- """
- items = ['foo', 'bar', 'baz']
- for item in items:
- BasicModel(text=item).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
-
- def test_get_root_view_filters_by_name_with_filter_backend(self):
- """
- GET requests to ListCreateAPIView should return filtered list.
- """
- root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
- request = factory.get('/')
- response = root_view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(len(response.data), 1)
- self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
-
- def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
- """
- GET requests to ListCreateAPIView should return empty list when all models are filtered out.
- """
- root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
- request = factory.get('/')
- response = root_view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, [])
-
- def test_get_instance_view_filters_out_name_with_filter_backend(self):
- """
- GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
- """
- instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
- request = factory.get('/1')
- response = instance_view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertEqual(response.data, {'detail': 'Not found'})
-
- def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
- """
- GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
- """
- instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
- request = factory.get('/1')
- response = instance_view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
-
- def test_dynamic_serializer_form_in_browsable_api(self):
- """
- GET requests to ListCreateAPIView should return filtered list.
- """
- view = DynamicSerializerView.as_view()
- request = factory.get('/')
- response = view(request).render()
- self.assertContains(response, 'field_b')
- self.assertNotContains(response, 'field_a')
diff --git a/rest_framework/tests/test_htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py
deleted file mode 100644
index 514d9e2b..00000000
--- a/rest_framework/tests/test_htmlrenderer.py
+++ /dev/null
@@ -1,120 +0,0 @@
-from __future__ import unicode_literals
-from django.core.exceptions import PermissionDenied
-from django.http import Http404
-from django.test import TestCase
-from django.template import TemplateDoesNotExist, Template
-import django.template.loader
-from rest_framework import status
-from rest_framework.compat import patterns, url
-from rest_framework.decorators import api_view, renderer_classes
-from rest_framework.renderers import TemplateHTMLRenderer
-from rest_framework.response import Response
-from rest_framework.compat import six
-
-
-@api_view(('GET',))
-@renderer_classes((TemplateHTMLRenderer,))
-def example(request):
- """
- A view that can returns an HTML representation.
- """
- data = {'object': 'foobar'}
- return Response(data, template_name='example.html')
-
-
-@api_view(('GET',))
-@renderer_classes((TemplateHTMLRenderer,))
-def permission_denied(request):
- raise PermissionDenied()
-
-
-@api_view(('GET',))
-@renderer_classes((TemplateHTMLRenderer,))
-def not_found(request):
- raise Http404()
-
-
-urlpatterns = patterns('',
- url(r'^$', example),
- url(r'^permission_denied$', permission_denied),
- url(r'^not_found$', not_found),
-)
-
-
-class TemplateHTMLRendererTests(TestCase):
- urls = 'rest_framework.tests.test_htmlrenderer'
-
- def setUp(self):
- """
- Monkeypatch get_template
- """
- self.get_template = django.template.loader.get_template
-
- def get_template(template_name, dirs=None):
- if template_name == 'example.html':
- return Template("example: {{ object }}")
- raise TemplateDoesNotExist(template_name)
-
- django.template.loader.get_template = get_template
-
- def tearDown(self):
- """
- Revert monkeypatching
- """
- django.template.loader.get_template = self.get_template
-
- def test_simple_html_view(self):
- response = self.client.get('/')
- self.assertContains(response, "example: foobar")
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
-
- def test_not_found_html_view(self):
- response = self.client.get('/not_found')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertEqual(response.content, six.b("404 Not Found"))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
-
- def test_permission_denied_html_view(self):
- response = self.client.get('/permission_denied')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertEqual(response.content, six.b("403 Forbidden"))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
-
-
-class TemplateHTMLRendererExceptionTests(TestCase):
- urls = 'rest_framework.tests.test_htmlrenderer'
-
- def setUp(self):
- """
- Monkeypatch get_template
- """
- self.get_template = django.template.loader.get_template
-
- def get_template(template_name):
- if template_name == '404.html':
- return Template("404: {{ detail }}")
- if template_name == '403.html':
- return Template("403: {{ detail }}")
- raise TemplateDoesNotExist(template_name)
-
- django.template.loader.get_template = get_template
-
- def tearDown(self):
- """
- Revert monkeypatching
- """
- django.template.loader.get_template = self.get_template
-
- def test_not_found_html_view_with_template(self):
- response = self.client.get('/not_found')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertTrue(response.content in (
- six.b("404: Not found"), six.b("404 Not Found")))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
-
- def test_permission_denied_html_view_with_template(self):
- response = self.client.get('/permission_denied')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertTrue(response.content in (
- six.b("403: Permission denied"), six.b("403 Forbidden")))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py
deleted file mode 100644
index 83d46043..00000000
--- a/rest_framework/tests/test_hyperlinkedserializers.py
+++ /dev/null
@@ -1,379 +0,0 @@
-from __future__ import unicode_literals
-import json
-from django.test import TestCase
-from rest_framework import generics, status, serializers
-from rest_framework.compat import patterns, url
-from rest_framework.settings import api_settings
-from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import (
- Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment,
- Album, Photo, OptionalRelationModel
-)
-
-factory = APIRequestFactory()
-
-
-class BlogPostCommentSerializer(serializers.ModelSerializer):
- url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail')
- text = serializers.CharField()
- blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail')
-
- class Meta:
- model = BlogPostComment
- fields = ('text', 'blog_post_url', 'url')
-
-
-class PhotoSerializer(serializers.Serializer):
- description = serializers.CharField()
- album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title')
-
- def restore_object(self, attrs, instance=None):
- return Photo(**attrs)
-
-
-class AlbumSerializer(serializers.ModelSerializer):
- url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title')
-
- class Meta:
- model = Album
- fields = ('title', 'url')
-
-
-class BasicList(generics.ListCreateAPIView):
- model = BasicModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
-
-
-class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
- model = BasicModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
-
-
-class AnchorDetail(generics.RetrieveAPIView):
- model = Anchor
- model_serializer_class = serializers.HyperlinkedModelSerializer
-
-
-class ManyToManyList(generics.ListAPIView):
- model = ManyToManyModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
-
-
-class ManyToManyDetail(generics.RetrieveAPIView):
- model = ManyToManyModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
-
-
-class BlogPostCommentListCreate(generics.ListCreateAPIView):
- model = BlogPostComment
- serializer_class = BlogPostCommentSerializer
-
-
-class BlogPostCommentDetail(generics.RetrieveAPIView):
- model = BlogPostComment
- serializer_class = BlogPostCommentSerializer
-
-
-class BlogPostDetail(generics.RetrieveAPIView):
- model = BlogPost
-
-
-class PhotoListCreate(generics.ListCreateAPIView):
- model = Photo
- model_serializer_class = PhotoSerializer
-
-
-class AlbumDetail(generics.RetrieveAPIView):
- model = Album
- serializer_class = AlbumSerializer
- lookup_field = 'title'
-
-
-class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
- model = OptionalRelationModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
-
-
-urlpatterns = patterns('',
- url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
- url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
- url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
- url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
- url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
- url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
- url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'),
- url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'),
- url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'),
- url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'),
- url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'),
-)
-
-
-class TestBasicHyperlinkedView(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- """
- Create 3 BasicModel instances.
- """
- items = ['foo', 'bar', 'baz']
- for item in items:
- BasicModel(text=item).save()
- self.objects = BasicModel.objects
- self.data = [
- {'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.list_view = BasicList.as_view()
- self.detail_view = BasicDetail.as_view()
-
- def test_get_list_view(self):
- """
- GET requests to ListCreateAPIView should return list of objects.
- """
- request = factory.get('/basic/')
- response = self.list_view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- def test_get_detail_view(self):
- """
- GET requests to ListCreateAPIView should return list of objects.
- """
- request = factory.get('/basic/1')
- response = self.detail_view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data[0])
-
-
-class TestManyToManyHyperlinkedView(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- """
- Create 3 BasicModel instances.
- """
- items = ['foo', 'bar', 'baz']
- anchors = []
- for item in items:
- anchor = Anchor(text=item)
- anchor.save()
- anchors.append(anchor)
-
- manytomany = ManyToManyModel()
- manytomany.save()
- manytomany.rel.add(*anchors)
-
- self.data = [{
- 'url': 'http://testserver/manytomany/1/',
- 'rel': [
- 'http://testserver/anchor/1/',
- 'http://testserver/anchor/2/',
- 'http://testserver/anchor/3/',
- ]
- }]
- self.list_view = ManyToManyList.as_view()
- self.detail_view = ManyToManyDetail.as_view()
-
- def test_get_list_view(self):
- """
- GET requests to ListCreateAPIView should return list of objects.
- """
- request = factory.get('/manytomany/')
- response = self.list_view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- def test_get_detail_view(self):
- """
- GET requests to ListCreateAPIView should return list of objects.
- """
- request = factory.get('/manytomany/1/')
- response = self.detail_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data[0])
-
-
-class TestHyperlinkedIdentityFieldLookup(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- """
- Create 3 Album instances.
- """
- titles = ['foo', 'bar', 'baz']
- for title in titles:
- album = Album(title=title)
- album.save()
- self.detail_view = AlbumDetail.as_view()
- self.data = {
- 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'},
- 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
- 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'}
- }
-
- def test_lookup_field(self):
- """
- GET requests to AlbumDetail view should return serialized Albums
- with a url field keyed by `title`.
- """
- for album in Album.objects.all():
- request = factory.get('/albums/{0}/'.format(album.title))
- response = self.detail_view(request, title=album.title)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data[album.title])
-
-
-class TestCreateWithForeignKeys(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- """
- Create a blog post
- """
- self.post = BlogPost.objects.create(title="Test post")
- self.create_view = BlogPostCommentListCreate.as_view()
-
- def test_create_comment(self):
-
- data = {
- 'text': 'A test comment',
- 'blog_post_url': 'http://testserver/posts/1/'
- }
-
- request = factory.post('/comments/', data=data)
- response = self.create_view(request)
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- self.assertEqual(response['Location'], 'http://testserver/comments/1/')
- self.assertEqual(self.post.blogpostcomment_set.count(), 1)
- self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
-
-
-class TestCreateWithForeignKeysAndCustomSlug(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- """
- Create an Album
- """
- self.post = Album.objects.create(title='test-album')
- self.list_create_view = PhotoListCreate.as_view()
-
- def test_create_photo(self):
-
- data = {
- 'description': 'A test photo',
- 'album_url': 'http://testserver/albums/test-album/'
- }
-
- request = factory.post('/photos/', data=data)
- response = self.list_create_view(request)
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
- self.assertEqual(self.post.photo_set.count(), 1)
- self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
-
-
-class TestOptionalRelationHyperlinkedView(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- """
- Create 1 OptionalRelationModel instances.
- """
- OptionalRelationModel().save()
- self.objects = OptionalRelationModel.objects
- self.detail_view = OptionalRelationDetail.as_view()
- self.data = {"url": "http://testserver/optionalrelation/1/", "other": None}
-
- def test_get_detail_view(self):
- """
- GET requests to RetrieveAPIView with optional relations should return None
- for non existing relations.
- """
- request = factory.get('/optionalrelationmodel-detail/1')
- response = self.detail_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- def test_put_detail_view(self):
- """
- PUT requests to RetrieveUpdateDestroyAPIView with optional relations
- should accept None for non existing relations.
- """
- response = self.client.put('/optionalrelation/1/',
- data=json.dumps(self.data),
- content_type='application/json')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
-
-class TestOverriddenURLField(TestCase):
- def setUp(self):
- class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer):
- url = serializers.SerializerMethodField('get_url')
-
- class Meta:
- model = BlogPost
- fields = ('title', 'url')
-
- def get_url(self, obj):
- return 'foo bar'
-
- self.Serializer = OverriddenURLSerializer
- self.obj = BlogPost.objects.create(title='New blog post')
-
- def test_overridden_url_field(self):
- """
- The 'url' field should respect overriding.
- Regression test for #936.
- """
- serializer = self.Serializer(self.obj)
- self.assertEqual(
- serializer.data,
- {'title': 'New blog post', 'url': 'foo bar'}
- )
-
-
-class TestURLFieldNameBySettings(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- self.saved_url_field_name = api_settings.URL_FIELD_NAME
- api_settings.URL_FIELD_NAME = 'global_url_field'
-
- class Serializer(serializers.HyperlinkedModelSerializer):
-
- class Meta:
- model = BlogPost
- fields = ('title', api_settings.URL_FIELD_NAME)
-
- self.Serializer = Serializer
- self.obj = BlogPost.objects.create(title="New blog post")
-
- def tearDown(self):
- api_settings.URL_FIELD_NAME = self.saved_url_field_name
-
- def test_overridden_url_field_name(self):
- request = factory.get('/posts/')
- serializer = self.Serializer(self.obj, context={'request': request})
- self.assertIn(api_settings.URL_FIELD_NAME, serializer.data)
-
-
-class TestURLFieldNameByOptions(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- class Serializer(serializers.HyperlinkedModelSerializer):
-
- class Meta:
- model = BlogPost
- fields = ('title', 'serializer_url_field')
- url_field_name = 'serializer_url_field'
-
- self.Serializer = Serializer
- self.obj = BlogPost.objects.create(title="New blog post")
-
- def test_overridden_url_field_name(self):
- request = factory.get('/posts/')
- serializer = self.Serializer(self.obj, context={'request': request})
- self.assertIn(self.Serializer.Meta.url_field_name, serializer.data)
diff --git a/rest_framework/tests/test_multitable_inheritance.py b/rest_framework/tests/test_multitable_inheritance.py
deleted file mode 100644
index 00c15327..00000000
--- a/rest_framework/tests/test_multitable_inheritance.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.tests.models import RESTFrameworkModel
-
-
-# Models
-class ParentModel(RESTFrameworkModel):
- name1 = models.CharField(max_length=100)
-
-
-class ChildModel(ParentModel):
- name2 = models.CharField(max_length=100)
-
-
-class AssociatedModel(RESTFrameworkModel):
- ref = models.OneToOneField(ParentModel, primary_key=True)
- name = models.CharField(max_length=100)
-
-
-# Serializers
-class DerivedModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = ChildModel
-
-
-class AssociatedModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = AssociatedModel
-
-
-# Tests
-class IneritedModelSerializationTests(TestCase):
-
- def test_multitable_inherited_model_fields_as_expected(self):
- """
- Assert that the parent pointer field is not included in the fields
- serialized fields
- """
- child = ChildModel(name1='parent name', name2='child name')
- serializer = DerivedModelSerializer(child)
- self.assertEqual(set(serializer.data.keys()),
- set(['name1', 'name2', 'id']))
-
- def test_onetoone_primary_key_model_fields_as_expected(self):
- """
- Assert that a model with a onetoone field that is the primary key is
- not treated like a derived model
- """
- parent = ParentModel(name1='parent name')
- associate = AssociatedModel(name='hello', ref=parent)
- serializer = AssociatedModelSerializer(associate)
- self.assertEqual(set(serializer.data.keys()),
- set(['name', 'ref']))
-
- def test_data_is_valid_without_parent_ptr(self):
- """
- Assert that the pointer to the parent table is not a required field
- for input data
- """
- data = {
- 'name1': 'parent name',
- 'name2': 'child name',
- }
- serializer = DerivedModelSerializer(data=data)
- self.assertEqual(serializer.is_valid(), True)
diff --git a/rest_framework/tests/test_negotiation.py b/rest_framework/tests/test_negotiation.py
deleted file mode 100644
index 04b89eb6..00000000
--- a/rest_framework/tests/test_negotiation.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.negotiation import DefaultContentNegotiation
-from rest_framework.request import Request
-from rest_framework.renderers import BaseRenderer
-from rest_framework.test import APIRequestFactory
-
-
-factory = APIRequestFactory()
-
-
-class MockJSONRenderer(BaseRenderer):
- media_type = 'application/json'
-
-
-class MockHTMLRenderer(BaseRenderer):
- media_type = 'text/html'
-
-
-class NoCharsetSpecifiedRenderer(BaseRenderer):
- media_type = 'my/media'
-
-
-class TestAcceptedMediaType(TestCase):
- def setUp(self):
- self.renderers = [MockJSONRenderer(), MockHTMLRenderer()]
- self.negotiator = DefaultContentNegotiation()
-
- def select_renderer(self, request):
- return self.negotiator.select_renderer(request, self.renderers)
-
- def test_client_without_accept_use_renderer(self):
- request = Request(factory.get('/'))
- accepted_renderer, accepted_media_type = self.select_renderer(request)
- self.assertEqual(accepted_media_type, 'application/json')
-
- def test_client_underspecifies_accept_use_renderer(self):
- request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
- accepted_renderer, accepted_media_type = self.select_renderer(request)
- self.assertEqual(accepted_media_type, 'application/json')
-
- def test_client_overspecifies_accept_use_client(self):
- request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
- accepted_renderer, accepted_media_type = self.select_renderer(request)
- self.assertEqual(accepted_media_type, 'application/json; indent=8')
diff --git a/rest_framework/tests/test_nullable_fields.py b/rest_framework/tests/test_nullable_fields.py
deleted file mode 100644
index 6ee55c00..00000000
--- a/rest_framework/tests/test_nullable_fields.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from django.core.urlresolvers import reverse
-
-from rest_framework.compat import patterns, url
-from rest_framework.test import APITestCase
-from rest_framework.tests.models import NullableForeignKeySource
-from rest_framework.tests.serializers import NullableFKSourceSerializer
-from rest_framework.tests.views import NullableFKSourceDetail
-
-
-urlpatterns = patterns(
- '',
- url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'),
-)
-
-
-class NullableForeignKeyTests(APITestCase):
- """
- DRF should be able to handle nullable foreign keys when a test
- Client POST/PUT request is made with its own serialized object.
- """
- urls = 'rest_framework.tests.test_nullable_fields'
-
- def test_updating_object_with_null_fk(self):
- obj = NullableForeignKeySource(name='example', target=None)
- obj.save()
- serialized_data = NullableFKSourceSerializer(obj).data
-
- response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data)
-
- self.assertEqual(response.data, serialized_data)
diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py
deleted file mode 100644
index 24c1ba39..00000000
--- a/rest_framework/tests/test_pagination.py
+++ /dev/null
@@ -1,521 +0,0 @@
-from __future__ import unicode_literals
-import datetime
-from decimal import Decimal
-from django.db import models
-from django.core.paginator import Paginator
-from django.test import TestCase
-from django.utils import unittest
-from rest_framework import generics, status, pagination, filters, serializers
-from rest_framework.compat import django_filters
-from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import BasicModel
-from .models import FilterableItem
-
-factory = APIRequestFactory()
-
-# Helper function to split arguments out of an url
-def split_arguments_from_url(url):
- if '?' not in url:
- return url
-
- path, args = url.split('?')
- args = dict(r.split('=') for r in args.split('&'))
- return path, args
-
-
-class RootView(generics.ListCreateAPIView):
- """
- Example description for OPTIONS.
- """
- model = BasicModel
- paginate_by = 10
-
-
-class DefaultPageSizeKwargView(generics.ListAPIView):
- """
- View for testing default paginate_by_param usage
- """
- model = BasicModel
-
-
-class PaginateByParamView(generics.ListAPIView):
- """
- View for testing custom paginate_by_param usage
- """
- model = BasicModel
- paginate_by_param = 'page_size'
-
-
-class MaxPaginateByView(generics.ListAPIView):
- """
- View for testing custom max_paginate_by usage
- """
- model = BasicModel
- paginate_by = 3
- max_paginate_by = 5
- paginate_by_param = 'page_size'
-
-
-class IntegrationTestPagination(TestCase):
- """
- Integration tests for paginated list views.
- """
-
- def setUp(self):
- """
- Create 26 BasicModel instances.
- """
- for char in 'abcdefghijklmnopqrstuvwxyz':
- BasicModel(text=char * 3).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = RootView.as_view()
-
- def test_get_paginated_root_view(self):
- """
- GET requests to paginated ListCreateAPIView should return paginated results.
- """
- request = factory.get('/')
- # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
- with self.assertNumQueries(2):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 26)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['next']))
- with self.assertNumQueries(2):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 26)
- self.assertEqual(response.data['results'], self.data[10:20])
- self.assertNotEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['next']))
- with self.assertNumQueries(2):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 26)
- self.assertEqual(response.data['results'], self.data[20:])
- self.assertEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
-
-class IntegrationTestPaginationAndFiltering(TestCase):
-
- def setUp(self):
- """
- Create 50 FilterableItem instances.
- """
- base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
- for i in range(26):
- text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
- decimal = base_data[1] + i
- date = base_data[2] - datetime.timedelta(days=i * 2)
- FilterableItem(text=text, decimal=decimal, date=date).save()
-
- self.objects = FilterableItem.objects
- self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
- ]
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_get_django_filter_paginated_filtered_root_view(self):
- """
- GET requests to paginated filtered ListCreateAPIView should return
- paginated results. The next and previous links should preserve the
- filtered parameters.
- """
- class DecimalFilter(django_filters.FilterSet):
- decimal = django_filters.NumberFilter(lookup_type='lt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- paginate_by = 10
- filter_class = DecimalFilter
- filter_backends = (filters.DjangoFilterBackend,)
-
- view = FilterFieldsRootView.as_view()
-
- EXPECTED_NUM_QUERIES = 2
-
- request = factory.get('/', {'decimal': '15.20'})
- with self.assertNumQueries(EXPECTED_NUM_QUERIES):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['next']))
- with self.assertNumQueries(EXPECTED_NUM_QUERIES):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[10:15])
- self.assertEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['previous']))
- with self.assertNumQueries(EXPECTED_NUM_QUERIES):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- def test_get_basic_paginated_filtered_root_view(self):
- """
- Same as `test_get_django_filter_paginated_filtered_root_view`,
- except using a custom filter backend instead of the django-filter
- backend,
- """
-
- class DecimalFilterBackend(filters.BaseFilterBackend):
- def filter_queryset(self, request, queryset, view):
- return queryset.filter(decimal__lt=Decimal(request.GET['decimal']))
-
- class BasicFilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- paginate_by = 10
- filter_backends = (DecimalFilterBackend,)
-
- view = BasicFilterFieldsRootView.as_view()
-
- request = factory.get('/', {'decimal': '15.20'})
- with self.assertNumQueries(2):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['next']))
- with self.assertNumQueries(2):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[10:15])
- self.assertEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['previous']))
- with self.assertNumQueries(2):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
-
-class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
- class Meta:
- object_serializer_class = serializers.Serializer
-
-
-class UnitTestPagination(TestCase):
- """
- Unit tests for pagination of primitive objects.
- """
-
- def setUp(self):
- self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz']
- paginator = Paginator(self.objects, 10)
- self.first_page = paginator.page(1)
- self.last_page = paginator.page(3)
-
- def test_native_pagination(self):
- serializer = pagination.PaginationSerializer(self.first_page)
- self.assertEqual(serializer.data['count'], 26)
- self.assertEqual(serializer.data['next'], '?page=2')
- self.assertEqual(serializer.data['previous'], None)
- self.assertEqual(serializer.data['results'], self.objects[:10])
-
- serializer = pagination.PaginationSerializer(self.last_page)
- self.assertEqual(serializer.data['count'], 26)
- self.assertEqual(serializer.data['next'], None)
- self.assertEqual(serializer.data['previous'], '?page=2')
- self.assertEqual(serializer.data['results'], self.objects[20:])
-
- def test_context_available_in_result(self):
- """
- Ensure context gets passed through to the object serializer.
- """
- serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
- serializer.data
- results = serializer.fields[serializer.results_field]
- self.assertEqual(serializer.context, results.context)
-
-
-class TestUnpaginated(TestCase):
- """
- Tests for list views without pagination.
- """
-
- def setUp(self):
- """
- Create 13 BasicModel instances.
- """
- for i in range(13):
- BasicModel(text=i).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = DefaultPageSizeKwargView.as_view()
-
- def test_unpaginated(self):
- """
- Tests the default page size for this view.
- no page size --> no limit --> no meta data
- """
- request = factory.get('/')
- response = self.view(request)
- self.assertEqual(response.data, self.data)
-
-
-class TestCustomPaginateByParam(TestCase):
- """
- Tests for list views with default page size kwarg
- """
-
- def setUp(self):
- """
- Create 13 BasicModel instances.
- """
- for i in range(13):
- BasicModel(text=i).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = PaginateByParamView.as_view()
-
- def test_default_page_size(self):
- """
- Tests the default page size for this view.
- no page size --> no limit --> no meta data
- """
- request = factory.get('/')
- response = self.view(request).render()
- self.assertEqual(response.data, self.data)
-
- def test_paginate_by_param(self):
- """
- If paginate_by_param is set, the new kwarg should limit per view requests.
- """
- request = factory.get('/', {'page_size': 5})
- response = self.view(request).render()
- self.assertEqual(response.data['count'], 13)
- self.assertEqual(response.data['results'], self.data[:5])
-
-
-class TestMaxPaginateByParam(TestCase):
- """
- Tests for list views with max_paginate_by kwarg
- """
-
- def setUp(self):
- """
- Create 13 BasicModel instances.
- """
- for i in range(13):
- BasicModel(text=i).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = MaxPaginateByView.as_view()
-
- def test_max_paginate_by(self):
- """
- If max_paginate_by is set, it should limit page size for the view.
- """
- request = factory.get('/', data={'page_size': 10})
- response = self.view(request).render()
- self.assertEqual(response.data['count'], 13)
- self.assertEqual(response.data['results'], self.data[:5])
-
- def test_max_paginate_by_without_page_size_param(self):
- """
- If max_paginate_by is set, but client does not specifiy page_size,
- standard `paginate_by` behavior should be used.
- """
- request = factory.get('/')
- response = self.view(request).render()
- self.assertEqual(response.data['results'], self.data[:3])
-
-
-### Tests for context in pagination serializers
-
-class CustomField(serializers.Field):
- def to_native(self, value):
- if not 'view' in self.context:
- raise RuntimeError("context isn't getting passed into custom field")
- return "value"
-
-
-class BasicModelSerializer(serializers.Serializer):
- text = CustomField()
-
- def __init__(self, *args, **kwargs):
- super(BasicModelSerializer, self).__init__(*args, **kwargs)
- if not 'view' in self.context:
- raise RuntimeError("context isn't getting passed into serializer init")
-
-
-class TestContextPassedToCustomField(TestCase):
- def setUp(self):
- BasicModel.objects.create(text='ala ma kota')
-
- def test_with_pagination(self):
- class ListView(generics.ListCreateAPIView):
- model = BasicModel
- serializer_class = BasicModelSerializer
- paginate_by = 1
-
- self.view = ListView.as_view()
- request = factory.get('/')
- response = self.view(request).render()
-
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
-
-### Tests for custom pagination serializers
-
-class LinksSerializer(serializers.Serializer):
- next = pagination.NextPageField(source='*')
- prev = pagination.PreviousPageField(source='*')
-
-
-class CustomPaginationSerializer(pagination.BasePaginationSerializer):
- links = LinksSerializer(source='*') # Takes the page object as the source
- total_results = serializers.Field(source='paginator.count')
-
- results_field = 'objects'
-
-
-class TestCustomPaginationSerializer(TestCase):
- def setUp(self):
- objects = ['john', 'paul', 'george', 'ringo']
- paginator = Paginator(objects, 2)
- self.page = paginator.page(1)
-
- def test_custom_pagination_serializer(self):
- request = APIRequestFactory().get('/foobar')
- serializer = CustomPaginationSerializer(
- instance=self.page,
- context={'request': request}
- )
- expected = {
- 'links': {
- 'next': 'http://testserver/foobar?page=2',
- 'prev': None
- },
- 'total_results': 4,
- 'objects': ['john', 'paul']
- }
- self.assertEqual(serializer.data, expected)
-
-
-class NonIntegerPage(object):
-
- def __init__(self, paginator, object_list, prev_token, token, next_token):
- self.paginator = paginator
- self.object_list = object_list
- self.prev_token = prev_token
- self.token = token
- self.next_token = next_token
-
- def has_next(self):
- return not not self.next_token
-
- def next_page_number(self):
- return self.next_token
-
- def has_previous(self):
- return not not self.prev_token
-
- def previous_page_number(self):
- return self.prev_token
-
-
-class NonIntegerPaginator(object):
-
- def __init__(self, object_list, per_page):
- self.object_list = object_list
- self.per_page = per_page
-
- def count(self):
- # pretend like we don't know how many pages we have
- return None
-
- def page(self, token=None):
- if token:
- try:
- first = self.object_list.index(token)
- except ValueError:
- first = 0
- else:
- first = 0
- n = len(self.object_list)
- last = min(first + self.per_page, n)
- prev_token = self.object_list[last - (2 * self.per_page)] if first else None
- next_token = self.object_list[last] if last < n else None
- return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token)
-
-
-class TestNonIntegerPagination(TestCase):
-
-
- def test_custom_pagination_serializer(self):
- objects = ['john', 'paul', 'george', 'ringo']
- paginator = NonIntegerPaginator(objects, 2)
-
- request = APIRequestFactory().get('/foobar')
- serializer = CustomPaginationSerializer(
- instance=paginator.page(),
- context={'request': request}
- )
- expected = {
- 'links': {
- 'next': 'http://testserver/foobar?page={0}'.format(objects[2]),
- 'prev': None
- },
- 'total_results': None,
- 'objects': objects[:2]
- }
- self.assertEqual(serializer.data, expected)
-
- request = APIRequestFactory().get('/foobar')
- serializer = CustomPaginationSerializer(
- instance=paginator.page('george'),
- context={'request': request}
- )
- expected = {
- 'links': {
- 'next': None,
- 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]),
- },
- 'total_results': None,
- 'objects': objects[2:]
- }
- self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_parsers.py b/rest_framework/tests/test_parsers.py
deleted file mode 100644
index 7699e10c..00000000
--- a/rest_framework/tests/test_parsers.py
+++ /dev/null
@@ -1,115 +0,0 @@
-from __future__ import unicode_literals
-from rest_framework.compat import StringIO
-from django import forms
-from django.core.files.uploadhandler import MemoryFileUploadHandler
-from django.test import TestCase
-from django.utils import unittest
-from rest_framework.compat import etree
-from rest_framework.parsers import FormParser, FileUploadParser
-from rest_framework.parsers import XMLParser
-import datetime
-
-
-class Form(forms.Form):
- field1 = forms.CharField(max_length=3)
- field2 = forms.CharField()
-
-
-class TestFormParser(TestCase):
- def setUp(self):
- self.string = "field1=abc&field2=defghijk"
-
- def test_parse(self):
- """ Make sure the `QueryDict` works OK """
- parser = FormParser()
-
- stream = StringIO(self.string)
- data = parser.parse(stream)
-
- self.assertEqual(Form(data).is_valid(), True)
-
-
-class TestXMLParser(TestCase):
- def setUp(self):
- self._input = StringIO(
- '<?xml version="1.0" encoding="utf-8"?>'
- '<root>'
- '<field_a>121.0</field_a>'
- '<field_b>dasd</field_b>'
- '<field_c></field_c>'
- '<field_d>2011-12-25 12:45:00</field_d>'
- '</root>'
- )
- self._data = {
- 'field_a': 121,
- 'field_b': 'dasd',
- 'field_c': None,
- 'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
- }
- self._complex_data_input = StringIO(
- '<?xml version="1.0" encoding="utf-8"?>'
- '<root>'
- '<creation_date>2011-12-25 12:45:00</creation_date>'
- '<sub_data_list>'
- '<list-item><sub_id>1</sub_id><sub_name>first</sub_name></list-item>'
- '<list-item><sub_id>2</sub_id><sub_name>second</sub_name></list-item>'
- '</sub_data_list>'
- '<name>name</name>'
- '</root>'
- )
- self._complex_data = {
- "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
- "name": "name",
- "sub_data_list": [
- {
- "sub_id": 1,
- "sub_name": "first"
- },
- {
- "sub_id": 2,
- "sub_name": "second"
- }
- ]
- }
-
- @unittest.skipUnless(etree, 'defusedxml not installed')
- def test_parse(self):
- parser = XMLParser()
- data = parser.parse(self._input)
- self.assertEqual(data, self._data)
-
- @unittest.skipUnless(etree, 'defusedxml not installed')
- def test_complex_data_parse(self):
- parser = XMLParser()
- data = parser.parse(self._complex_data_input)
- self.assertEqual(data, self._complex_data)
-
-
-class TestFileUploadParser(TestCase):
- def setUp(self):
- class MockRequest(object):
- pass
- from io import BytesIO
- self.stream = BytesIO(
- "Test text file".encode('utf-8')
- )
- request = MockRequest()
- request.upload_handlers = (MemoryFileUploadHandler(),)
- request.META = {
- 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'),
- 'HTTP_CONTENT_LENGTH': 14,
- }
- self.parser_context = {'request': request, 'kwargs': {}}
-
- def test_parse(self):
- """ Make sure the `QueryDict` works OK """
- parser = FileUploadParser()
- self.stream.seek(0)
- data_and_files = parser.parse(self.stream, None, self.parser_context)
- file_obj = data_and_files.files['file']
- self.assertEqual(file_obj._size, 14)
-
- def test_get_filename(self):
- parser = FileUploadParser()
- filename = parser.get_filename(self.stream, None, self.parser_context)
- self.assertEqual(filename, 'file.txt'.encode('utf-8'))
diff --git a/rest_framework/tests/test_permissions.py b/rest_framework/tests/test_permissions.py
deleted file mode 100644
index 6e3a6303..00000000
--- a/rest_framework/tests/test_permissions.py
+++ /dev/null
@@ -1,300 +0,0 @@
-from __future__ import unicode_literals
-from django.contrib.auth.models import User, Permission, Group
-from django.db import models
-from django.test import TestCase
-from django.utils import unittest
-from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING
-from rest_framework.compat import guardian, get_model_name
-from rest_framework.filters import DjangoObjectPermissionsFilter
-from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import BasicModel
-import base64
-
-factory = APIRequestFactory()
-
-class RootView(generics.ListCreateAPIView):
- model = BasicModel
- authentication_classes = [authentication.BasicAuthentication]
- permission_classes = [permissions.DjangoModelPermissions]
-
-
-class InstanceView(generics.RetrieveUpdateDestroyAPIView):
- model = BasicModel
- authentication_classes = [authentication.BasicAuthentication]
- permission_classes = [permissions.DjangoModelPermissions]
-
-root_view = RootView.as_view()
-instance_view = InstanceView.as_view()
-
-
-def basic_auth_header(username, password):
- credentials = ('%s:%s' % (username, password))
- base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
- return 'Basic %s' % base64_credentials
-
-
-class ModelPermissionsIntegrationTests(TestCase):
- def setUp(self):
- User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
- user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
- user.user_permissions = [
- Permission.objects.get(codename='add_basicmodel'),
- Permission.objects.get(codename='change_basicmodel'),
- Permission.objects.get(codename='delete_basicmodel')
- ]
- user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
- user.user_permissions = [
- Permission.objects.get(codename='change_basicmodel'),
- ]
-
- self.permitted_credentials = basic_auth_header('permitted', 'password')
- self.disallowed_credentials = basic_auth_header('disallowed', 'password')
- self.updateonly_credentials = basic_auth_header('updateonly', 'password')
-
- BasicModel(text='foo').save()
-
- def test_has_create_permissions(self):
- request = factory.post('/', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- response = root_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
-
- def test_has_put_permissions(self):
- request = factory.put('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_has_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
- response = instance_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
-
- def test_does_not_have_create_permissions(self):
- request = factory.post('/', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = root_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- def test_does_not_have_put_permissions(self):
- request = factory.put('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- def test_does_not_have_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = instance_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- def test_has_put_as_create_permissions(self):
- # User only has update permissions - should be able to update an entity.
- request = factory.put('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.updateonly_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- # But if PUTing to a new entity, permission should be denied.
- request = factory.put('/2', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.updateonly_credentials)
- response = instance_view(request, pk='2')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- def test_options_permitted(self):
- request = factory.options('/',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- response = root_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertIn('actions', response.data)
- self.assertEqual(list(response.data['actions'].keys()), ['POST'])
-
- request = factory.options('/1',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertIn('actions', response.data)
- self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
-
- def test_options_disallowed(self):
- request = factory.options('/',
- HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = root_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertNotIn('actions', response.data)
-
- request = factory.options('/1',
- HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertNotIn('actions', response.data)
-
- def test_options_updateonly(self):
- request = factory.options('/',
- HTTP_AUTHORIZATION=self.updateonly_credentials)
- response = root_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertNotIn('actions', response.data)
-
- request = factory.options('/1',
- HTTP_AUTHORIZATION=self.updateonly_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertIn('actions', response.data)
- self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
-
-
-class BasicPermModel(models.Model):
- text = models.CharField(max_length=100)
-
- class Meta:
- app_label = 'tests'
- permissions = (
- ('view_basicpermmodel', 'Can view basic perm model'),
- # add, change, delete built in to django
- )
-
-# Custom object-level permission, that includes 'view' permissions
-class ViewObjectPermissions(permissions.DjangoObjectPermissions):
- perms_map = {
- 'GET': ['%(app_label)s.view_%(model_name)s'],
- 'OPTIONS': ['%(app_label)s.view_%(model_name)s'],
- 'HEAD': ['%(app_label)s.view_%(model_name)s'],
- 'POST': ['%(app_label)s.add_%(model_name)s'],
- 'PUT': ['%(app_label)s.change_%(model_name)s'],
- 'PATCH': ['%(app_label)s.change_%(model_name)s'],
- 'DELETE': ['%(app_label)s.delete_%(model_name)s'],
- }
-
-
-class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView):
- model = BasicPermModel
- authentication_classes = [authentication.BasicAuthentication]
- permission_classes = [ViewObjectPermissions]
-
-object_permissions_view = ObjectPermissionInstanceView.as_view()
-
-
-class ObjectPermissionListView(generics.ListAPIView):
- model = BasicPermModel
- authentication_classes = [authentication.BasicAuthentication]
- permission_classes = [ViewObjectPermissions]
-
-object_permissions_list_view = ObjectPermissionListView.as_view()
-
-
-@unittest.skipUnless(guardian, 'django-guardian not installed')
-class ObjectPermissionsIntegrationTests(TestCase):
- """
- Integration tests for the object level permissions API.
- """
- @classmethod
- def setUpClass(cls):
- from guardian.shortcuts import assign_perm
-
- # create users
- create = User.objects.create_user
- users = {
- 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'),
- 'readonly': create('readonly', 'readonly@example.com', 'password'),
- 'writeonly': create('writeonly', 'writeonly@example.com', 'password'),
- 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'),
- }
-
- # give everyone model level permissions, as we are not testing those
- everyone = Group.objects.create(name='everyone')
- model_name = get_model_name(BasicPermModel)
- app_label = BasicPermModel._meta.app_label
- f = '{0}_{1}'.format
- perms = {
- 'view': f('view', model_name),
- 'change': f('change', model_name),
- 'delete': f('delete', model_name)
- }
- for perm in perms.values():
- perm = '{0}.{1}'.format(app_label, perm)
- assign_perm(perm, everyone)
- everyone.user_set.add(*users.values())
-
- cls.perms = perms
- cls.users = users
-
- def setUp(self):
- from guardian.shortcuts import assign_perm
- perms = self.perms
- users = self.users
-
- # appropriate object level permissions
- readers = Group.objects.create(name='readers')
- writers = Group.objects.create(name='writers')
- deleters = Group.objects.create(name='deleters')
-
- model = BasicPermModel.objects.create(text='foo')
-
- assign_perm(perms['view'], readers, model)
- assign_perm(perms['change'], writers, model)
- assign_perm(perms['delete'], deleters, model)
-
- readers.user_set.add(users['fullaccess'], users['readonly'])
- writers.user_set.add(users['fullaccess'], users['writeonly'])
- deleters.user_set.add(users['fullaccess'], users['deleteonly'])
-
- self.credentials = {}
- for user in users.values():
- self.credentials[user.username] = basic_auth_header(user.username, 'password')
-
- # Delete
- def test_can_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
-
- def test_cannot_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- # Update
- def test_can_update_permissions(self):
- request = factory.patch('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.credentials['writeonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data.get('text'), 'foobar')
-
- def test_cannot_update_permissions(self):
- request = factory.patch('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.credentials['deleteonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
- def test_cannot_update_permissions_non_existing(self):
- request = factory.patch('/999', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.credentials['deleteonly'])
- response = object_permissions_view(request, pk='999')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
- # Read
- def test_can_read_permissions(self):
- request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_cannot_read_permissions(self):
- request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
- # Read list
- def test_can_read_list_permissions(self):
- request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
- object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
- response = object_permissions_list_view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data[0].get('id'), 1)
-
- def test_cannot_read_list_permissions(self):
- request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
- object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
- response = object_permissions_list_view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertListEqual(response.data, [])
diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py
deleted file mode 100644
index f52e0e1e..00000000
--- a/rest_framework/tests/test_relations.py
+++ /dev/null
@@ -1,120 +0,0 @@
-"""
-General tests for relational fields.
-"""
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.tests.models import BlogPost
-
-
-class NullModel(models.Model):
- pass
-
-
-class FieldTests(TestCase):
- def test_pk_related_field_with_empty_string(self):
- """
- Regression test for #446
-
- https://github.com/tomchristie/django-rest-framework/issues/446
- """
- field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
- self.assertRaises(serializers.ValidationError, field.from_native, '')
- self.assertRaises(serializers.ValidationError, field.from_native, [])
-
- def test_hyperlinked_related_field_with_empty_string(self):
- field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
- self.assertRaises(serializers.ValidationError, field.from_native, '')
- self.assertRaises(serializers.ValidationError, field.from_native, [])
-
- def test_slug_related_field_with_empty_string(self):
- field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
- self.assertRaises(serializers.ValidationError, field.from_native, '')
- self.assertRaises(serializers.ValidationError, field.from_native, [])
-
-
-class TestManyRelatedMixin(TestCase):
- def test_missing_many_to_many_related_field(self):
- '''
- Regression test for #632
-
- https://github.com/tomchristie/django-rest-framework/pull/632
- '''
- field = serializers.RelatedField(many=True, read_only=False)
-
- into = {}
- field.field_from_native({}, None, 'field_name', into)
- self.assertEqual(into['field_name'], [])
-
-
-# Regression tests for #694 (`source` attribute on related fields)
-
-class RelatedFieldSourceTests(TestCase):
- def test_related_manager_source(self):
- """
- Relational fields should be able to use manager-returning methods as their source.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.RelatedField(many=True, source='get_blogposts_manager')
-
- class ClassWithManagerMethod(object):
- def get_blogposts_manager(self):
- return BlogPost.objects
-
- obj = ClassWithManagerMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, ['BlogPost object'])
-
- def test_related_queryset_source(self):
- """
- Relational fields should be able to use queryset-returning methods as their source.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.RelatedField(many=True, source='get_blogposts_queryset')
-
- class ClassWithQuerysetMethod(object):
- def get_blogposts_queryset(self):
- return BlogPost.objects.all()
-
- obj = ClassWithQuerysetMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, ['BlogPost object'])
-
- def test_dotted_source(self):
- """
- Source argument should support dotted.source notation.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.RelatedField(many=True, source='a.b.c')
-
- class ClassWithQuerysetMethod(object):
- a = {
- 'b': {
- 'c': BlogPost.objects.all()
- }
- }
-
- obj = ClassWithQuerysetMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, ['BlogPost object'])
-
- # Regression for #1129
- def test_exception_for_incorect_fk(self):
- """
- Check that the exception message are correct if the source field
- doesn't exist.
- """
- from rest_framework.tests.models import ManyToManySource
- class Meta:
- model = ManyToManySource
- attrs = {
- 'name': serializers.SlugRelatedField(
- slug_field='name', source='banzai'),
- 'Meta': Meta,
- }
-
- TestSerializer = type(str('TestSerializer'),
- (serializers.ModelSerializer,), attrs)
- with self.assertRaises(AttributeError):
- TestSerializer(data={'name': 'foo'})
diff --git a/rest_framework/tests/test_relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py
deleted file mode 100644
index 3c4d39af..00000000
--- a/rest_framework/tests/test_relations_hyperlink.py
+++ /dev/null
@@ -1,524 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.compat import patterns, url
-from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import (
- BlogPost,
- ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
- NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
-)
-
-factory = APIRequestFactory()
-request = factory.get('/') # Just to ensure we have a request in the serializer context
-
-
-def dummy_view(request, pk):
- pass
-
-urlpatterns = patterns('',
- url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
- url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
- url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
- url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
- url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),
- url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
- url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
- url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
-)
-
-
-# ManyToMany
-class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = ManyToManyTarget
- fields = ('url', 'name', 'sources')
-
-
-class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = ManyToManySource
- fields = ('url', 'name', 'targets')
-
-
-# ForeignKey
-class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = ForeignKeyTarget
- fields = ('url', 'name', 'sources')
-
-
-class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = ForeignKeySource
- fields = ('url', 'name', 'target')
-
-
-# Nullable ForeignKey
-class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = NullableForeignKeySource
- fields = ('url', 'name', 'target')
-
-
-# Nullable OneToOne
-class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = OneToOneTarget
- fields = ('url', 'name', 'nullable_source')
-
-
-# TODO: Add test that .data cannot be accessed prior to .is_valid
-
-class HyperlinkedManyToManyTests(TestCase):
- urls = 'rest_framework.tests.test_relations_hyperlink'
-
- def setUp(self):
- for idx in range(1, 4):
- target = ManyToManyTarget(name='target-%d' % idx)
- target.save()
- source = ManyToManySource(name='source-%d' % idx)
- source.save()
- for target in ManyToManyTarget.objects.all():
- source.targets.add(target)
-
- def test_many_to_many_retrieve(self):
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
- {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
- {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_many_to_many_retrieve(self):
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_many_to_many_update(self):
- data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
- instance = ManyToManySource.objects.get(pk=1)
- serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
- {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
- {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_many_to_many_update(self):
- data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
- instance = ManyToManyTarget.objects.get(pk=1)
- serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure target 1 is updated, and everything else is as expected
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
- {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
-
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_many_to_many_create(self):
- data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
- serializer = ManyToManySourceSerializer(data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is added, and everything else is as expected
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
- {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
- {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
- {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_many_to_many_create(self):
- data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
- serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-4')
-
- # Ensure target 4 is added, and everything else is as expected
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
- ]
- self.assertEqual(serializer.data, expected)
-
-
-class HyperlinkedForeignKeyTests(TestCase):
- urls = 'rest_framework.tests.test_relations_hyperlink'
-
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
- for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update(self):
- data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
- {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_incorrect_type(self):
- data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected url string, received int.']})
-
- def test_reverse_foreign_key_update(self):
- data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
- instance = ForeignKeyTarget.objects.get(pk=2)
- serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- # We shouldn't have saved anything to the db yet since save
- # hasn't been called.
- queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
- ]
- self.assertEqual(new_serializer.data, expected)
-
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure target 2 is update, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create(self):
- data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
- serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_foreign_key_create(self):
- data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
- serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-3')
-
- # Ensure target 4 is added, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
- {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_invalid_null(self):
- data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': ['This field is required.']})
-
-
-class HyperlinkedNullableForeignKeyTests(TestCase):
- urls = 'rest_framework.tests.test_relations_hyperlink'
-
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- for idx in range(1, 4):
- if idx == 3:
- target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create_with_valid_null(self):
- data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
- expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, expected_data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_valid_null(self):
- data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
- expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, expected_data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- ]
- self.assertEqual(serializer.data, expected)
-
- # reverse foreign keys MUST be read_only
- # In the general case they do not provide .remove() or .clear()
- # and cannot be arbitrarily set.
-
- # def test_reverse_foreign_key_update(self):
- # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
- # instance = ForeignKeyTarget.objects.get(pk=1)
- # serializer = ForeignKeyTargetSerializer(instance, data=data)
- # self.assertTrue(serializer.is_valid())
- # self.assertEqual(serializer.data, data)
- # serializer.save()
-
- # # Ensure target 1 is updated, and everything else is as expected
- # queryset = ForeignKeyTarget.objects.all()
- # serializer = ForeignKeyTargetSerializer(queryset, many=True)
- # expected = [
- # {'id': 1, 'name': 'target-1', 'sources': [1]},
- # {'id': 2, 'name': 'target-2', 'sources': []},
- # ]
- # self.assertEqual(serializer.data, expected)
-
-
-class HyperlinkedNullableOneToOneTests(TestCase):
- urls = 'rest_framework.tests.test_relations_hyperlink'
-
- def setUp(self):
- target = OneToOneTarget(name='target-1')
- target.save()
- new_target = OneToOneTarget(name='target-2')
- new_target.save()
- source = NullableOneToOneSource(name='source-1', target=target)
- source.save()
-
- def test_reverse_foreign_key_retrieve_with_null(self):
- queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
- {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
- ]
- self.assertEqual(serializer.data, expected)
-
-
-# Regression tests for #694 (`source` attribute on related fields)
-
-class HyperlinkedRelatedFieldSourceTests(TestCase):
- urls = 'rest_framework.tests.test_relations_hyperlink'
-
- def test_related_manager_source(self):
- """
- Relational fields should be able to use manager-returning methods as their source.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.HyperlinkedRelatedField(
- many=True,
- source='get_blogposts_manager',
- view_name='dummy-url',
- )
- field.context = {'request': request}
-
- class ClassWithManagerMethod(object):
- def get_blogposts_manager(self):
- return BlogPost.objects
-
- obj = ClassWithManagerMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, ['http://testserver/dummyurl/1/'])
-
- def test_related_queryset_source(self):
- """
- Relational fields should be able to use queryset-returning methods as their source.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.HyperlinkedRelatedField(
- many=True,
- source='get_blogposts_queryset',
- view_name='dummy-url',
- )
- field.context = {'request': request}
-
- class ClassWithQuerysetMethod(object):
- def get_blogposts_queryset(self):
- return BlogPost.objects.all()
-
- obj = ClassWithQuerysetMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, ['http://testserver/dummyurl/1/'])
-
- def test_dotted_source(self):
- """
- Source argument should support dotted.source notation.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.HyperlinkedRelatedField(
- many=True,
- source='a.b.c',
- view_name='dummy-url',
- )
- field.context = {'request': request}
-
- class ClassWithQuerysetMethod(object):
- a = {
- 'b': {
- 'c': BlogPost.objects.all()
- }
- }
-
- obj = ClassWithQuerysetMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, ['http://testserver/dummyurl/1/'])
diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py
deleted file mode 100644
index 4d9da489..00000000
--- a/rest_framework/tests/test_relations_nested.py
+++ /dev/null
@@ -1,326 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-
-from .models import OneToOneTarget
-
-
-class OneToOneSource(models.Model):
- name = models.CharField(max_length=100)
- target = models.OneToOneField(OneToOneTarget, related_name='source',
- null=True, blank=True)
-
-
-class OneToManyTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class OneToManySource(models.Model):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(OneToManyTarget, related_name='sources')
-
-
-class ReverseNestedOneToOneTests(TestCase):
- def setUp(self):
- class OneToOneSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToOneSource
- fields = ('id', 'name')
-
- class OneToOneTargetSerializer(serializers.ModelSerializer):
- source = OneToOneSourceSerializer()
-
- class Meta:
- model = OneToOneTarget
- fields = ('id', 'name', 'source')
-
- self.Serializer = OneToOneTargetSerializer
-
- for idx in range(1, 4):
- target = OneToOneTarget(name='target-%d' % idx)
- target.save()
- source = OneToOneSource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_one_to_one_retrieve(self):
- queryset = OneToOneTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
- {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
- {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_one_create(self):
- data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
- serializer = self.Serializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-4')
-
- # Ensure (target 4, target_source 4, source 4) are added, and
- # everything else is as expected.
- queryset = OneToOneTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
- {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
- {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}},
- {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_one_create_with_invalid_data(self):
- data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}}
- serializer = self.Serializer(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]})
-
- def test_one_to_one_update(self):
- data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
- instance = OneToOneTarget.objects.get(pk=3)
- serializer = self.Serializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-3-updated')
-
- # Ensure (target 3, target_source 3, source 3) are updated,
- # and everything else is as expected.
- queryset = OneToOneTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
- {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
- {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
- ]
- self.assertEqual(serializer.data, expected)
-
-
-class ForwardNestedOneToOneTests(TestCase):
- def setUp(self):
- class OneToOneTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToOneTarget
- fields = ('id', 'name')
-
- class OneToOneSourceSerializer(serializers.ModelSerializer):
- target = OneToOneTargetSerializer()
-
- class Meta:
- model = OneToOneSource
- fields = ('id', 'name', 'target')
-
- self.Serializer = OneToOneSourceSerializer
-
- for idx in range(1, 4):
- target = OneToOneTarget(name='target-%d' % idx)
- target.save()
- source = OneToOneSource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_one_to_one_retrieve(self):
- queryset = OneToOneSource.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
- {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_one_create(self):
- data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
- serializer = self.Serializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure (target 4, target_source 4, source 4) are added, and
- # everything else is as expected.
- queryset = OneToOneSource.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
- {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}},
- {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_one_create_with_invalid_data(self):
- data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}}
- serializer = self.Serializer(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]})
-
- def test_one_to_one_update(self):
- data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
- instance = OneToOneSource.objects.get(pk=3)
- serializer = self.Serializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-3-updated')
-
- # Ensure (target 3, target_source 3, source 3) are updated,
- # and everything else is as expected.
- queryset = OneToOneSource.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
- {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_one_update_to_null(self):
- data = {'id': 3, 'name': 'source-3-updated', 'target': None}
- instance = OneToOneSource.objects.get(pk=3)
- serializer = self.Serializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
-
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-3-updated')
- self.assertEqual(obj.target, None)
-
- queryset = OneToOneSource.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
- {'id': 3, 'name': 'source-3-updated', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- # TODO: Nullable 1-1 tests
- # def test_one_to_one_delete(self):
- # data = {'id': 3, 'name': 'target-3', 'target_source': None}
- # instance = OneToOneTarget.objects.get(pk=3)
- # serializer = self.Serializer(instance, data=data)
- # self.assertTrue(serializer.is_valid())
- # serializer.save()
-
- # # Ensure (target_source 3, source 3) are deleted,
- # # and everything else is as expected.
- # queryset = OneToOneTarget.objects.all()
- # serializer = self.Serializer(queryset)
- # expected = [
- # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
- # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
- # {'id': 3, 'name': 'target-3', 'source': None}
- # ]
- # self.assertEqual(serializer.data, expected)
-
-
-class ReverseNestedOneToManyTests(TestCase):
- def setUp(self):
- class OneToManySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToManySource
- fields = ('id', 'name')
-
- class OneToManyTargetSerializer(serializers.ModelSerializer):
- sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
-
- class Meta:
- model = OneToManyTarget
- fields = ('id', 'name', 'sources')
-
- self.Serializer = OneToManyTargetSerializer
-
- target = OneToManyTarget(name='target-1')
- target.save()
- for idx in range(1, 4):
- source = OneToManySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_one_to_many_retrieve(self):
- queryset = OneToManyTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
- {'id': 2, 'name': 'source-2'},
- {'id': 3, 'name': 'source-3'}]},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_many_create(self):
- data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
- {'id': 2, 'name': 'source-2'},
- {'id': 3, 'name': 'source-3'},
- {'id': 4, 'name': 'source-4'}]}
- instance = OneToManyTarget.objects.get(pk=1)
- serializer = self.Serializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-1')
-
- # Ensure source 4 is added, and everything else is as
- # expected.
- queryset = OneToManyTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
- {'id': 2, 'name': 'source-2'},
- {'id': 3, 'name': 'source-3'},
- {'id': 4, 'name': 'source-4'}]}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_many_create_with_invalid_data(self):
- data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
- {'id': 2, 'name': 'source-2'},
- {'id': 3, 'name': 'source-3'},
- {'id': 4}]}
- serializer = self.Serializer(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
-
- def test_one_to_many_update(self):
- data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
- {'id': 2, 'name': 'source-2'},
- {'id': 3, 'name': 'source-3'}]}
- instance = OneToManyTarget.objects.get(pk=1)
- serializer = self.Serializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-1-updated')
-
- # Ensure (target 1, source 1) are updated,
- # and everything else is as expected.
- queryset = OneToManyTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
- {'id': 2, 'name': 'source-2'},
- {'id': 3, 'name': 'source-3'}]}
-
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_many_delete(self):
- data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
- {'id': 3, 'name': 'source-3'}]}
- instance = OneToManyTarget.objects.get(pk=1)
- serializer = self.Serializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- serializer.save()
-
- # Ensure source 2 is deleted, and everything else is as
- # expected.
- queryset = OneToManyTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
- {'id': 3, 'name': 'source-3'}]}
-
- ]
- self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_relations_pk.py b/rest_framework/tests/test_relations_pk.py
deleted file mode 100644
index 3815afdd..00000000
--- a/rest_framework/tests/test_relations_pk.py
+++ /dev/null
@@ -1,551 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.tests.models import (
- BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
- NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource,
-)
-from rest_framework.compat import six
-
-
-# ManyToMany
-class ManyToManyTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = ManyToManyTarget
- fields = ('id', 'name', 'sources')
-
-
-class ManyToManySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = ManyToManySource
- fields = ('id', 'name', 'targets')
-
-
-# ForeignKey
-class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeyTarget
- fields = ('id', 'name', 'sources')
-
-
-class ForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeySource
- fields = ('id', 'name', 'target')
-
-
-# Nullable ForeignKey
-class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableForeignKeySource
- fields = ('id', 'name', 'target')
-
-
-# Nullable OneToOne
-class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToOneTarget
- fields = ('id', 'name', 'nullable_source')
-
-
-# TODO: Add test that .data cannot be accessed prior to .is_valid
-
-class PKManyToManyTests(TestCase):
- def setUp(self):
- for idx in range(1, 4):
- target = ManyToManyTarget(name='target-%d' % idx)
- target.save()
- source = ManyToManySource(name='source-%d' % idx)
- source.save()
- for target in ManyToManyTarget.objects.all():
- source.targets.add(target)
-
- def test_many_to_many_retrieve(self):
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'targets': [1]},
- {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_many_to_many_retrieve(self):
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': 'target-3', 'sources': [3]}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_many_to_many_update(self):
- data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
- instance = ManyToManySource.objects.get(pk=1)
- serializer = ManyToManySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
- {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_many_to_many_update(self):
- data = {'id': 1, 'name': 'target-1', 'sources': [1]}
- instance = ManyToManyTarget.objects.get(pk=1)
- serializer = ManyToManyTargetSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure target 1 is updated, and everything else is as expected
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1]},
- {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': 'target-3', 'sources': [3]}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_many_to_many_create(self):
- data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
- serializer = ManyToManySourceSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is added, and everything else is as expected
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True)
- self.assertFalse(serializer.fields['targets'].read_only)
- expected = [
- {'id': 1, 'name': 'source-1', 'targets': [1]},
- {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
- {'id': 4, 'name': 'source-4', 'targets': [1, 3]},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_many_to_many_create(self):
- data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
- serializer = ManyToManyTargetSerializer(data=data)
- self.assertFalse(serializer.fields['sources'].read_only)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-4')
-
- # Ensure target 4 is added, and everything else is as expected
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': 'target-3', 'sources': [3]},
- {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
- ]
- self.assertEqual(serializer.data, expected)
-
-
-class PKForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
- for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': []},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update(self):
- data = {'id': 1, 'name': 'source-1', 'target': 2}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 2},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_incorrect_type(self):
- data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]})
-
- def test_reverse_foreign_key_update(self):
- data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
- instance = ForeignKeyTarget.objects.get(pk=2)
- serializer = ForeignKeyTargetSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- # We shouldn't have saved anything to the db yet since save
- # hasn't been called.
- queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': []},
- ]
- self.assertEqual(new_serializer.data, expected)
-
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure target 2 is update, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [2]},
- {'id': 2, 'name': 'target-2', 'sources': [1, 3]},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create(self):
- data = {'id': 4, 'name': 'source-4', 'target': 2}
- serializer = ForeignKeySourceSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is added, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1},
- {'id': 4, 'name': 'source-4', 'target': 2},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_foreign_key_create(self):
- data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
- serializer = ForeignKeyTargetSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-3')
-
- # Ensure target 3 is added, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [2]},
- {'id': 2, 'name': 'target-2', 'sources': []},
- {'id': 3, 'name': 'target-3', 'sources': [1, 3]},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_invalid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': ['This field is required.']})
-
- def test_foreign_key_with_empty(self):
- """
- Regression test for #1072
-
- https://github.com/tomchristie/django-rest-framework/issues/1072
- """
- serializer = NullableForeignKeySourceSerializer()
- self.assertEqual(serializer.data['target'], None)
-
-
-class PKNullableForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- for idx in range(1, 4):
- if idx == 3:
- target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create_with_valid_null(self):
- data = {'id': 4, 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'id': 4, 'name': 'source-4', 'target': ''}
- expected_data = {'id': 4, 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, expected_data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_valid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'id': 1, 'name': 'source-1', 'target': ''}
- expected_data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, expected_data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- # reverse foreign keys MUST be read_only
- # In the general case they do not provide .remove() or .clear()
- # and cannot be arbitrarily set.
-
- # def test_reverse_foreign_key_update(self):
- # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
- # instance = ForeignKeyTarget.objects.get(pk=1)
- # serializer = ForeignKeyTargetSerializer(instance, data=data)
- # self.assertTrue(serializer.is_valid())
- # self.assertEqual(serializer.data, data)
- # serializer.save()
-
- # # Ensure target 1 is updated, and everything else is as expected
- # queryset = ForeignKeyTarget.objects.all()
- # serializer = ForeignKeyTargetSerializer(queryset, many=True)
- # expected = [
- # {'id': 1, 'name': 'target-1', 'sources': [1]},
- # {'id': 2, 'name': 'target-2', 'sources': []},
- # ]
- # self.assertEqual(serializer.data, expected)
-
-
-class PKNullableOneToOneTests(TestCase):
- def setUp(self):
- target = OneToOneTarget(name='target-1')
- target.save()
- new_target = OneToOneTarget(name='target-2')
- new_target.save()
- source = NullableOneToOneSource(name='source-1', target=new_target)
- source.save()
-
- def test_reverse_foreign_key_retrieve_with_null(self):
- queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'nullable_source': None},
- {'id': 2, 'name': 'target-2', 'nullable_source': 1},
- ]
- self.assertEqual(serializer.data, expected)
-
-
-# The below models and tests ensure that serializer fields corresponding
-# to a ManyToManyField field with a user-specified ``through`` model are
-# set to read only
-
-
-class ManyToManyThroughTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class ManyToManyThrough(models.Model):
- source = models.ForeignKey('ManyToManyThroughSource')
- target = models.ForeignKey(ManyToManyThroughTarget)
-
-
-class ManyToManyThroughSource(models.Model):
- name = models.CharField(max_length=100)
- targets = models.ManyToManyField(ManyToManyThroughTarget,
- related_name='sources',
- through='ManyToManyThrough')
-
-
-class ManyToManyThroughTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = ManyToManyThroughTarget
- fields = ('id', 'name', 'sources')
-
-
-class ManyToManyThroughSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = ManyToManyThroughSource
- fields = ('id', 'name', 'targets')
-
-
-class PKManyToManyThroughTests(TestCase):
- def setUp(self):
- self.source = ManyToManyThroughSource.objects.create(
- name='through-source-1')
- self.target = ManyToManyThroughTarget.objects.create(
- name='through-target-1')
-
- def test_many_to_many_create(self):
- data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]}
- serializer = ManyToManyThroughSourceSerializer(data=data)
- self.assertTrue(serializer.fields['targets'].read_only)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(obj.name, 'source-2')
- self.assertEqual(obj.targets.count(), 0)
-
- def test_many_to_many_reverse_create(self):
- data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]}
- serializer = ManyToManyThroughTargetSerializer(data=data)
- self.assertTrue(serializer.fields['sources'].read_only)
- self.assertTrue(serializer.is_valid())
- serializer.save()
- obj = serializer.save()
- self.assertEqual(obj.name, 'target-2')
- self.assertEqual(obj.sources.count(), 0)
-
-
-# Regression tests for #694 (`source` attribute on related fields)
-
-
-class PrimaryKeyRelatedFieldSourceTests(TestCase):
- def test_related_manager_source(self):
- """
- Relational fields should be able to use manager-returning methods as their source.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager')
-
- class ClassWithManagerMethod(object):
- def get_blogposts_manager(self):
- return BlogPost.objects
-
- obj = ClassWithManagerMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, [1])
-
- def test_related_queryset_source(self):
- """
- Relational fields should be able to use queryset-returning methods as their source.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset')
-
- class ClassWithQuerysetMethod(object):
- def get_blogposts_queryset(self):
- return BlogPost.objects.all()
-
- obj = ClassWithQuerysetMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, [1])
-
- def test_dotted_source(self):
- """
- Source argument should support dotted.source notation.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c')
-
- class ClassWithQuerysetMethod(object):
- a = {
- 'b': {
- 'c': BlogPost.objects.all()
- }
- }
-
- obj = ClassWithQuerysetMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, [1])
diff --git a/rest_framework/tests/test_relations_slug.py b/rest_framework/tests/test_relations_slug.py
deleted file mode 100644
index 435c821c..00000000
--- a/rest_framework/tests/test_relations_slug.py
+++ /dev/null
@@ -1,257 +0,0 @@
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
-
-
-class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.SlugRelatedField(many=True, slug_field='name')
-
- class Meta:
- model = ForeignKeyTarget
-
-
-class ForeignKeySourceSerializer(serializers.ModelSerializer):
- target = serializers.SlugRelatedField(slug_field='name')
-
- class Meta:
- model = ForeignKeySource
-
-
-class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
- target = serializers.SlugRelatedField(slug_field='name', required=False)
-
- class Meta:
- model = NullableForeignKeySource
-
-
-# TODO: M2M Tests, FKTests (Non-nullable), One2One
-class SlugForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
- for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': 'target-1'}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
- {'id': 2, 'name': 'target-2', 'sources': []},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update(self):
- data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-2'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': 'target-1'}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_incorrect_type(self):
- data = {'id': 1, 'name': 'source-1', 'target': 123}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']})
-
- def test_reverse_foreign_key_update(self):
- data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
- instance = ForeignKeyTarget.objects.get(pk=2)
- serializer = ForeignKeyTargetSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- # We shouldn't have saved anything to the db yet since save
- # hasn't been called.
- queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
- {'id': 2, 'name': 'target-2', 'sources': []},
- ]
- self.assertEqual(new_serializer.data, expected)
-
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure target 2 is update, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
- {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create(self):
- data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
- serializer = ForeignKeySourceSerializer(data=data)
- serializer.is_valid()
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is added, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': 'target-1'},
- {'id': 4, 'name': 'source-4', 'target': 'target-2'},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_foreign_key_create(self):
- data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
- serializer = ForeignKeyTargetSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-3')
-
- # Ensure target 3 is added, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
- {'id': 2, 'name': 'target-2', 'sources': []},
- {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_invalid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': ['This field is required.']})
-
-
-class SlugNullableForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- for idx in range(1, 4):
- if idx == 3:
- target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create_with_valid_null(self):
- data = {'id': 4, 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'id': 4, 'name': 'source-4', 'target': ''}
- expected_data = {'id': 4, 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, expected_data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_valid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'id': 1, 'name': 'source-1', 'target': ''}
- expected_data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, expected_data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py
deleted file mode 100644
index c7bf772e..00000000
--- a/rest_framework/tests/test_renderers.py
+++ /dev/null
@@ -1,655 +0,0 @@
-# -*- coding: utf-8 -*-
-from __future__ import unicode_literals
-
-from decimal import Decimal
-from django.core.cache import cache
-from django.db import models
-from django.test import TestCase
-from django.utils import unittest
-from django.utils.translation import ugettext_lazy as _
-from rest_framework import status, permissions
-from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO
-from rest_framework.response import Response
-from rest_framework.views import APIView
-from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
- XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer
-from rest_framework.parsers import YAMLParser, XMLParser
-from rest_framework.settings import api_settings
-from rest_framework.test import APIRequestFactory
-from collections import MutableMapping
-import datetime
-import json
-import pickle
-import re
-
-
-DUMMYSTATUS = status.HTTP_200_OK
-DUMMYCONTENT = 'dummycontent'
-
-RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
-RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
-
-
-expected_results = [
- ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]') # Generator
-]
-
-
-class DummyTestModel(models.Model):
- name = models.CharField(max_length=42, default='')
-
-
-class BasicRendererTests(TestCase):
- def test_expected_results(self):
- for value, renderer_cls, expected in expected_results:
- output = renderer_cls().render(value)
- self.assertEqual(output, expected)
-
-
-class RendererA(BaseRenderer):
- media_type = 'mock/renderera'
- format = "formata"
-
- def render(self, data, media_type=None, renderer_context=None):
- return RENDERER_A_SERIALIZER(data)
-
-
-class RendererB(BaseRenderer):
- media_type = 'mock/rendererb'
- format = "formatb"
-
- def render(self, data, media_type=None, renderer_context=None):
- return RENDERER_B_SERIALIZER(data)
-
-
-class MockView(APIView):
- renderer_classes = (RendererA, RendererB)
-
- def get(self, request, **kwargs):
- response = Response(DUMMYCONTENT, status=DUMMYSTATUS)
- return response
-
-
-class MockGETView(APIView):
- def get(self, request, **kwargs):
- return Response({'foo': ['bar', 'baz']})
-
-
-
-class MockPOSTView(APIView):
- def post(self, request, **kwargs):
- return Response({'foo': request.DATA})
-
-
-class EmptyGETView(APIView):
- renderer_classes = (JSONRenderer,)
-
- def get(self, request, **kwargs):
- return Response(status=status.HTTP_204_NO_CONTENT)
-
-
-class HTMLView(APIView):
- renderer_classes = (BrowsableAPIRenderer, )
-
- def get(self, request, **kwargs):
- return Response('text')
-
-
-class HTMLView1(APIView):
- renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
-
- def get(self, request, **kwargs):
- return Response('text')
-
-urlpatterns = patterns('',
- url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
- url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
- url(r'^cache$', MockGETView.as_view()),
- url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])),
- url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])),
- url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),
- url(r'^html$', HTMLView.as_view()),
- url(r'^html1$', HTMLView1.as_view()),
- url(r'^empty$', EmptyGETView.as_view()),
- url(r'^api', include('rest_framework.urls', namespace='rest_framework'))
-)
-
-
-class POSTDeniedPermission(permissions.BasePermission):
- def has_permission(self, request, view):
- return request.method != 'POST'
-
-
-class POSTDeniedView(APIView):
- renderer_classes = (BrowsableAPIRenderer,)
- permission_classes = (POSTDeniedPermission,)
-
- def get(self, request):
- return Response()
-
- def post(self, request):
- return Response()
-
- def put(self, request):
- return Response()
-
- def patch(self, request):
- return Response()
-
-
-class DocumentingRendererTests(TestCase):
- def test_only_permitted_forms_are_displayed(self):
- view = POSTDeniedView.as_view()
- request = APIRequestFactory().get('/')
- response = view(request).render()
- self.assertNotContains(response, '>POST<')
- self.assertContains(response, '>PUT<')
- self.assertContains(response, '>PATCH<')
-
-
-class RendererEndToEndTests(TestCase):
- """
- End-to-end testing of renderers using an RendererMixin on a generic view.
- """
-
- urls = 'rest_framework.tests.test_renderers'
-
- def test_default_renderer_serializes_content(self):
- """If the Accept header is not set the default renderer should serialize the response."""
- resp = self.client.get('/')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_head_method_serializes_no_content(self):
- """No response must be included in HEAD requests."""
- resp = self.client.head('/')
- self.assertEqual(resp.status_code, DUMMYSTATUS)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, six.b(''))
-
- def test_default_renderer_serializes_content_on_accept_any(self):
- """If the Accept header is set to */* the default renderer should serialize the response."""
- resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_default_case(self):
- """If the Accept header is set the specified renderer should serialize the response.
- (In this case we check that works for the default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_non_default_case(self):
- """If the Accept header is set the specified renderer should serialize the response.
- (In this case we check that works for a non-default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_on_accept_query(self):
- """The '_accept' query string should behave in the same way as the Accept header."""
- param = '?%s=%s' % (
- api_settings.URL_ACCEPT_OVERRIDE,
- RendererB.media_type
- )
- resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
- """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
- resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
- self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
-
- def test_specified_renderer_serializes_content_on_format_query(self):
- """If a 'format' query is specified, the renderer with the matching
- format attribute should serialize the response."""
- param = '?%s=%s' % (
- api_settings.URL_FORMAT_OVERRIDE,
- RendererB.format
- )
- resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_on_format_kwargs(self):
- """If a 'format' keyword arg is specified, the renderer with the matching
- format attribute should serialize the response."""
- resp = self.client.get('/something.formatb')
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
- """If both a 'format' query and a matching Accept header specified,
- the renderer with the matching format attribute should serialize the response."""
- param = '?%s=%s' % (
- api_settings.URL_FORMAT_OVERRIDE,
- RendererB.format
- )
- resp = self.client.get('/' + param,
- HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_parse_error_renderers_browsable_api(self):
- """Invalid data should still render the browsable API correctly."""
- resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
- self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
-
- def test_204_no_content_responses_have_no_content_type_set(self):
- """
- Regression test for #1196
-
- https://github.com/tomchristie/django-rest-framework/issues/1196
- """
- resp = self.client.get('/empty')
- self.assertEqual(resp.get('Content-Type', None), None)
- self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
-
- def test_contains_headers_of_api_response(self):
- """
- Issue #1437
-
- Test we display the headers of the API response and not those from the
- HTML response
- """
- resp = self.client.get('/html1')
- self.assertContains(resp, '>GET, HEAD, OPTIONS<')
- self.assertContains(resp, '>application/json<')
- self.assertNotContains(resp, '>text/html; charset=utf-8<')
-
-
-_flat_repr = '{"foo": ["bar", "baz"]}'
-_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}'
-
-
-def strip_trailing_whitespace(content):
- """
- Seems to be some inconsistencies re. trailing whitespace with
- different versions of the json lib.
- """
- return re.sub(' +\n', '\n', content)
-
-
-class JSONRendererTests(TestCase):
- """
- Tests specific to the JSON Renderer
- """
-
- def test_render_lazy_strings(self):
- """
- JSONRenderer should deal with lazy translated strings.
- """
- ret = JSONRenderer().render(_('test'))
- self.assertEqual(ret, b'"test"')
-
- def test_render_queryset_values(self):
- o = DummyTestModel.objects.create(name='dummy')
- qs = DummyTestModel.objects.values('id', 'name')
- ret = JSONRenderer().render(qs)
- data = json.loads(ret.decode('utf-8'))
- self.assertEquals(data, [{'id': o.id, 'name': o.name}])
-
- def test_render_queryset_values_list(self):
- o = DummyTestModel.objects.create(name='dummy')
- qs = DummyTestModel.objects.values_list('id', 'name')
- ret = JSONRenderer().render(qs)
- data = json.loads(ret.decode('utf-8'))
- self.assertEquals(data, [[o.id, o.name]])
-
- def test_render_dict_abc_obj(self):
- class Dict(MutableMapping):
- def __init__(self):
- self._dict = dict()
- def __getitem__(self, key):
- return self._dict.__getitem__(key)
- def __setitem__(self, key, value):
- return self._dict.__setitem__(key, value)
- def __delitem__(self, key):
- return self._dict.__delitem__(key)
- def __iter__(self):
- return self._dict.__iter__()
- def __len__(self):
- return self._dict.__len__()
- def keys(self):
- return self._dict.keys()
-
- x = Dict()
- x['key'] = 'string value'
- x[2] = 3
- ret = JSONRenderer().render(x)
- data = json.loads(ret.decode('utf-8'))
- self.assertEquals(data, {'key': 'string value', '2': 3})
-
- def test_render_obj_with_getitem(self):
- class DictLike(object):
- def __init__(self):
- self._dict = {}
- def set(self, value):
- self._dict = dict(value)
- def __getitem__(self, key):
- return self._dict[key]
-
- x = DictLike()
- x.set({'a': 1, 'b': 'string'})
- with self.assertRaises(TypeError):
- JSONRenderer().render(x)
-
- def test_without_content_type_args(self):
- """
- Test basic JSON rendering.
- """
- obj = {'foo': ['bar', 'baz']}
- renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json')
- # Fix failing test case which depends on version of JSON library.
- self.assertEqual(content.decode('utf-8'), _flat_repr)
-
- def test_with_content_type_args(self):
- """
- Test JSON rendering with additional content type arguments supplied.
- """
- obj = {'foo': ['bar', 'baz']}
- renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json; indent=2')
- self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr)
-
- def test_check_ascii(self):
- obj = {'countries': ['United Kingdom', 'France', 'España']}
- renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json')
- self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8'))
-
-
-class UnicodeJSONRendererTests(TestCase):
- """
- Tests specific for the Unicode JSON Renderer
- """
- def test_proper_encoding(self):
- obj = {'countries': ['United Kingdom', 'France', 'España']}
- renderer = UnicodeJSONRenderer()
- content = renderer.render(obj, 'application/json')
- self.assertEqual(content, '{"countries": ["United Kingdom", "France", "España"]}'.encode('utf-8'))
-
-
-class JSONPRendererTests(TestCase):
- """
- Tests specific to the JSONP Renderer
- """
-
- urls = 'rest_framework.tests.test_renderers'
-
- def test_without_callback_with_json_renderer(self):
- """
- Test JSONP rendering with View JSON Renderer.
- """
- resp = self.client.get('/jsonp/jsonrenderer',
- HTTP_ACCEPT='application/javascript')
- self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
- self.assertEqual(resp.content,
- ('callback(%s);' % _flat_repr).encode('ascii'))
-
- def test_without_callback_without_json_renderer(self):
- """
- Test JSONP rendering without View JSON Renderer.
- """
- resp = self.client.get('/jsonp/nojsonrenderer',
- HTTP_ACCEPT='application/javascript')
- self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
- self.assertEqual(resp.content,
- ('callback(%s);' % _flat_repr).encode('ascii'))
-
- def test_with_callback(self):
- """
- Test JSONP rendering with callback function name.
- """
- callback_func = 'myjsonpcallback'
- resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func,
- HTTP_ACCEPT='application/javascript')
- self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
- self.assertEqual(resp.content,
- ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii'))
-
-
-if yaml:
- _yaml_repr = 'foo: [bar, baz]\n'
-
- class YAMLRendererTests(TestCase):
- """
- Tests specific to the YAML Renderer
- """
-
- def test_render(self):
- """
- Test basic YAML rendering.
- """
- obj = {'foo': ['bar', 'baz']}
- renderer = YAMLRenderer()
- content = renderer.render(obj, 'application/yaml')
- self.assertEqual(content, _yaml_repr)
-
- def test_render_and_parse(self):
- """
- Test rendering and then parsing returns the original object.
- IE obj -> render -> parse -> obj.
- """
- obj = {'foo': ['bar', 'baz']}
-
- renderer = YAMLRenderer()
- parser = YAMLParser()
-
- content = renderer.render(obj, 'application/yaml')
- data = parser.parse(StringIO(content))
- self.assertEqual(obj, data)
-
- def test_render_decimal(self):
- """
- Test YAML decimal rendering.
- """
- renderer = YAMLRenderer()
- content = renderer.render({'field': Decimal('111.2')}, 'application/yaml')
- self.assertYAMLContains(content, "field: '111.2'")
-
- def assertYAMLContains(self, content, string):
- self.assertTrue(string in content, '%r not in %r' % (string, content))
-
-
-class XMLRendererTestCase(TestCase):
- """
- Tests specific to the XML Renderer
- """
-
- _complex_data = {
- "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
- "name": "name",
- "sub_data_list": [
- {
- "sub_id": 1,
- "sub_name": "first"
- },
- {
- "sub_id": 2,
- "sub_name": "second"
- }
- ]
- }
-
- def test_render_string(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': 'astring'}, 'application/xml')
- self.assertXMLContains(content, '<field>astring</field>')
-
- def test_render_integer(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': 111}, 'application/xml')
- self.assertXMLContains(content, '<field>111</field>')
-
- def test_render_datetime(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({
- 'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
- }, 'application/xml')
- self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>')
-
- def test_render_float(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': 123.4}, 'application/xml')
- self.assertXMLContains(content, '<field>123.4</field>')
-
- def test_render_decimal(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
- self.assertXMLContains(content, '<field>111.2</field>')
-
- def test_render_none(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': None}, 'application/xml')
- self.assertXMLContains(content, '<field></field>')
-
- def test_render_complex_data(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render(self._complex_data, 'application/xml')
- self.assertXMLContains(content, '<sub_name>first</sub_name>')
- self.assertXMLContains(content, '<sub_name>second</sub_name>')
-
- @unittest.skipUnless(etree, 'defusedxml not installed')
- def test_render_and_parse_complex_data(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = StringIO(renderer.render(self._complex_data, 'application/xml'))
-
- parser = XMLParser()
- complex_data_out = parser.parse(content)
- error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
- self.assertEqual(self._complex_data, complex_data_out, error_msg)
-
- def assertXMLContains(self, xml, string):
- self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
- self.assertTrue(xml.endswith('</root>'))
- self.assertTrue(string in xml, '%r not in %r' % (string, xml))
-
-
-# Tests for caching issue, #346
-class CacheRenderTest(TestCase):
- """
- Tests specific to caching responses
- """
-
- urls = 'rest_framework.tests.test_renderers'
-
- cache_key = 'just_a_cache_key'
-
- @classmethod
- def _get_pickling_errors(cls, obj, seen=None):
- """ Return any errors that would be raised if `obj' is pickled
- Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897
- """
- if seen == None:
- seen = []
- try:
- state = obj.__getstate__()
- except AttributeError:
- return
- if state == None:
- return
- if isinstance(state, tuple):
- if not isinstance(state[0], dict):
- state = state[1]
- else:
- state = state[0].update(state[1])
- result = {}
- for i in state:
- try:
- pickle.dumps(state[i], protocol=2)
- except pickle.PicklingError:
- if not state[i] in seen:
- seen.append(state[i])
- result[i] = cls._get_pickling_errors(state[i], seen)
- return result
-
- def http_resp(self, http_method, url):
- """
- Simple wrapper for Client http requests
- Removes the `client' and `request' attributes from as they are
- added by django.test.client.Client and not part of caching
- responses outside of tests.
- """
- method = getattr(self.client, http_method)
- resp = method(url)
- del resp.client, resp.request
- try:
- del resp.wsgi_request
- except AttributeError:
- pass
- return resp
-
- def test_obj_pickling(self):
- """
- Test that responses are properly pickled
- """
- resp = self.http_resp('get', '/cache')
-
- # Make sure that no pickling errors occurred
- self.assertEqual(self._get_pickling_errors(resp), {})
-
- # Unfortunately LocMem backend doesn't raise PickleErrors but returns
- # None instead.
- cache.set(self.cache_key, resp)
- self.assertTrue(cache.get(self.cache_key) is not None)
-
- def test_head_caching(self):
- """
- Test caching of HEAD requests
- """
- resp = self.http_resp('head', '/cache')
- cache.set(self.cache_key, resp)
-
- cached_resp = cache.get(self.cache_key)
- self.assertIsInstance(cached_resp, Response)
-
- def test_get_caching(self):
- """
- Test caching of GET requests
- """
- resp = self.http_resp('get', '/cache')
- cache.set(self.cache_key, resp)
-
- cached_resp = cache.get(self.cache_key)
- self.assertIsInstance(cached_resp, Response)
- self.assertEqual(cached_resp.content, resp.content)
diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py
deleted file mode 100644
index c0b50f33..00000000
--- a/rest_framework/tests/test_request.py
+++ /dev/null
@@ -1,347 +0,0 @@
-"""
-Tests for content parsing, and form-overloaded content parsing.
-"""
-from __future__ import unicode_literals
-from django.contrib.auth.models import User
-from django.contrib.auth import authenticate, login, logout
-from django.contrib.sessions.middleware import SessionMiddleware
-from django.core.handlers.wsgi import WSGIRequest
-from django.test import TestCase
-from rest_framework import status
-from rest_framework.authentication import SessionAuthentication
-from rest_framework.compat import patterns
-from rest_framework.parsers import (
- BaseParser,
- FormParser,
- MultiPartParser,
- JSONParser
-)
-from rest_framework.request import Request, Empty
-from rest_framework.response import Response
-from rest_framework.settings import api_settings
-from rest_framework.test import APIRequestFactory, APIClient
-from rest_framework.views import APIView
-from rest_framework.compat import six
-from io import BytesIO
-import json
-
-
-factory = APIRequestFactory()
-
-
-class PlainTextParser(BaseParser):
- media_type = 'text/plain'
-
- def parse(self, stream, media_type=None, parser_context=None):
- """
- Returns a 2-tuple of `(data, files)`.
-
- `data` will simply be a string representing the body of the request.
- `files` will always be `None`.
- """
- return stream.read()
-
-
-class TestMethodOverloading(TestCase):
- def test_method(self):
- """
- Request methods should be same as underlying request.
- """
- request = Request(factory.get('/'))
- self.assertEqual(request.method, 'GET')
- request = Request(factory.post('/'))
- self.assertEqual(request.method, 'POST')
-
- def test_overloaded_method(self):
- """
- POST requests can be overloaded to another method by setting a
- reserved form field
- """
- request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'}))
- self.assertEqual(request.method, 'DELETE')
-
- def test_x_http_method_override_header(self):
- """
- POST requests can also be overloaded to another method by setting
- the X-HTTP-Method-Override header.
- """
- request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
- self.assertEqual(request.method, 'DELETE')
-
- request = Request(factory.get('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
- self.assertEqual(request.method, 'DELETE')
-
-
-class TestContentParsing(TestCase):
- def test_standard_behaviour_determines_no_content_GET(self):
- """
- Ensure request.DATA returns empty QueryDict for GET request.
- """
- request = Request(factory.get('/'))
- self.assertEqual(request.DATA, {})
-
- def test_standard_behaviour_determines_no_content_HEAD(self):
- """
- Ensure request.DATA returns empty QueryDict for HEAD request.
- """
- request = Request(factory.head('/'))
- self.assertEqual(request.DATA, {})
-
- def test_request_DATA_with_form_content(self):
- """
- Ensure request.DATA returns content for POST request with form content.
- """
- data = {'qwerty': 'uiop'}
- request = Request(factory.post('/', data))
- request.parsers = (FormParser(), MultiPartParser())
- self.assertEqual(list(request.DATA.items()), list(data.items()))
-
- def test_request_DATA_with_text_content(self):
- """
- Ensure request.DATA returns content for POST request with
- non-form content.
- """
- content = six.b('qwerty')
- content_type = 'text/plain'
- request = Request(factory.post('/', content, content_type=content_type))
- request.parsers = (PlainTextParser(),)
- self.assertEqual(request.DATA, content)
-
- def test_request_POST_with_form_content(self):
- """
- Ensure request.POST returns content for POST request with form content.
- """
- data = {'qwerty': 'uiop'}
- request = Request(factory.post('/', data))
- request.parsers = (FormParser(), MultiPartParser())
- self.assertEqual(list(request.POST.items()), list(data.items()))
-
- def test_standard_behaviour_determines_form_content_PUT(self):
- """
- Ensure request.DATA returns content for PUT request with form content.
- """
- data = {'qwerty': 'uiop'}
- request = Request(factory.put('/', data))
- request.parsers = (FormParser(), MultiPartParser())
- self.assertEqual(list(request.DATA.items()), list(data.items()))
-
- def test_standard_behaviour_determines_non_form_content_PUT(self):
- """
- Ensure request.DATA returns content for PUT request with
- non-form content.
- """
- content = six.b('qwerty')
- content_type = 'text/plain'
- request = Request(factory.put('/', content, content_type=content_type))
- request.parsers = (PlainTextParser(), )
- self.assertEqual(request.DATA, content)
-
- def test_overloaded_behaviour_allows_content_tunnelling(self):
- """
- Ensure request.DATA returns content for overloaded POST request.
- """
- json_data = {'foobar': 'qwerty'}
- content = json.dumps(json_data)
- content_type = 'application/json'
- form_data = {
- api_settings.FORM_CONTENT_OVERRIDE: content,
- api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
- }
- request = Request(factory.post('/', form_data))
- request.parsers = (JSONParser(), )
- self.assertEqual(request.DATA, json_data)
-
- def test_form_POST_unicode(self):
- """
- JSON POST via default web interface with unicode data
- """
- # Note: environ and other variables here have simplified content compared to real Request
- CONTENT = b'_content_type=application%2Fjson&_content=%7B%22request%22%3A+4%2C+%22firm%22%3A+1%2C+%22text%22%3A+%22%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%21%22%7D'
- environ = {
- 'REQUEST_METHOD': 'POST',
- 'CONTENT_TYPE': 'application/x-www-form-urlencoded',
- 'CONTENT_LENGTH': len(CONTENT),
- 'wsgi.input': BytesIO(CONTENT),
- }
- wsgi_request = WSGIRequest(environ=environ)
- wsgi_request._load_post_and_files()
- parsers = (JSONParser(), FormParser(), MultiPartParser())
- parser_context = {
- 'encoding': 'utf-8',
- 'kwargs': {},
- 'args': (),
- }
- request = Request(wsgi_request, parsers=parsers, parser_context=parser_context)
- method = request.method
- self.assertEqual(method, 'POST')
- self.assertEqual(request._content_type, 'application/json')
- self.assertEqual(request._stream.getvalue(), b'{"request": 4, "firm": 1, "text": "\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82!"}')
- self.assertEqual(request._data, Empty)
- self.assertEqual(request._files, Empty)
-
- # def test_accessing_post_after_data_form(self):
- # """
- # Ensures request.POST can be accessed after request.DATA in
- # form request.
- # """
- # data = {'qwerty': 'uiop'}
- # request = factory.post('/', data=data)
- # self.assertEqual(request.DATA.items(), data.items())
- # self.assertEqual(request.POST.items(), data.items())
-
- # def test_accessing_post_after_data_for_json(self):
- # """
- # Ensures request.POST can be accessed after request.DATA in
- # json request.
- # """
- # data = {'qwerty': 'uiop'}
- # content = json.dumps(data)
- # content_type = 'application/json'
- # parsers = (JSONParser, )
-
- # request = factory.post('/', content, content_type=content_type,
- # parsers=parsers)
- # self.assertEqual(request.DATA.items(), data.items())
- # self.assertEqual(request.POST.items(), [])
-
- # def test_accessing_post_after_data_for_overloaded_json(self):
- # """
- # Ensures request.POST can be accessed after request.DATA in overloaded
- # json request.
- # """
- # data = {'qwerty': 'uiop'}
- # content = json.dumps(data)
- # content_type = 'application/json'
- # parsers = (JSONParser, )
- # form_data = {Request._CONTENT_PARAM: content,
- # Request._CONTENTTYPE_PARAM: content_type}
-
- # request = factory.post('/', form_data, parsers=parsers)
- # self.assertEqual(request.DATA.items(), data.items())
- # self.assertEqual(request.POST.items(), form_data.items())
-
- # def test_accessing_data_after_post_form(self):
- # """
- # Ensures request.DATA can be accessed after request.POST in
- # form request.
- # """
- # data = {'qwerty': 'uiop'}
- # parsers = (FormParser, MultiPartParser)
- # request = factory.post('/', data, parsers=parsers)
-
- # self.assertEqual(request.POST.items(), data.items())
- # self.assertEqual(request.DATA.items(), data.items())
-
- # def test_accessing_data_after_post_for_json(self):
- # """
- # Ensures request.DATA can be accessed after request.POST in
- # json request.
- # """
- # data = {'qwerty': 'uiop'}
- # content = json.dumps(data)
- # content_type = 'application/json'
- # parsers = (JSONParser, )
- # request = factory.post('/', content, content_type=content_type,
- # parsers=parsers)
- # self.assertEqual(request.POST.items(), [])
- # self.assertEqual(request.DATA.items(), data.items())
-
- # def test_accessing_data_after_post_for_overloaded_json(self):
- # """
- # Ensures request.DATA can be accessed after request.POST in overloaded
- # json request
- # """
- # data = {'qwerty': 'uiop'}
- # content = json.dumps(data)
- # content_type = 'application/json'
- # parsers = (JSONParser, )
- # form_data = {Request._CONTENT_PARAM: content,
- # Request._CONTENTTYPE_PARAM: content_type}
-
- # request = factory.post('/', form_data, parsers=parsers)
- # self.assertEqual(request.POST.items(), form_data.items())
- # self.assertEqual(request.DATA.items(), data.items())
-
-
-class MockView(APIView):
- authentication_classes = (SessionAuthentication,)
-
- def post(self, request):
- if request.POST.get('example') is not None:
- return Response(status=status.HTTP_200_OK)
-
- return Response(status=status.INTERNAL_SERVER_ERROR)
-
-urlpatterns = patterns('',
- (r'^$', MockView.as_view()),
-)
-
-
-class TestContentParsingWithAuthentication(TestCase):
- urls = 'rest_framework.tests.test_request'
-
- def setUp(self):
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- def test_user_logged_in_authentication_has_POST_when_not_logged_in(self):
- """
- Ensures request.POST exists after SessionAuthentication when user
- doesn't log in.
- """
- content = {'example': 'example'}
-
- response = self.client.post('/', content)
- self.assertEqual(status.HTTP_200_OK, response.status_code)
-
- response = self.csrf_client.post('/', content)
- self.assertEqual(status.HTTP_200_OK, response.status_code)
-
- # def test_user_logged_in_authentication_has_post_when_logged_in(self):
- # """Ensures request.POST exists after UserLoggedInAuthentication when user does log in"""
- # self.client.login(username='john', password='password')
- # self.csrf_client.login(username='john', password='password')
- # content = {'example': 'example'}
-
- # response = self.client.post('/', content)
- # self.assertEqual(status.OK, response.status_code, "POST data is malformed")
-
- # response = self.csrf_client.post('/', content)
- # self.assertEqual(status.OK, response.status_code, "POST data is malformed")
-
-
-class TestUserSetter(TestCase):
-
- def setUp(self):
- # Pass request object through session middleware so session is
- # available to login and logout functions
- self.request = Request(factory.get('/'))
- SessionMiddleware().process_request(self.request)
-
- User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
- self.user = authenticate(username='ringo', password='yellow')
-
- def test_user_can_be_set(self):
- self.request.user = self.user
- self.assertEqual(self.request.user, self.user)
-
- def test_user_can_login(self):
- login(self.request, self.user)
- self.assertEqual(self.request.user, self.user)
-
- def test_user_can_logout(self):
- self.request.user = self.user
- self.assertFalse(self.request.user.is_anonymous())
- logout(self.request)
- self.assertTrue(self.request.user.is_anonymous())
-
-
-class TestAuthSetter(TestCase):
-
- def test_auth_can_be_set(self):
- request = Request(factory.get('/'))
- request.auth = 'DUMMY'
- self.assertEqual(request.auth, 'DUMMY')
diff --git a/rest_framework/tests/test_response.py b/rest_framework/tests/test_response.py
deleted file mode 100644
index eea3c641..00000000
--- a/rest_framework/tests/test_response.py
+++ /dev/null
@@ -1,278 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.tests.models import BasicModel, BasicModelSerializer
-from rest_framework.compat import patterns, url, include
-from rest_framework.response import Response
-from rest_framework.views import APIView
-from rest_framework import generics
-from rest_framework import routers
-from rest_framework import status
-from rest_framework.renderers import (
- BaseRenderer,
- JSONRenderer,
- BrowsableAPIRenderer
-)
-from rest_framework import viewsets
-from rest_framework.settings import api_settings
-from rest_framework.compat import six
-
-
-class MockPickleRenderer(BaseRenderer):
- media_type = 'application/pickle'
-
-
-class MockJsonRenderer(BaseRenderer):
- media_type = 'application/json'
-
-
-class MockTextMediaRenderer(BaseRenderer):
- media_type = 'text/html'
-
-DUMMYSTATUS = status.HTTP_200_OK
-DUMMYCONTENT = 'dummycontent'
-
-RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
-RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
-
-
-class RendererA(BaseRenderer):
- media_type = 'mock/renderera'
- format = "formata"
-
- def render(self, data, media_type=None, renderer_context=None):
- return RENDERER_A_SERIALIZER(data)
-
-
-class RendererB(BaseRenderer):
- media_type = 'mock/rendererb'
- format = "formatb"
-
- def render(self, data, media_type=None, renderer_context=None):
- return RENDERER_B_SERIALIZER(data)
-
-
-class RendererC(RendererB):
- media_type = 'mock/rendererc'
- format = 'formatc'
- charset = "rendererc"
-
-
-class MockView(APIView):
- renderer_classes = (RendererA, RendererB, RendererC)
-
- def get(self, request, **kwargs):
- return Response(DUMMYCONTENT, status=DUMMYSTATUS)
-
-
-class MockViewSettingContentType(APIView):
- renderer_classes = (RendererA, RendererB, RendererC)
-
- def get(self, request, **kwargs):
- return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview')
-
-
-class HTMLView(APIView):
- renderer_classes = (BrowsableAPIRenderer, )
-
- def get(self, request, **kwargs):
- return Response('text')
-
-
-class HTMLView1(APIView):
- renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
-
- def get(self, request, **kwargs):
- return Response('text')
-
-
-class HTMLNewModelViewSet(viewsets.ModelViewSet):
- model = BasicModel
-
-
-class HTMLNewModelView(generics.ListCreateAPIView):
- renderer_classes = (BrowsableAPIRenderer,)
- permission_classes = []
- serializer_class = BasicModelSerializer
- model = BasicModel
-
-
-new_model_viewset_router = routers.DefaultRouter()
-new_model_viewset_router.register(r'', HTMLNewModelViewSet)
-
-
-urlpatterns = patterns('',
- url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
- url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
- url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
- url(r'^html$', HTMLView.as_view()),
- url(r'^html1$', HTMLView1.as_view()),
- url(r'^html_new_model$', HTMLNewModelView.as_view()),
- url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)),
- url(r'^restframework', include('rest_framework.urls', namespace='rest_framework'))
-)
-
-
-# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
-class RendererIntegrationTests(TestCase):
- """
- End-to-end testing of renderers using an ResponseMixin on a generic view.
- """
-
- urls = 'rest_framework.tests.test_response'
-
- def test_default_renderer_serializes_content(self):
- """If the Accept header is not set the default renderer should serialize the response."""
- resp = self.client.get('/')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_head_method_serializes_no_content(self):
- """No response must be included in HEAD requests."""
- resp = self.client.head('/')
- self.assertEqual(resp.status_code, DUMMYSTATUS)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, six.b(''))
-
- def test_default_renderer_serializes_content_on_accept_any(self):
- """If the Accept header is set to */* the default renderer should serialize the response."""
- resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_default_case(self):
- """If the Accept header is set the specified renderer should serialize the response.
- (In this case we check that works for the default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_non_default_case(self):
- """If the Accept header is set the specified renderer should serialize the response.
- (In this case we check that works for a non-default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_on_accept_query(self):
- """The '_accept' query string should behave in the same way as the Accept header."""
- param = '?%s=%s' % (
- api_settings.URL_ACCEPT_OVERRIDE,
- RendererB.media_type
- )
- resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_on_format_query(self):
- """If a 'format' query is specified, the renderer with the matching
- format attribute should serialize the response."""
- resp = self.client.get('/?format=%s' % RendererB.format)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_on_format_kwargs(self):
- """If a 'format' keyword arg is specified, the renderer with the matching
- format attribute should serialize the response."""
- resp = self.client.get('/something.formatb')
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
- """If both a 'format' query and a matching Accept header specified,
- the renderer with the matching format attribute should serialize the response."""
- resp = self.client.get('/?format=%s' % RendererB.format,
- HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
-
-class Issue122Tests(TestCase):
- """
- Tests that covers #122.
- """
- urls = 'rest_framework.tests.test_response'
-
- def test_only_html_renderer(self):
- """
- Test if no infinite recursion occurs.
- """
- self.client.get('/html')
-
- def test_html_renderer_is_first(self):
- """
- Test if no infinite recursion occurs.
- """
- self.client.get('/html1')
-
-
-class Issue467Tests(TestCase):
- """
- Tests for #467
- """
-
- urls = 'rest_framework.tests.test_response'
-
- def test_form_has_label_and_help_text(self):
- resp = self.client.get('/html_new_model')
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
- self.assertContains(resp, 'Text comes here')
- self.assertContains(resp, 'Text description.')
-
-
-class Issue807Tests(TestCase):
- """
- Covers #807
- """
-
- urls = 'rest_framework.tests.test_response'
-
- def test_does_not_append_charset_by_default(self):
- """
- Renderers don't include a charset unless set explicitly.
- """
- headers = {"HTTP_ACCEPT": RendererA.media_type}
- resp = self.client.get('/', **headers)
- expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8')
- self.assertEqual(expected, resp['Content-Type'])
-
- def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
- """
- If renderer class has charset attribute declared, it gets appended
- to Response's Content-Type
- """
- headers = {"HTTP_ACCEPT": RendererC.media_type}
- resp = self.client.get('/', **headers)
- expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset)
- self.assertEqual(expected, resp['Content-Type'])
-
- def test_content_type_set_explictly_on_response(self):
- """
- The content type may be set explictly on the response.
- """
- headers = {"HTTP_ACCEPT": RendererC.media_type}
- resp = self.client.get('/setbyview', **headers)
- self.assertEqual('setbyview', resp['Content-Type'])
-
- def test_viewset_label_help_text(self):
- param = '?%s=%s' % (
- api_settings.URL_ACCEPT_OVERRIDE,
- 'text/html'
- )
- resp = self.client.get('/html_new_model_viewset/' + param)
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
- self.assertContains(resp, 'Text comes here')
- self.assertContains(resp, 'Text description.')
-
- def test_form_has_label_and_help_text(self):
- resp = self.client.get('/html_new_model')
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
- self.assertContains(resp, 'Text comes here')
- self.assertContains(resp, 'Text description.')
diff --git a/rest_framework/tests/test_reverse.py b/rest_framework/tests/test_reverse.py
deleted file mode 100644
index 690a30b1..00000000
--- a/rest_framework/tests/test_reverse.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.compat import patterns, url
-from rest_framework.reverse import reverse
-from rest_framework.test import APIRequestFactory
-
-factory = APIRequestFactory()
-
-
-def null_view(request):
- pass
-
-urlpatterns = patterns('',
- url(r'^view$', null_view, name='view'),
-)
-
-
-class ReverseTests(TestCase):
- """
- Tests for fully qualified URLs when using `reverse`.
- """
- urls = 'rest_framework.tests.test_reverse'
-
- def test_reversed_urls_are_fully_qualified(self):
- request = factory.get('/view')
- url = reverse('view', request=request)
- self.assertEqual(url, 'http://testserver/view')
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
deleted file mode 100644
index e723f7d4..00000000
--- a/rest_framework/tests/test_routers.py
+++ /dev/null
@@ -1,216 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from django.core.exceptions import ImproperlyConfigured
-from rest_framework import serializers, viewsets, permissions
-from rest_framework.compat import include, patterns, url
-from rest_framework.decorators import link, action
-from rest_framework.response import Response
-from rest_framework.routers import SimpleRouter, DefaultRouter
-from rest_framework.test import APIRequestFactory
-
-factory = APIRequestFactory()
-
-urlpatterns = patterns('',)
-
-
-class BasicViewSet(viewsets.ViewSet):
- def list(self, request, *args, **kwargs):
- return Response({'method': 'list'})
-
- @action()
- def action1(self, request, *args, **kwargs):
- return Response({'method': 'action1'})
-
- @action()
- def action2(self, request, *args, **kwargs):
- return Response({'method': 'action2'})
-
- @action(methods=['post', 'delete'])
- def action3(self, request, *args, **kwargs):
- return Response({'method': 'action2'})
-
- @link()
- def link1(self, request, *args, **kwargs):
- return Response({'method': 'link1'})
-
- @link()
- def link2(self, request, *args, **kwargs):
- return Response({'method': 'link2'})
-
-
-class TestSimpleRouter(TestCase):
- def setUp(self):
- self.router = SimpleRouter()
-
- def test_link_and_action_decorator(self):
- routes = self.router.get_routes(BasicViewSet)
- decorator_routes = routes[2:]
- # Make sure all these endpoints exist and none have been clobbered
- for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']):
- route = decorator_routes[i]
- # check url listing
- self.assertEqual(route.url,
- '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
- # check method to function mapping
- if endpoint == 'action3':
- methods_map = ['post', 'delete']
- elif endpoint.startswith('action'):
- methods_map = ['post']
- else:
- methods_map = ['get']
- for method in methods_map:
- self.assertEqual(route.mapping[method], endpoint)
-
-
-class RouterTestModel(models.Model):
- uuid = models.CharField(max_length=20)
- text = models.CharField(max_length=200)
-
-
-class TestCustomLookupFields(TestCase):
- """
- Ensure that custom lookup fields are correctly routed.
- """
- urls = 'rest_framework.tests.test_routers'
-
- def setUp(self):
- class NoteSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = RouterTestModel
- lookup_field = 'uuid'
- fields = ('url', 'uuid', 'text')
-
- class NoteViewSet(viewsets.ModelViewSet):
- queryset = RouterTestModel.objects.all()
- serializer_class = NoteSerializer
- lookup_field = 'uuid'
-
- RouterTestModel.objects.create(uuid='123', text='foo bar')
-
- self.router = SimpleRouter()
- self.router.register(r'notes', NoteViewSet)
-
- from rest_framework.tests import test_routers
- urls = getattr(test_routers, 'urlpatterns')
- urls += patterns('',
- url(r'^', include(self.router.urls)),
- )
-
- def test_custom_lookup_field_route(self):
- detail_route = self.router.urls[-1]
- detail_url_pattern = detail_route.regex.pattern
- self.assertIn('<uuid>', detail_url_pattern)
-
- def test_retrieve_lookup_field_list_view(self):
- response = self.client.get('/notes/')
- self.assertEqual(response.data,
- [{
- "url": "http://testserver/notes/123/",
- "uuid": "123", "text": "foo bar"
- }]
- )
-
- def test_retrieve_lookup_field_detail_view(self):
- response = self.client.get('/notes/123/')
- self.assertEqual(response.data,
- {
- "url": "http://testserver/notes/123/",
- "uuid": "123", "text": "foo bar"
- }
- )
-
-
-class TestTrailingSlashIncluded(TestCase):
- def setUp(self):
- class NoteViewSet(viewsets.ModelViewSet):
- model = RouterTestModel
-
- self.router = SimpleRouter()
- self.router.register(r'notes', NoteViewSet)
- self.urls = self.router.urls
-
- def test_urls_have_trailing_slash_by_default(self):
- expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$']
- for idx in range(len(expected)):
- self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
-
-
-class TestTrailingSlashRemoved(TestCase):
- def setUp(self):
- class NoteViewSet(viewsets.ModelViewSet):
- model = RouterTestModel
-
- self.router = SimpleRouter(trailing_slash=False)
- self.router.register(r'notes', NoteViewSet)
- self.urls = self.router.urls
-
- def test_urls_can_have_trailing_slash_removed(self):
- expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
- for idx in range(len(expected)):
- self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
-
-
-class TestNameableRoot(TestCase):
- def setUp(self):
- class NoteViewSet(viewsets.ModelViewSet):
- model = RouterTestModel
- self.router = DefaultRouter()
- self.router.root_view_name = 'nameable-root'
- self.router.register(r'notes', NoteViewSet)
- self.urls = self.router.urls
-
- def test_router_has_custom_name(self):
- expected = 'nameable-root'
- self.assertEqual(expected, self.urls[0].name)
-
-
-class TestActionKeywordArgs(TestCase):
- """
- Ensure keyword arguments passed in the `@action` decorator
- are properly handled. Refs #940.
- """
-
- def setUp(self):
- class TestViewSet(viewsets.ModelViewSet):
- permission_classes = []
-
- @action(permission_classes=[permissions.AllowAny])
- def custom(self, request, *args, **kwargs):
- return Response({
- 'permission_classes': self.permission_classes
- })
-
- self.router = SimpleRouter()
- self.router.register(r'test', TestViewSet, base_name='test')
- self.view = self.router.urls[-1].callback
-
- def test_action_kwargs(self):
- request = factory.post('/test/0/custom/')
- response = self.view(request)
- self.assertEqual(
- response.data,
- {'permission_classes': [permissions.AllowAny]}
- )
-
-
-class TestActionAppliedToExistingRoute(TestCase):
- """
- Ensure `@action` decorator raises an except when applied
- to an existing route
- """
-
- def test_exception_raised_when_action_applied_to_existing_route(self):
- class TestViewSet(viewsets.ModelViewSet):
-
- @action()
- def retrieve(self, request, *args, **kwargs):
- return Response({
- 'hello': 'world'
- })
-
- self.router = SimpleRouter()
- self.router.register(r'test', TestViewSet, base_name='test')
-
- with self.assertRaises(ImproperlyConfigured):
- self.router.urls
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
deleted file mode 100644
index 3ee2b38a..00000000
--- a/rest_framework/tests/test_serializer.py
+++ /dev/null
@@ -1,1949 +0,0 @@
-# -*- coding: utf-8 -*-
-from __future__ import unicode_literals
-from django.db import models
-from django.db.models.fields import BLANK_CHOICE_DASH
-from django.test import TestCase
-from django.utils import unittest
-from django.utils.datastructures import MultiValueDict
-from django.utils.translation import ugettext_lazy as _
-from rest_framework import serializers, fields, relations
-from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
- BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
- ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel)
-from rest_framework.tests.models import BasicModelSerializer
-import datetime
-import pickle
-try:
- import PIL
-except:
- PIL = None
-
-
-if PIL is not None:
- class AMOAFModel(RESTFrameworkModel):
- char_field = models.CharField(max_length=1024, blank=True)
- comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
- decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
- email_field = models.EmailField(max_length=1024, blank=True)
- file_field = models.FileField(upload_to='test', max_length=1024, blank=True)
- image_field = models.ImageField(upload_to='test', max_length=1024, blank=True)
- slug_field = models.SlugField(max_length=1024, blank=True)
- url_field = models.URLField(max_length=1024, blank=True)
-
- class DVOAFModel(RESTFrameworkModel):
- positive_integer_field = models.PositiveIntegerField(blank=True)
- positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
- email_field = models.EmailField(blank=True)
- file_field = models.FileField(upload_to='test', blank=True)
- image_field = models.ImageField(upload_to='test', blank=True)
- slug_field = models.SlugField(blank=True)
- url_field = models.URLField(blank=True)
-
-
-class SubComment(object):
- def __init__(self, sub_comment):
- self.sub_comment = sub_comment
-
-
-class Comment(object):
- def __init__(self, email, content, created):
- self.email = email
- self.content = content
- self.created = created or datetime.datetime.now()
-
- def __eq__(self, other):
- return all([getattr(self, attr) == getattr(other, attr)
- for attr in ('email', 'content', 'created')])
-
- def get_sub_comment(self):
- sub_comment = SubComment('And Merry Christmas!')
- return sub_comment
-
-
-class CommentSerializer(serializers.Serializer):
- email = serializers.EmailField()
- content = serializers.CharField(max_length=1000)
- created = serializers.DateTimeField()
- sub_comment = serializers.Field(source='get_sub_comment.sub_comment')
-
- def restore_object(self, data, instance=None):
- if instance is None:
- return Comment(**data)
- for key, val in data.items():
- setattr(instance, key, val)
- return instance
-
-
-class NamesSerializer(serializers.Serializer):
- first = serializers.CharField()
- last = serializers.CharField(required=False, default='')
- initials = serializers.CharField(required=False, default='')
-
-
-class PersonIdentifierSerializer(serializers.Serializer):
- ssn = serializers.CharField()
- names = NamesSerializer(source='names', required=False)
-
-
-class BookSerializer(serializers.ModelSerializer):
- isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'})
-
- class Meta:
- model = Book
-
-
-class ActionItemSerializer(serializers.ModelSerializer):
-
- class Meta:
- model = ActionItem
-
-class ActionItemSerializerOptionalFields(serializers.ModelSerializer):
- """
- Intended to test that fields with `required=False` are excluded from validation.
- """
- title = serializers.CharField(required=False)
-
- class Meta:
- model = ActionItem
- fields = ('title',)
-
-class ActionItemSerializerCustomRestore(serializers.ModelSerializer):
-
- class Meta:
- model = ActionItem
-
- def restore_object(self, data, instance=None):
- if instance is None:
- return ActionItem(**data)
- for key, val in data.items():
- setattr(instance, key, val)
- return instance
-
-
-class PersonSerializer(serializers.ModelSerializer):
- info = serializers.Field(source='info')
-
- class Meta:
- model = Person
- fields = ('name', 'age', 'info')
- read_only_fields = ('age',)
-
-
-class NestedSerializer(serializers.Serializer):
- info = serializers.Field()
-
-
-class ModelSerializerWithNestedSerializer(serializers.ModelSerializer):
- nested = NestedSerializer(source='*')
-
- class Meta:
- model = Person
-
-
-class NestedSerializerWithRenamedField(serializers.Serializer):
- renamed_info = serializers.Field(source='info')
-
-
-class ModelSerializerWithNestedSerializerWithRenamedField(serializers.ModelSerializer):
- nested = NestedSerializerWithRenamedField(source='*')
-
- class Meta:
- model = Person
-
-
-class PersonSerializerInvalidReadOnly(serializers.ModelSerializer):
- """
- Testing for #652.
- """
- info = serializers.Field(source='info')
-
- class Meta:
- model = Person
- fields = ('name', 'age', 'info')
- read_only_fields = ('age', 'info')
-
-
-class AlbumsSerializer(serializers.ModelSerializer):
-
- class Meta:
- model = Album
- fields = ['title', 'ref'] # lists are also valid options
-
-
-class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
- class Meta:
- model = HasPositiveIntegerAsChoice
- fields = ['some_integer']
-
-
-class BasicTests(TestCase):
- def setUp(self):
- self.comment = Comment(
- 'tom@example.com',
- 'Happy new year!',
- datetime.datetime(2012, 1, 1)
- )
- self.actionitem = ActionItem(title='Some to do item',)
- self.data = {
- 'email': 'tom@example.com',
- 'content': 'Happy new year!',
- 'created': datetime.datetime(2012, 1, 1),
- 'sub_comment': 'This wont change'
- }
- self.expected = {
- 'email': 'tom@example.com',
- 'content': 'Happy new year!',
- 'created': datetime.datetime(2012, 1, 1),
- 'sub_comment': 'And Merry Christmas!'
- }
- self.person_data = {'name': 'dwight', 'age': 35}
- self.person = Person(**self.person_data)
- self.person.save()
-
- def test_empty(self):
- serializer = CommentSerializer()
- expected = {
- 'email': '',
- 'content': '',
- 'created': None
- }
- self.assertEqual(serializer.data, expected)
-
- def test_retrieve(self):
- serializer = CommentSerializer(self.comment)
- self.assertEqual(serializer.data, self.expected)
-
- def test_create(self):
- serializer = CommentSerializer(data=self.data)
- expected = self.comment
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected)
- self.assertFalse(serializer.object is expected)
- self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
-
- def test_create_nested(self):
- """Test a serializer with nested data."""
- names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'}
- data = {'ssn': '1234567890', 'names': names}
- serializer = PersonIdentifierSerializer(data=data)
-
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, data)
- self.assertFalse(serializer.object is data)
- self.assertEqual(serializer.data['names'], names)
-
- def test_create_partial_nested(self):
- """Test a serializer with nested data which has missing fields."""
- names = {'first': 'John'}
- data = {'ssn': '1234567890', 'names': names}
- serializer = PersonIdentifierSerializer(data=data)
-
- expected_names = {'first': 'John', 'last': '', 'initials': ''}
- data['names'] = expected_names
-
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, data)
- self.assertFalse(serializer.object is expected_names)
- self.assertEqual(serializer.data['names'], expected_names)
-
- def test_null_nested(self):
- """Test a serializer with a nonexistent nested field"""
- data = {'ssn': '1234567890'}
- serializer = PersonIdentifierSerializer(data=data)
-
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, data)
- self.assertFalse(serializer.object is data)
- expected = {'ssn': '1234567890', 'names': None}
- self.assertEqual(serializer.data, expected)
-
- def test_update(self):
- serializer = CommentSerializer(self.comment, data=self.data)
- expected = self.comment
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected)
- self.assertTrue(serializer.object is expected)
- self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
-
- def test_partial_update(self):
- msg = 'Merry New Year!'
- partial_data = {'content': msg}
- serializer = CommentSerializer(self.comment, data=partial_data)
- self.assertEqual(serializer.is_valid(), False)
- serializer = CommentSerializer(self.comment, data=partial_data, partial=True)
- expected = self.comment
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected)
- self.assertTrue(serializer.object is expected)
- self.assertEqual(serializer.data['content'], msg)
-
- def test_model_fields_as_expected(self):
- """
- Make sure that the fields returned are the same as defined
- in the Meta data
- """
- serializer = PersonSerializer(self.person)
- self.assertEqual(set(serializer.data.keys()),
- set(['name', 'age', 'info']))
-
- def test_field_with_dictionary(self):
- """
- Make sure that dictionaries from fields are left intact
- """
- serializer = PersonSerializer(self.person)
- expected = self.person_data
- self.assertEqual(serializer.data['info'], expected)
-
- def test_read_only_fields(self):
- """
- Attempting to update fields set as read_only should have no effect.
- """
- serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(serializer.errors, {})
- # Assert age is unchanged (35)
- self.assertEqual(instance.age, self.person_data['age'])
-
- def test_invalid_read_only_fields(self):
- """
- Regression test for #652.
- """
- self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, [])
-
- def test_serializer_data_is_cleared_on_save(self):
- """
- Check _data attribute is cleared on `save()`
-
- Regression test for #1116
- — id field is not populated if `data` is accessed prior to `save()`
- """
- serializer = ActionItemSerializer(self.actionitem)
- self.assertIsNone(serializer.data.get('id',None), 'New instance. `id` should not be set.')
- serializer.save()
- self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.')
-
- def test_fields_marked_as_not_required_are_excluded_from_validation(self):
- """
- Check that fields with `required=False` are included in list of exclusions.
- """
- serializer = ActionItemSerializerOptionalFields(self.actionitem)
- exclusions = serializer.get_validation_exclusions()
- self.assertTrue('title' in exclusions, '`title` field was marked `required=False` and should be excluded')
-
-
-class DictStyleSerializer(serializers.Serializer):
- """
- Note that we don't have any `restore_object` method, so the default
- case of simply returning a dict will apply.
- """
- email = serializers.EmailField()
-
-
-class DictStyleSerializerTests(TestCase):
- def test_dict_style_deserialize(self):
- """
- Ensure serializers can deserialize into a dict.
- """
- data = {'email': 'foo@example.com'}
- serializer = DictStyleSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
-
- def test_dict_style_serialize(self):
- """
- Ensure serializers can serialize dict objects.
- """
- data = {'email': 'foo@example.com'}
- serializer = DictStyleSerializer(data)
- self.assertEqual(serializer.data, data)
-
-
-class ValidationTests(TestCase):
- def setUp(self):
- self.comment = Comment(
- 'tom@example.com',
- 'Happy new year!',
- datetime.datetime(2012, 1, 1)
- )
- self.data = {
- 'email': 'tom@example.com',
- 'content': 'x' * 1001,
- 'created': datetime.datetime(2012, 1, 1)
- }
- self.actionitem = ActionItem(title='Some to do item',)
-
- def test_create(self):
- serializer = CommentSerializer(data=self.data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
-
- def test_update(self):
- serializer = CommentSerializer(self.comment, data=self.data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
-
- def test_update_missing_field(self):
- data = {
- 'content': 'xxx',
- 'created': datetime.datetime(2012, 1, 1)
- }
- serializer = CommentSerializer(self.comment, data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'email': ['This field is required.']})
-
- def test_missing_bool_with_default(self):
- """Make sure that a boolean value with a 'False' value is not
- mistaken for not having a default."""
- data = {
- 'title': 'Some action item',
- #No 'done' value.
- }
- serializer = ActionItemSerializer(self.actionitem, data=data)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.errors, {})
-
- def test_cross_field_validation(self):
-
- class CommentSerializerWithCrossFieldValidator(CommentSerializer):
-
- def validate(self, attrs):
- if attrs["email"] not in attrs["content"]:
- raise serializers.ValidationError("Email address not in content")
- return attrs
-
- data = {
- 'email': 'tom@example.com',
- 'content': 'A comment from tom@example.com',
- 'created': datetime.datetime(2012, 1, 1)
- }
-
- serializer = CommentSerializerWithCrossFieldValidator(data=data)
- self.assertTrue(serializer.is_valid())
-
- data['content'] = 'A comment from foo@bar.com'
-
- serializer = CommentSerializerWithCrossFieldValidator(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'non_field_errors': ['Email address not in content']})
-
- def test_null_is_true_fields(self):
- """
- Omitting a value for null-field should validate.
- """
- serializer = PersonSerializer(data={'name': 'marko'})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.errors, {})
-
- def test_modelserializer_max_length_exceeded(self):
- data = {
- 'title': 'x' * 201,
- }
- serializer = ActionItemSerializer(data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
-
- def test_modelserializer_max_length_exceeded_with_custom_restore(self):
- """
- When overriding ModelSerializer.restore_object, validation tests should still apply.
- Regression test for #623.
-
- https://github.com/tomchristie/django-rest-framework/pull/623
- """
- data = {
- 'title': 'x' * 201,
- }
- serializer = ActionItemSerializerCustomRestore(data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
-
- def test_default_modelfield_max_length_exceeded(self):
- data = {
- 'title': 'Testing "info" field...',
- 'info': 'x' * 13,
- }
- serializer = ActionItemSerializer(data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']})
-
- def test_datetime_validation_failure(self):
- """
- Test DateTimeField validation errors on non-str values.
- Regression test for #669.
-
- https://github.com/tomchristie/django-rest-framework/issues/669
- """
- data = self.data
- data['created'] = 0
-
- serializer = CommentSerializer(data=data)
- self.assertEqual(serializer.is_valid(), False)
-
- self.assertIn('created', serializer.errors)
-
- def test_missing_model_field_exception_msg(self):
- """
- Assert that a meaningful exception message is outputted when the model
- field is missing (e.g. when mistyping ``model``).
- """
- class BrokenModelSerializer(serializers.ModelSerializer):
- class Meta:
- fields = ['some_field']
-
- try:
- BrokenModelSerializer()
- except AssertionError as e:
- self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option")
- except:
- self.fail('Wrong exception type thrown.')
-
- def test_writable_star_source_on_nested_serializer(self):
- """
- Assert that a nested serializer instantiated with source='*' correctly
- expands the data into the outer serializer.
- """
- serializer = ModelSerializerWithNestedSerializer(data={
- 'name': 'marko',
- 'nested': {'info': 'hi'}},
- )
- self.assertEqual(serializer.is_valid(), True)
-
- def test_writable_star_source_on_nested_serializer_with_parent_object(self):
- class TitleSerializer(serializers.Serializer):
- title = serializers.WritableField(source='title')
-
- class AlbumSerializer(serializers.ModelSerializer):
- nested = TitleSerializer(source='*')
-
- class Meta:
- model = Album
- fields = ('nested',)
-
- class PhotoSerializer(serializers.ModelSerializer):
- album = AlbumSerializer(source='album')
-
- class Meta:
- model = Photo
- fields = ('album', )
-
- photo = Photo(album=Album())
-
- data = {'album': {'nested': {'title': 'test'}}}
-
- serializer = PhotoSerializer(photo, data=data)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.data, data)
-
- def test_writable_star_source_with_inner_source_fields(self):
- """
- Tests that a serializer with source="*" correctly expands the
- it's fields into the outer serializer even if they have their
- own 'source' parameters.
- """
-
- serializer = ModelSerializerWithNestedSerializerWithRenamedField(data={
- 'name': 'marko',
- 'nested': {'renamed_info': 'hi'}},
- )
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.errors, {})
-
-
-class CustomValidationTests(TestCase):
- class CommentSerializerWithFieldValidator(CommentSerializer):
-
- def validate_email(self, attrs, source):
- attrs[source]
- return attrs
-
- def validate_content(self, attrs, source):
- value = attrs[source]
- if "test" not in value:
- raise serializers.ValidationError("Test not in value")
- return attrs
-
- def test_field_validation(self):
- data = {
- 'email': 'tom@example.com',
- 'content': 'A test comment',
- 'created': datetime.datetime(2012, 1, 1)
- }
-
- serializer = self.CommentSerializerWithFieldValidator(data=data)
- self.assertTrue(serializer.is_valid())
-
- data['content'] = 'This should not validate'
-
- serializer = self.CommentSerializerWithFieldValidator(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'content': ['Test not in value']})
-
- def test_missing_data(self):
- """
- Make sure that validate_content isn't called if the field is missing
- """
- incomplete_data = {
- 'email': 'tom@example.com',
- 'created': datetime.datetime(2012, 1, 1)
- }
- serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'content': ['This field is required.']})
-
- def test_wrong_data(self):
- """
- Make sure that validate_content isn't called if the field input is wrong
- """
- wrong_data = {
- 'email': 'not an email',
- 'content': 'A test comment',
- 'created': datetime.datetime(2012, 1, 1)
- }
- serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']})
-
- def test_partial_update(self):
- """
- Make sure that validate_email isn't called when partial=True and email
- isn't found in data.
- """
- initial_data = {
- 'email': 'tom@example.com',
- 'content': 'A test comment',
- 'created': datetime.datetime(2012, 1, 1)
- }
-
- serializer = self.CommentSerializerWithFieldValidator(data=initial_data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.object
-
- new_content = 'An *updated* test comment'
- partial_data = {
- 'content': new_content
- }
-
- serializer = self.CommentSerializerWithFieldValidator(instance=instance,
- data=partial_data,
- partial=True)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.object
- self.assertEqual(instance.content, new_content)
-
-
-class PositiveIntegerAsChoiceTests(TestCase):
- def test_positive_integer_in_json_is_correctly_parsed(self):
- data = {'some_integer': 1}
- serializer = PositiveIntegerAsChoiceSerializer(data=data)
- self.assertEqual(serializer.is_valid(), True)
-
-
-class ModelValidationTests(TestCase):
- def test_validate_unique(self):
- """
- Just check if serializers.ModelSerializer handles unique checks via .full_clean()
- """
- serializer = AlbumsSerializer(data={'title': 'a', 'ref': '1'})
- serializer.is_valid()
- serializer.save()
- second_serializer = AlbumsSerializer(data={'title': 'a'})
- self.assertFalse(second_serializer.is_valid())
- self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.'],})
- third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}])
- self.assertFalse(third_serializer.is_valid())
- self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}])
-
- def test_foreign_key_is_null_with_partial(self):
- """
- Test ModelSerializer validation with partial=True
-
- Specifically test that a null foreign key does not pass validation
- """
- album = Album(title='test')
- album.save()
-
- class PhotoSerializer(serializers.ModelSerializer):
- class Meta:
- model = Photo
-
- photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk})
- self.assertTrue(photo_serializer.is_valid())
- photo = photo_serializer.save()
-
- # Updating only the album (foreign key)
- photo_serializer = PhotoSerializer(instance=photo, data={'album': ''}, partial=True)
- self.assertFalse(photo_serializer.is_valid())
- self.assertTrue('album' in photo_serializer.errors)
- self.assertEqual(photo_serializer.errors['album'], photo_serializer.error_messages['required'])
-
- def test_foreign_key_with_partial(self):
- """
- Test ModelSerializer validation with partial=True
-
- Specifically test foreign key validation.
- """
-
- album = Album(title='test')
- album.save()
-
- class PhotoSerializer(serializers.ModelSerializer):
- class Meta:
- model = Photo
-
- photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk})
- self.assertTrue(photo_serializer.is_valid())
- photo = photo_serializer.save()
-
- # Updating only the album (foreign key)
- photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True)
- self.assertTrue(photo_serializer.is_valid())
- self.assertTrue(photo_serializer.save())
-
- # Updating only the description
- photo_serializer = PhotoSerializer(instance=photo,
- data={'description': 'new'},
- partial=True)
-
- self.assertTrue(photo_serializer.is_valid())
- self.assertTrue(photo_serializer.save())
-
-
-class RegexValidationTest(TestCase):
- def test_create_failed(self):
- serializer = BookSerializer(data={'isbn': '1234567890'})
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
-
- serializer = BookSerializer(data={'isbn': '12345678901234'})
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
-
- serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
-
- def test_create_success(self):
- serializer = BookSerializer(data={'isbn': '1234567890123'})
- self.assertTrue(serializer.is_valid())
-
-
-class MetadataTests(TestCase):
- def test_empty(self):
- serializer = CommentSerializer()
- expected = {
- 'email': serializers.CharField,
- 'content': serializers.CharField,
- 'created': serializers.DateTimeField
- }
- for field_name, field in expected.items():
- self.assertTrue(isinstance(serializer.data.fields[field_name], field))
-
-
-class ManyToManyTests(TestCase):
- def setUp(self):
- class ManyToManySerializer(serializers.ModelSerializer):
- class Meta:
- model = ManyToManyModel
-
- self.serializer_class = ManyToManySerializer
-
- # An anchor instance to use for the relationship
- self.anchor = Anchor()
- self.anchor.save()
-
- # A model instance with a many to many relationship to the anchor
- self.instance = ManyToManyModel()
- self.instance.save()
- self.instance.rel.add(self.anchor)
-
- # A serialized representation of the model instance
- self.data = {'id': 1, 'rel': [self.anchor.id]}
-
- def test_retrieve(self):
- """
- Serialize an instance of a model with a ManyToMany relationship.
- """
- serializer = self.serializer_class(instance=self.instance)
- expected = self.data
- self.assertEqual(serializer.data, expected)
-
- def test_create(self):
- """
- Create an instance of a model with a ManyToMany relationship.
- """
- data = {'rel': [self.anchor.id]}
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ManyToManyModel.objects.all()), 2)
- self.assertEqual(instance.pk, 2)
- self.assertEqual(list(instance.rel.all()), [self.anchor])
-
- def test_update(self):
- """
- Update an instance of a model with a ManyToMany relationship.
- """
- new_anchor = Anchor()
- new_anchor.save()
- data = {'rel': [self.anchor.id, new_anchor.id]}
- serializer = self.serializer_class(self.instance, data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ManyToManyModel.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- self.assertEqual(list(instance.rel.all()), [self.anchor, new_anchor])
-
- def test_create_empty_relationship(self):
- """
- Create an instance of a model with a ManyToMany relationship,
- containing no items.
- """
- data = {'rel': []}
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ManyToManyModel.objects.all()), 2)
- self.assertEqual(instance.pk, 2)
- self.assertEqual(list(instance.rel.all()), [])
-
- def test_update_empty_relationship(self):
- """
- Update an instance of a model with a ManyToMany relationship,
- containing no items.
- """
- new_anchor = Anchor()
- new_anchor.save()
- data = {'rel': []}
- serializer = self.serializer_class(self.instance, data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ManyToManyModel.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- self.assertEqual(list(instance.rel.all()), [])
-
- def test_create_empty_relationship_flat_data(self):
- """
- Create an instance of a model with a ManyToMany relationship,
- containing no items, using a representation that does not support
- lists (eg form data).
- """
- data = MultiValueDict()
- data.setlist('rel', [''])
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ManyToManyModel.objects.all()), 2)
- self.assertEqual(instance.pk, 2)
- self.assertEqual(list(instance.rel.all()), [])
-
-
-class ReadOnlyManyToManyTests(TestCase):
- def setUp(self):
- class ReadOnlyManyToManySerializer(serializers.ModelSerializer):
- rel = serializers.RelatedField(many=True, read_only=True)
-
- class Meta:
- model = ReadOnlyManyToManyModel
-
- self.serializer_class = ReadOnlyManyToManySerializer
-
- # An anchor instance to use for the relationship
- self.anchor = Anchor()
- self.anchor.save()
-
- # A model instance with a many to many relationship to the anchor
- self.instance = ReadOnlyManyToManyModel()
- self.instance.save()
- self.instance.rel.add(self.anchor)
-
- # A serialized representation of the model instance
- self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'}
-
- def test_update(self):
- """
- Attempt to update an instance of a model with a ManyToMany
- relationship. Not updated due to read_only=True
- """
- new_anchor = Anchor()
- new_anchor.save()
- data = {'rel': [self.anchor.id, new_anchor.id]}
- serializer = self.serializer_class(self.instance, data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- # rel is still as original (1 entry)
- self.assertEqual(list(instance.rel.all()), [self.anchor])
-
- def test_update_without_relationship(self):
- """
- Attempt to update an instance of a model where many to ManyToMany
- relationship is not supplied. Not updated due to read_only=True
- """
- new_anchor = Anchor()
- new_anchor.save()
- data = {}
- serializer = self.serializer_class(self.instance, data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- # rel is still as original (1 entry)
- self.assertEqual(list(instance.rel.all()), [self.anchor])
-
-
-class DefaultValueTests(TestCase):
- def setUp(self):
- class DefaultValueSerializer(serializers.ModelSerializer):
- class Meta:
- model = DefaultValueModel
-
- self.serializer_class = DefaultValueSerializer
- self.objects = DefaultValueModel.objects
-
- def test_create_using_default(self):
- data = {}
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(self.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- self.assertEqual(instance.text, 'foobar')
-
- def test_create_overriding_default(self):
- data = {'text': 'overridden'}
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(self.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- self.assertEqual(instance.text, 'overridden')
-
- def test_partial_update_default(self):
- """ Regression test for issue #532 """
- data = {'text': 'overridden'}
- serializer = self.serializer_class(data=data, partial=True)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
-
- data = {'extra': 'extra_value'}
- serializer = self.serializer_class(instance=instance, data=data, partial=True)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
-
- self.assertEqual(instance.extra, 'extra_value')
- self.assertEqual(instance.text, 'overridden')
-
-
-class WritableFieldDefaultValueTests(TestCase):
-
- def setUp(self):
- self.expected = {'default': 'value'}
- self.create_field = fields.WritableField
-
- def test_get_default_value_with_noncallable(self):
- field = self.create_field(default=self.expected)
- got = field.get_default_value()
- self.assertEqual(got, self.expected)
-
- def test_get_default_value_with_callable(self):
- field = self.create_field(default=lambda : self.expected)
- got = field.get_default_value()
- self.assertEqual(got, self.expected)
-
- def test_get_default_value_when_not_required(self):
- field = self.create_field(default=self.expected, required=False)
- got = field.get_default_value()
- self.assertEqual(got, self.expected)
-
- def test_get_default_value_returns_None(self):
- field = self.create_field()
- got = field.get_default_value()
- self.assertIsNone(got)
-
- def test_get_default_value_returns_non_True_values(self):
- values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause
- for expected in values:
- field = self.create_field(default=expected)
- got = field.get_default_value()
- self.assertEqual(got, expected)
-
-
-class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests):
-
- def setUp(self):
- self.expected = {'foo': 'bar'}
- self.create_field = relations.RelatedField
-
- def test_get_default_value_returns_empty_list(self):
- field = self.create_field(many=True)
- got = field.get_default_value()
- self.assertListEqual(got, [])
-
- def test_get_default_value_returns_expected(self):
- expected = [1, 2, 3]
- field = self.create_field(many=True, default=expected)
- got = field.get_default_value()
- self.assertListEqual(got, expected)
-
-
-class CallableDefaultValueTests(TestCase):
- def setUp(self):
- class CallableDefaultValueSerializer(serializers.ModelSerializer):
- class Meta:
- model = CallableDefaultValueModel
-
- self.serializer_class = CallableDefaultValueSerializer
- self.objects = CallableDefaultValueModel.objects
-
- def test_create_using_default(self):
- data = {}
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(self.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- self.assertEqual(instance.text, 'foobar')
-
- def test_create_overriding_default(self):
- data = {'text': 'overridden'}
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(self.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- self.assertEqual(instance.text, 'overridden')
-
-
-class ManyRelatedTests(TestCase):
- def test_reverse_relations(self):
- post = BlogPost.objects.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I hate this blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- class BlogPostCommentSerializer(serializers.Serializer):
- text = serializers.CharField()
-
- class BlogPostSerializer(serializers.Serializer):
- title = serializers.CharField()
- comments = BlogPostCommentSerializer(source='blogpostcomment_set')
-
- serializer = BlogPostSerializer(instance=post)
- expected = {
- 'title': 'Test blog post',
- 'comments': [
- {'text': 'I hate this blog post'},
- {'text': 'I love this blog post'}
- ]
- }
-
- self.assertEqual(serializer.data, expected)
-
- def test_include_reverse_relations(self):
- post = BlogPost.objects.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I hate this blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- class BlogPostSerializer(serializers.ModelSerializer):
- class Meta:
- model = BlogPost
- fields = ('id', 'title', 'blogpostcomment_set')
-
- serializer = BlogPostSerializer(instance=post)
- expected = {
- 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2]
- }
- self.assertEqual(serializer.data, expected)
-
- def test_depth_include_reverse_relations(self):
- post = BlogPost.objects.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I hate this blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- class BlogPostSerializer(serializers.ModelSerializer):
- class Meta:
- model = BlogPost
- fields = ('id', 'title', 'blogpostcomment_set')
- depth = 1
-
- serializer = BlogPostSerializer(instance=post)
- expected = {
- 'id': 1, 'title': 'Test blog post',
- 'blogpostcomment_set': [
- {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1},
- {'id': 2, 'text': 'I love this blog post', 'blog_post': 1}
- ]
- }
- self.assertEqual(serializer.data, expected)
-
- def test_callable_source(self):
- post = BlogPost.objects.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- class BlogPostCommentSerializer(serializers.Serializer):
- text = serializers.CharField()
-
- class BlogPostSerializer(serializers.Serializer):
- title = serializers.CharField()
- first_comment = BlogPostCommentSerializer(source='get_first_comment')
-
- serializer = BlogPostSerializer(post)
-
- expected = {
- 'title': 'Test blog post',
- 'first_comment': {'text': 'I love this blog post'}
- }
- self.assertEqual(serializer.data, expected)
-
-
-class RelatedTraversalTest(TestCase):
- def test_nested_traversal(self):
- """
- Source argument should support dotted.source notation.
- """
- user = Person.objects.create(name="django")
- post = BlogPost.objects.create(title="Test blog post", writer=user)
- post.blogpostcomment_set.create(text="I love this blog post")
-
- class PersonSerializer(serializers.ModelSerializer):
- class Meta:
- model = Person
- fields = ("name", "age")
-
- class BlogPostCommentSerializer(serializers.ModelSerializer):
- class Meta:
- model = BlogPostComment
- fields = ("text", "post_owner")
-
- text = serializers.CharField()
- post_owner = PersonSerializer(source='blog_post.writer')
-
- class BlogPostSerializer(serializers.Serializer):
- title = serializers.CharField()
- comments = BlogPostCommentSerializer(source='blogpostcomment_set')
-
- serializer = BlogPostSerializer(instance=post)
-
- expected = {
- 'title': 'Test blog post',
- 'comments': [{
- 'text': 'I love this blog post',
- 'post_owner': {
- "name": "django",
- "age": None
- }
- }]
- }
-
- self.assertEqual(serializer.data, expected)
-
- def test_nested_traversal_with_none(self):
- """
- If a component of the dotted.source is None, return None for the field.
- """
- from rest_framework.tests.models import NullableForeignKeySource
- instance = NullableForeignKeySource.objects.create(name='Source with null FK')
-
- class NullableSourceSerializer(serializers.Serializer):
- target_name = serializers.Field(source='target.name')
-
- serializer = NullableSourceSerializer(instance=instance)
-
- expected = {
- 'target_name': None,
- }
-
- self.assertEqual(serializer.data, expected)
-
-
-class SerializerMethodFieldTests(TestCase):
- def setUp(self):
-
- class BoopSerializer(serializers.Serializer):
- beep = serializers.SerializerMethodField('get_beep')
- boop = serializers.Field()
- boop_count = serializers.SerializerMethodField('get_boop_count')
-
- def get_beep(self, obj):
- return 'hello!'
-
- def get_boop_count(self, obj):
- return len(obj.boop)
-
- self.serializer_class = BoopSerializer
-
- def test_serializer_method_field(self):
-
- class MyModel(object):
- boop = ['a', 'b', 'c']
-
- source_data = MyModel()
-
- serializer = self.serializer_class(source_data)
-
- expected = {
- 'beep': 'hello!',
- 'boop': ['a', 'b', 'c'],
- 'boop_count': 3,
- }
-
- self.assertEqual(serializer.data, expected)
-
-
-# Test for issue #324
-class BlankFieldTests(TestCase):
- def setUp(self):
-
- class BlankFieldModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = BlankFieldModel
-
- class BlankFieldSerializer(serializers.Serializer):
- title = serializers.CharField(required=False)
-
- class NotBlankFieldModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = BasicModel
-
- class NotBlankFieldSerializer(serializers.Serializer):
- title = serializers.CharField()
-
- self.model_serializer_class = BlankFieldModelSerializer
- self.serializer_class = BlankFieldSerializer
- self.not_blank_model_serializer_class = NotBlankFieldModelSerializer
- self.not_blank_serializer_class = NotBlankFieldSerializer
- self.data = {'title': ''}
-
- def test_create_blank_field(self):
- serializer = self.serializer_class(data=self.data)
- self.assertEqual(serializer.is_valid(), True)
-
- def test_create_model_blank_field(self):
- serializer = self.model_serializer_class(data=self.data)
- self.assertEqual(serializer.is_valid(), True)
-
- def test_create_model_null_field(self):
- serializer = self.model_serializer_class(data={'title': None})
- self.assertEqual(serializer.is_valid(), True)
-
- def test_create_not_blank_field(self):
- """
- Test to ensure blank data in a field not marked as blank=True
- is considered invalid in a non-model serializer
- """
- serializer = self.not_blank_serializer_class(data=self.data)
- self.assertEqual(serializer.is_valid(), False)
-
- def test_create_model_not_blank_field(self):
- """
- Test to ensure blank data in a field not marked as blank=True
- is considered invalid in a model serializer
- """
- serializer = self.not_blank_model_serializer_class(data=self.data)
- self.assertEqual(serializer.is_valid(), False)
-
- def test_create_model_empty_field(self):
- serializer = self.model_serializer_class(data={})
- self.assertEqual(serializer.is_valid(), True)
-
-
-#test for issue #460
-class SerializerPickleTests(TestCase):
- """
- Test pickleability of the output of Serializers
- """
- def test_pickle_simple_model_serializer_data(self):
- """
- Test simple serializer
- """
- pickle.dumps(PersonSerializer(Person(name="Methusela", age=969)).data)
-
- def test_pickle_inner_serializer(self):
- """
- Test pickling a serializer whose resulting .data (a SortedDictWithMetadata) will
- have unpickleable meta data--in order to make sure metadata doesn't get pulled into the pickle.
- See DictWithMetadata.__getstate__
- """
- class InnerPersonSerializer(serializers.ModelSerializer):
- class Meta:
- model = Person
- fields = ('name', 'age')
- pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data, 0)
-
- def test_getstate_method_should_not_return_none(self):
- """
- Regression test for #645.
- """
- data = serializers.DictWithMetadata({1: 1})
- self.assertEqual(data.__getstate__(), serializers.SortedDict({1: 1}))
-
- def test_serializer_data_is_pickleable(self):
- """
- Another regression test for #645.
- """
- data = serializers.SortedDictWithMetadata({1: 1})
- repr(pickle.loads(pickle.dumps(data, 0)))
-
-
-# test for issue #725
-class SeveralChoicesModel(models.Model):
- color = models.CharField(
- max_length=10,
- choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')],
- blank=False
- )
- drink = models.CharField(
- max_length=10,
- choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')],
- blank=False,
- default='beer'
- )
- os = models.CharField(
- max_length=10,
- choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')],
- blank=True
- )
- music_genre = models.CharField(
- max_length=10,
- choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')],
- blank=True,
- default='metal'
- )
-
-
-class SerializerChoiceFields(TestCase):
-
- def setUp(self):
- super(SerializerChoiceFields, self).setUp()
-
- class SeveralChoicesSerializer(serializers.ModelSerializer):
- class Meta:
- model = SeveralChoicesModel
- fields = ('color', 'drink', 'os', 'music_genre')
-
- self.several_choices_serializer = SeveralChoicesSerializer
-
- def test_choices_blank_false_not_default(self):
- serializer = self.several_choices_serializer()
- self.assertEqual(
- serializer.fields['color'].choices,
- [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')]
- )
-
- def test_choices_blank_false_with_default(self):
- serializer = self.several_choices_serializer()
- self.assertEqual(
- serializer.fields['drink'].choices,
- [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')]
- )
-
- def test_choices_blank_true_not_default(self):
- serializer = self.several_choices_serializer()
- self.assertEqual(
- serializer.fields['os'].choices,
- BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')]
- )
-
- def test_choices_blank_true_with_default(self):
- serializer = self.several_choices_serializer()
- self.assertEqual(
- serializer.fields['music_genre'].choices,
- BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')]
- )
-
-
-# Regression tests for #675
-class Ticket(models.Model):
- assigned = models.ForeignKey(
- Person, related_name='assigned_tickets')
- reviewer = models.ForeignKey(
- Person, blank=True, null=True, related_name='reviewed_tickets')
-
-
-class SerializerRelatedChoicesTest(TestCase):
-
- def setUp(self):
- super(SerializerRelatedChoicesTest, self).setUp()
-
- class RelatedChoicesSerializer(serializers.ModelSerializer):
- class Meta:
- model = Ticket
- fields = ('assigned', 'reviewer')
-
- self.related_fields_serializer = RelatedChoicesSerializer
-
- def test_empty_queryset_required(self):
- serializer = self.related_fields_serializer()
- self.assertEqual(serializer.fields['assigned'].queryset.count(), 0)
- self.assertEqual(
- [x for x in serializer.fields['assigned'].widget.choices],
- []
- )
-
- def test_empty_queryset_not_required(self):
- serializer = self.related_fields_serializer()
- self.assertEqual(serializer.fields['reviewer'].queryset.count(), 0)
- self.assertEqual(
- [x for x in serializer.fields['reviewer'].widget.choices],
- [('', '---------')]
- )
-
- def test_with_some_persons_required(self):
- Person.objects.create(name="Lionel Messi")
- Person.objects.create(name="Xavi Hernandez")
- serializer = self.related_fields_serializer()
- self.assertEqual(serializer.fields['assigned'].queryset.count(), 2)
- self.assertEqual(
- [x for x in serializer.fields['assigned'].widget.choices],
- [(1, 'Person object - 1'), (2, 'Person object - 2')]
- )
-
- def test_with_some_persons_not_required(self):
- Person.objects.create(name="Lionel Messi")
- Person.objects.create(name="Xavi Hernandez")
- serializer = self.related_fields_serializer()
- self.assertEqual(serializer.fields['reviewer'].queryset.count(), 2)
- self.assertEqual(
- [x for x in serializer.fields['reviewer'].widget.choices],
- [('', '---------'), (1, 'Person object - 1'), (2, 'Person object - 2')]
- )
-
-
-class DepthTest(TestCase):
- def test_implicit_nesting(self):
-
- writer = Person.objects.create(name="django", age=1)
- post = BlogPost.objects.create(title="Test blog post", writer=writer)
- comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
-
- class BlogPostCommentSerializer(serializers.ModelSerializer):
- class Meta:
- model = BlogPostComment
- depth = 2
-
- serializer = BlogPostCommentSerializer(instance=comment)
- expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
- 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
-
- self.assertEqual(serializer.data, expected)
-
- def test_explicit_nesting(self):
- writer = Person.objects.create(name="django", age=1)
- post = BlogPost.objects.create(title="Test blog post", writer=writer)
- comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
-
- class PersonSerializer(serializers.ModelSerializer):
- class Meta:
- model = Person
-
- class BlogPostSerializer(serializers.ModelSerializer):
- writer = PersonSerializer()
-
- class Meta:
- model = BlogPost
-
- class BlogPostCommentSerializer(serializers.ModelSerializer):
- blog_post = BlogPostSerializer()
-
- class Meta:
- model = BlogPostComment
-
- serializer = BlogPostCommentSerializer(instance=comment)
- expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
- 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
-
- self.assertEqual(serializer.data, expected)
-
-
-class NestedSerializerContextTests(TestCase):
-
- def test_nested_serializer_context(self):
- """
- Regression for #497
-
- https://github.com/tomchristie/django-rest-framework/issues/497
- """
- class PhotoSerializer(serializers.ModelSerializer):
- class Meta:
- model = Photo
- fields = ("description", "callable")
-
- callable = serializers.SerializerMethodField('_callable')
-
- def _callable(self, instance):
- if not 'context_item' in self.context:
- raise RuntimeError("context isn't getting passed into 2nd level nested serializer")
- return "success"
-
- class AlbumSerializer(serializers.ModelSerializer):
- class Meta:
- model = Album
- fields = ("photo_set", "callable")
-
- photo_set = PhotoSerializer(source="photo_set")
- callable = serializers.SerializerMethodField("_callable")
-
- def _callable(self, instance):
- if not 'context_item' in self.context:
- raise RuntimeError("context isn't getting passed into 1st level nested serializer")
- return "success"
-
- class AlbumCollection(object):
- albums = None
-
- class AlbumCollectionSerializer(serializers.Serializer):
- albums = AlbumSerializer(source="albums")
-
- album1 = Album.objects.create(title="album 1")
- album2 = Album.objects.create(title="album 2")
- Photo.objects.create(description="Bigfoot", album=album1)
- Photo.objects.create(description="Unicorn", album=album1)
- Photo.objects.create(description="Yeti", album=album2)
- Photo.objects.create(description="Sasquatch", album=album2)
- album_collection = AlbumCollection()
- album_collection.albums = [album1, album2]
-
- # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers
- AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data
-
-
-class DeserializeListTestCase(TestCase):
-
- def setUp(self):
- self.data = {
- 'email': 'nobody@nowhere.com',
- 'content': 'This is some test content',
- 'created': datetime.datetime(2013, 3, 7),
- }
-
- def test_no_errors(self):
- data = [self.data.copy() for x in range(0, 3)]
- serializer = CommentSerializer(data=data, many=True)
- self.assertTrue(serializer.is_valid())
- self.assertTrue(isinstance(serializer.object, list))
- self.assertTrue(
- all((isinstance(item, Comment) for item in serializer.object))
- )
-
- def test_errors_return_as_list(self):
- invalid_item = self.data.copy()
- invalid_item['email'] = ''
- data = [self.data.copy(), invalid_item, self.data.copy()]
-
- serializer = CommentSerializer(data=data, many=True)
- self.assertFalse(serializer.is_valid())
- expected = [{}, {'email': ['This field is required.']}, {}]
- self.assertEqual(serializer.errors, expected)
-
-
-# Test for issue 747
-
-class LazyStringModel(object):
- def __init__(self, lazystring):
- self.lazystring = lazystring
-
-
-class LazyStringSerializer(serializers.Serializer):
- lazystring = serializers.Field()
-
- def restore_object(self, attrs, instance=None):
- if instance is not None:
- instance.lazystring = attrs.get('lazystring', instance.lazystring)
- return instance
- return LazyStringModel(**attrs)
-
-
-class LazyStringsTestCase(TestCase):
- def setUp(self):
- self.model = LazyStringModel(lazystring=_('lazystring'))
-
- def test_lazy_strings_are_translated(self):
- serializer = LazyStringSerializer(self.model)
- self.assertEqual(type(serializer.data['lazystring']),
- type('lazystring'))
-
-
-# Test for issue #467
-
-class FieldLabelTest(TestCase):
- def setUp(self):
- self.serializer_class = BasicModelSerializer
-
- def test_label_from_model(self):
- """
- Validates that label and help_text are correctly copied from the model class.
- """
- serializer = self.serializer_class()
- text_field = serializer.fields['text']
-
- self.assertEqual('Text comes here', text_field.label)
- self.assertEqual('Text description.', text_field.help_text)
-
- def test_field_ctor(self):
- """
- This is check that ctor supports both label and help_text.
- """
- self.assertEqual('Label', fields.Field(label='Label', help_text='Help').label)
- self.assertEqual('Help', fields.CharField(label='Label', help_text='Help').help_text)
- self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label)
-
-
-# Test for issue #961
-
-class ManyFieldHelpTextTest(TestCase):
- def test_help_text_no_hold_down_control_msg(self):
- """
- Validate that help_text doesn't contain the 'Hold down "Control" ...'
- message that Django appends to choice fields.
- """
- rel_field = fields.Field(help_text=ManyToManyModel._meta.get_field('rel').help_text)
- self.assertEqual('Some help text.', rel_field.help_text)
-
-
-@unittest.skipUnless(PIL is not None, 'PIL is not installed')
-class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
-
- def setUp(self):
-
- class AMOAFSerializer(serializers.ModelSerializer):
- class Meta:
- model = AMOAFModel
-
- self.serializer_class = AMOAFSerializer
- self.fields_attributes = {
- 'char_field': [
- ('max_length', 1024),
- ],
- 'comma_separated_integer_field': [
- ('max_length', 1024),
- ],
- 'decimal_field': [
- ('max_digits', 64),
- ('decimal_places', 32),
- ],
- 'email_field': [
- ('max_length', 1024),
- ],
- 'file_field': [
- ('max_length', 1024),
- ],
- 'image_field': [
- ('max_length', 1024),
- ],
- 'slug_field': [
- ('max_length', 1024),
- ],
- 'url_field': [
- ('max_length', 1024),
- ],
- }
-
- def field_test(self, field):
- serializer = self.serializer_class(data={})
- self.assertEqual(serializer.is_valid(), True)
-
- for attribute in self.fields_attributes[field]:
- self.assertEqual(
- getattr(serializer.fields[field], attribute[0]),
- attribute[1]
- )
-
- def test_char_field(self):
- self.field_test('char_field')
-
- def test_comma_separated_integer_field(self):
- self.field_test('comma_separated_integer_field')
-
- def test_decimal_field(self):
- self.field_test('decimal_field')
-
- def test_email_field(self):
- self.field_test('email_field')
-
- def test_file_field(self):
- self.field_test('file_field')
-
- def test_image_field(self):
- self.field_test('image_field')
-
- def test_slug_field(self):
- self.field_test('slug_field')
-
- def test_url_field(self):
- self.field_test('url_field')
-
-
-@unittest.skipUnless(PIL is not None, 'PIL is not installed')
-class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
-
- def setUp(self):
-
- class DVOAFSerializer(serializers.ModelSerializer):
- class Meta:
- model = DVOAFModel
-
- self.serializer_class = DVOAFSerializer
- self.fields_attributes = {
- 'positive_integer_field': [
- ('min_value', 0),
- ],
- 'positive_small_integer_field': [
- ('min_value', 0),
- ],
- 'email_field': [
- ('max_length', 75),
- ],
- 'file_field': [
- ('max_length', 100),
- ],
- 'image_field': [
- ('max_length', 100),
- ],
- 'slug_field': [
- ('max_length', 50),
- ],
- 'url_field': [
- ('max_length', 200),
- ],
- }
-
- def field_test(self, field):
- serializer = self.serializer_class(data={})
- self.assertEqual(serializer.is_valid(), True)
-
- for attribute in self.fields_attributes[field]:
- self.assertEqual(
- getattr(serializer.fields[field], attribute[0]),
- attribute[1]
- )
-
- def test_positive_integer_field(self):
- self.field_test('positive_integer_field')
-
- def test_positive_small_integer_field(self):
- self.field_test('positive_small_integer_field')
-
- def test_email_field(self):
- self.field_test('email_field')
-
- def test_file_field(self):
- self.field_test('file_field')
-
- def test_image_field(self):
- self.field_test('image_field')
-
- def test_slug_field(self):
- self.field_test('slug_field')
-
- def test_url_field(self):
- self.field_test('url_field')
-
-
-class MetadataSerializer(serializers.Serializer):
- field1 = serializers.CharField(3, required=True)
- field2 = serializers.CharField(10, required=False)
-
-
-class MetadataSerializerTestCase(TestCase):
- def setUp(self):
- self.serializer = MetadataSerializer()
-
- def test_serializer_metadata(self):
- metadata = self.serializer.metadata()
- expected = {
- 'field1': {
- 'required': True,
- 'max_length': 3,
- 'type': 'string',
- 'read_only': False
- },
- 'field2': {
- 'required': False,
- 'max_length': 10,
- 'type': 'string',
- 'read_only': False
- }
- }
- self.assertEqual(expected, metadata)
-
-
-### Regression test for #840
-
-class SimpleModel(models.Model):
- text = models.CharField(max_length=100)
-
-
-class SimpleModelSerializer(serializers.ModelSerializer):
- text = serializers.CharField()
- other = serializers.CharField()
-
- class Meta:
- model = SimpleModel
-
- def validate_other(self, attrs, source):
- del attrs['other']
- return attrs
-
-
-class FieldValidationRemovingAttr(TestCase):
- def test_removing_non_model_field_in_validation(self):
- """
- Removing an attr during field valiation should ensure that it is not
- passed through when restoring the object.
-
- This allows additional non-model fields to be supported.
-
- Regression test for #840.
- """
- serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'})
- self.assertTrue(serializer.is_valid())
- serializer.save()
- self.assertEqual(serializer.object.text, 'foo')
-
-
-### Regression test for #878
-
-class SimpleTargetModel(models.Model):
- text = models.CharField(max_length=100)
-
-
-class SimplePKSourceModelSerializer(serializers.Serializer):
- targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True)
- text = serializers.CharField()
-
-
-class SimpleSlugSourceModelSerializer(serializers.Serializer):
- targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk')
- text = serializers.CharField()
-
-
-class SerializerSupportsManyRelationships(TestCase):
- def setUp(self):
- SimpleTargetModel.objects.create(text='foo')
- SimpleTargetModel.objects.create(text='bar')
-
- def test_serializer_supports_pk_many_relationships(self):
- """
- Regression test for #878.
-
- Note that pk behavior has a different code path to usual cases,
- for performance reasons.
- """
- serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
-
- def test_serializer_supports_slug_many_relationships(self):
- """
- Regression test for #878.
- """
- serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
-
-
-class TransformMethodsSerializer(serializers.Serializer):
- a = serializers.CharField()
- b_renamed = serializers.CharField(source='b')
-
- def transform_a(self, obj, value):
- return value.lower()
-
- def transform_b_renamed(self, obj, value):
- if value is not None:
- return 'and ' + value
-
-
-class TestSerializerTransformMethods(TestCase):
- def setUp(self):
- self.s = TransformMethodsSerializer()
-
- def test_transform_methods(self):
- self.assertEqual(
- self.s.to_native({'a': 'GREEN EGGS', 'b': 'HAM'}),
- {
- 'a': 'green eggs',
- 'b_renamed': 'and HAM',
- }
- )
-
- def test_missing_fields(self):
- self.assertEqual(
- self.s.to_native({'a': 'GREEN EGGS'}),
- {
- 'a': 'green eggs',
- 'b_renamed': None,
- }
- )
-
-
-class DefaultTrueBooleanModel(models.Model):
- cat = models.BooleanField(default=True)
- dog = models.BooleanField(default=False)
-
-
-class SerializerDefaultTrueBoolean(TestCase):
-
- def setUp(self):
- super(SerializerDefaultTrueBoolean, self).setUp()
-
- class DefaultTrueBooleanSerializer(serializers.ModelSerializer):
- class Meta:
- model = DefaultTrueBooleanModel
- fields = ('cat', 'dog')
-
- self.default_true_boolean_serializer = DefaultTrueBooleanSerializer
-
- def test_enabled_as_false(self):
- serializer = self.default_true_boolean_serializer(data={'cat': False,
- 'dog': False})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.data['cat'], False)
- self.assertEqual(serializer.data['dog'], False)
-
- def test_enabled_as_true(self):
- serializer = self.default_true_boolean_serializer(data={'cat': True,
- 'dog': True})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.data['cat'], True)
- self.assertEqual(serializer.data['dog'], True)
-
- def test_enabled_partial(self):
- serializer = self.default_true_boolean_serializer(data={'cat': False},
- partial=True)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.data['cat'], False)
- self.assertEqual(serializer.data['dog'], False)
-
-
-class BoolenFieldTypeTest(TestCase):
- '''
- Ensure the various Boolean based model fields are rendered as the proper
- field type
-
- '''
-
- def setUp(self):
- '''
- Setup an ActionItemSerializer for BooleanTesting
- '''
- data = {
- 'title': 'b' * 201,
- }
- self.serializer = ActionItemSerializer(data=data)
-
- def test_booleanfield_type(self):
- '''
- Test that BooleanField is infered from models.BooleanField
- '''
- bfield = self.serializer.get_fields()['done']
- self.assertEqual(type(bfield), fields.BooleanField)
-
- def test_nullbooleanfield_type(self):
- '''
- Test that BooleanField is infered from models.NullBooleanField
-
- https://groups.google.com/forum/#!topic/django-rest-framework/D9mXEftpuQ8
- '''
- bfield = self.serializer.get_fields()['started']
- self.assertEqual(type(bfield), fields.BooleanField)
diff --git a/rest_framework/tests/test_serializer_bulk_update.py b/rest_framework/tests/test_serializer_bulk_update.py
deleted file mode 100644
index 8b0ded1a..00000000
--- a/rest_framework/tests/test_serializer_bulk_update.py
+++ /dev/null
@@ -1,278 +0,0 @@
-"""
-Tests to cover bulk create and update using serializers.
-"""
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import serializers
-
-
-class BulkCreateSerializerTests(TestCase):
- """
- Creating multiple instances using serializers.
- """
-
- def setUp(self):
- class BookSerializer(serializers.Serializer):
- id = serializers.IntegerField()
- title = serializers.CharField(max_length=100)
- author = serializers.CharField(max_length=100)
-
- self.BookSerializer = BookSerializer
-
- def test_bulk_create_success(self):
- """
- Correct bulk update serialization should return the input data.
- """
-
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 1,
- 'title': 'If this is a man',
- 'author': 'Primo Levi'
- }, {
- 'id': 2,
- 'title': 'The wind-up bird chronicle',
- 'author': 'Haruki Murakami'
- }
- ]
-
- serializer = self.BookSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, data)
-
- def test_bulk_create_errors(self):
- """
- Correct bulk update serialization should return the input data.
- """
-
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 1,
- 'title': 'If this is a man',
- 'author': 'Primo Levi'
- }, {
- 'id': 'foo',
- 'title': 'The wind-up bird chronicle',
- 'author': 'Haruki Murakami'
- }
- ]
- expected_errors = [
- {},
- {},
- {'id': ['Enter a whole number.']}
- ]
-
- serializer = self.BookSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, expected_errors)
-
- def test_invalid_list_datatype(self):
- """
- Data containing list of incorrect data type should return errors.
- """
- data = ['foo', 'bar', 'baz']
- serializer = self.BookSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
-
- expected_errors = [
- {'non_field_errors': ['Invalid data']},
- {'non_field_errors': ['Invalid data']},
- {'non_field_errors': ['Invalid data']}
- ]
-
- self.assertEqual(serializer.errors, expected_errors)
-
- def test_invalid_single_datatype(self):
- """
- Data containing a single incorrect data type should return errors.
- """
- data = 123
- serializer = self.BookSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
-
- expected_errors = {'non_field_errors': ['Expected a list of items.']}
-
- self.assertEqual(serializer.errors, expected_errors)
-
- def test_invalid_single_object(self):
- """
- Data containing only a single object, instead of a list of objects
- should return errors.
- """
- data = {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }
- serializer = self.BookSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
-
- expected_errors = {'non_field_errors': ['Expected a list of items.']}
-
- self.assertEqual(serializer.errors, expected_errors)
-
-
-class BulkUpdateSerializerTests(TestCase):
- """
- Updating multiple instances using serializers.
- """
-
- def setUp(self):
- class Book(object):
- """
- A data type that can be persisted to a mock storage backend
- with `.save()` and `.delete()`.
- """
- object_map = {}
-
- def __init__(self, id, title, author):
- self.id = id
- self.title = title
- self.author = author
-
- def save(self):
- Book.object_map[self.id] = self
-
- def delete(self):
- del Book.object_map[self.id]
-
- class BookSerializer(serializers.Serializer):
- id = serializers.IntegerField()
- title = serializers.CharField(max_length=100)
- author = serializers.CharField(max_length=100)
-
- def restore_object(self, attrs, instance=None):
- if instance:
- instance.id = attrs['id']
- instance.title = attrs['title']
- instance.author = attrs['author']
- return instance
- return Book(**attrs)
-
- self.Book = Book
- self.BookSerializer = BookSerializer
-
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 1,
- 'title': 'If this is a man',
- 'author': 'Primo Levi'
- }, {
- 'id': 2,
- 'title': 'The wind-up bird chronicle',
- 'author': 'Haruki Murakami'
- }
- ]
-
- for item in data:
- book = Book(item['id'], item['title'], item['author'])
- book.save()
-
- def books(self):
- """
- Return all the objects in the mock storage backend.
- """
- return self.Book.object_map.values()
-
- def test_bulk_update_success(self):
- """
- Correct bulk update serialization should return the input data.
- """
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 2,
- 'title': 'Kafka on the shore',
- 'author': 'Haruki Murakami'
- }
- ]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.data, data)
- serializer.save()
- new_data = self.BookSerializer(self.books(), many=True).data
-
- self.assertEqual(data, new_data)
-
- def test_bulk_update_and_create(self):
- """
- Bulk update serialization may also include created items.
- """
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 3,
- 'title': 'Kafka on the shore',
- 'author': 'Haruki Murakami'
- }
- ]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.data, data)
- serializer.save()
- new_data = self.BookSerializer(self.books(), many=True).data
- self.assertEqual(data, new_data)
-
- def test_bulk_update_invalid_create(self):
- """
- Bulk update serialization without allow_add_remove may not create items.
- """
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 3,
- 'title': 'Kafka on the shore',
- 'author': 'Haruki Murakami'
- }
- ]
- expected_errors = [
- {},
- {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}
- ]
- serializer = self.BookSerializer(self.books(), data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, expected_errors)
-
- def test_bulk_update_error(self):
- """
- Incorrect bulk update serialization should return error data.
- """
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 'foo',
- 'title': 'Kafka on the shore',
- 'author': 'Haruki Murakami'
- }
- ]
- expected_errors = [
- {},
- {'id': ['Enter a whole number.']}
- ]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, expected_errors)
diff --git a/rest_framework/tests/test_serializer_empty.py b/rest_framework/tests/test_serializer_empty.py
deleted file mode 100644
index 30cff361..00000000
--- a/rest_framework/tests/test_serializer_empty.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from django.test import TestCase
-from rest_framework import serializers
-
-
-class EmptySerializerTestCase(TestCase):
- def test_empty_serializer(self):
- class FooBarSerializer(serializers.Serializer):
- foo = serializers.IntegerField()
- bar = serializers.SerializerMethodField('get_bar')
-
- def get_bar(self, obj):
- return 'bar'
-
- serializer = FooBarSerializer()
- self.assertEquals(serializer.data, {'foo': 0})
diff --git a/rest_framework/tests/test_serializer_import.py b/rest_framework/tests/test_serializer_import.py
deleted file mode 100644
index 9f30a7ff..00000000
--- a/rest_framework/tests/test_serializer_import.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from django.test import TestCase
-
-from rest_framework import serializers
-from rest_framework.tests.accounts.serializers import AccountSerializer
-
-
-class ImportingModelSerializerTests(TestCase):
- """
- In some situations like, GH #1225, it is possible, especially in
- testing, to import a serializer who's related models have not yet
- been resolved by Django. `AccountSerializer` is an example of such
- a serializer (imported at the top of this file).
- """
- def test_import_model_serializer(self):
- """
- The serializer at the top of this file should have been
- imported successfully, and we should be able to instantiate it.
- """
- self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer)
diff --git a/rest_framework/tests/test_serializer_nested.py b/rest_framework/tests/test_serializer_nested.py
deleted file mode 100644
index 6d69ffbd..00000000
--- a/rest_framework/tests/test_serializer_nested.py
+++ /dev/null
@@ -1,347 +0,0 @@
-"""
-Tests to cover nested serializers.
-
-Doesn't cover model serializers.
-"""
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import serializers
-from . import models
-
-
-class WritableNestedSerializerBasicTests(TestCase):
- """
- Tests for deserializing nested entities.
- Basic tests that use serializers that simply restore to dicts.
- """
-
- def setUp(self):
- class TrackSerializer(serializers.Serializer):
- order = serializers.IntegerField()
- title = serializers.CharField(max_length=100)
- duration = serializers.IntegerField()
-
- class AlbumSerializer(serializers.Serializer):
- album_name = serializers.CharField(max_length=100)
- artist = serializers.CharField(max_length=100)
- tracks = TrackSerializer(many=True)
-
- self.AlbumSerializer = AlbumSerializer
-
- def test_nested_validation_success(self):
- """
- Correct nested serialization should return the input data.
- """
-
- data = {
- 'album_name': 'Discovery',
- 'artist': 'Daft Punk',
- 'tracks': [
- {'order': 1, 'title': 'One More Time', 'duration': 235},
- {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
- {'order': 3, 'title': 'Digital Love', 'duration': 239}
- ]
- }
-
- serializer = self.AlbumSerializer(data=data)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, data)
-
- def test_nested_validation_error(self):
- """
- Incorrect nested serialization should return appropriate error data.
- """
-
- data = {
- 'album_name': 'Discovery',
- 'artist': 'Daft Punk',
- 'tracks': [
- {'order': 1, 'title': 'One More Time', 'duration': 235},
- {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
- {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
- ]
- }
- expected_errors = {
- 'tracks': [
- {},
- {},
- {'duration': ['Enter a whole number.']}
- ]
- }
-
- serializer = self.AlbumSerializer(data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, expected_errors)
-
- def test_many_nested_validation_error(self):
- """
- Incorrect nested serialization should return appropriate error data
- when multiple entities are being deserialized.
- """
-
- data = [
- {
- 'album_name': 'Russian Red',
- 'artist': 'I Love Your Glasses',
- 'tracks': [
- {'order': 1, 'title': 'Cigarettes', 'duration': 121},
- {'order': 2, 'title': 'No Past Land', 'duration': 198},
- {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
- ]
- },
- {
- 'album_name': 'Discovery',
- 'artist': 'Daft Punk',
- 'tracks': [
- {'order': 1, 'title': 'One More Time', 'duration': 235},
- {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
- {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
- ]
- }
- ]
- expected_errors = [
- {},
- {
- 'tracks': [
- {},
- {},
- {'duration': ['Enter a whole number.']}
- ]
- }
- ]
-
- serializer = self.AlbumSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, expected_errors)
-
-
-class WritableNestedSerializerObjectTests(TestCase):
- """
- Tests for deserializing nested entities.
- These tests use serializers that restore to concrete objects.
- """
-
- def setUp(self):
- # Couple of concrete objects that we're going to deserialize into
- class Track(object):
- def __init__(self, order, title, duration):
- self.order, self.title, self.duration = order, title, duration
-
- def __eq__(self, other):
- return (
- self.order == other.order and
- self.title == other.title and
- self.duration == other.duration
- )
-
- class Album(object):
- def __init__(self, album_name, artist, tracks):
- self.album_name, self.artist, self.tracks = album_name, artist, tracks
-
- def __eq__(self, other):
- return (
- self.album_name == other.album_name and
- self.artist == other.artist and
- self.tracks == other.tracks
- )
-
- # And their corresponding serializers
- class TrackSerializer(serializers.Serializer):
- order = serializers.IntegerField()
- title = serializers.CharField(max_length=100)
- duration = serializers.IntegerField()
-
- def restore_object(self, attrs, instance=None):
- return Track(attrs['order'], attrs['title'], attrs['duration'])
-
- class AlbumSerializer(serializers.Serializer):
- album_name = serializers.CharField(max_length=100)
- artist = serializers.CharField(max_length=100)
- tracks = TrackSerializer(many=True)
-
- def restore_object(self, attrs, instance=None):
- return Album(attrs['album_name'], attrs['artist'], attrs['tracks'])
-
- self.Album, self.Track = Album, Track
- self.AlbumSerializer = AlbumSerializer
-
- def test_nested_validation_success(self):
- """
- Correct nested serialization should return a restored object
- that corresponds to the input data.
- """
-
- data = {
- 'album_name': 'Discovery',
- 'artist': 'Daft Punk',
- 'tracks': [
- {'order': 1, 'title': 'One More Time', 'duration': 235},
- {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
- {'order': 3, 'title': 'Digital Love', 'duration': 239}
- ]
- }
- expected_object = self.Album(
- album_name='Discovery',
- artist='Daft Punk',
- tracks=[
- self.Track(order=1, title='One More Time', duration=235),
- self.Track(order=2, title='Aerodynamic', duration=184),
- self.Track(order=3, title='Digital Love', duration=239),
- ]
- )
-
- serializer = self.AlbumSerializer(data=data)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected_object)
-
- def test_many_nested_validation_success(self):
- """
- Correct nested serialization should return multiple restored objects
- that corresponds to the input data when multiple objects are
- being deserialized.
- """
-
- data = [
- {
- 'album_name': 'Russian Red',
- 'artist': 'I Love Your Glasses',
- 'tracks': [
- {'order': 1, 'title': 'Cigarettes', 'duration': 121},
- {'order': 2, 'title': 'No Past Land', 'duration': 198},
- {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
- ]
- },
- {
- 'album_name': 'Discovery',
- 'artist': 'Daft Punk',
- 'tracks': [
- {'order': 1, 'title': 'One More Time', 'duration': 235},
- {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
- {'order': 3, 'title': 'Digital Love', 'duration': 239}
- ]
- }
- ]
- expected_object = [
- self.Album(
- album_name='Russian Red',
- artist='I Love Your Glasses',
- tracks=[
- self.Track(order=1, title='Cigarettes', duration=121),
- self.Track(order=2, title='No Past Land', duration=198),
- self.Track(order=3, title='They Don\'t Believe', duration=191),
- ]
- ),
- self.Album(
- album_name='Discovery',
- artist='Daft Punk',
- tracks=[
- self.Track(order=1, title='One More Time', duration=235),
- self.Track(order=2, title='Aerodynamic', duration=184),
- self.Track(order=3, title='Digital Love', duration=239),
- ]
- )
- ]
-
- serializer = self.AlbumSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected_object)
-
-
-class ForeignKeyNestedSerializerUpdateTests(TestCase):
- def setUp(self):
- class Artist(object):
- def __init__(self, name):
- self.name = name
-
- def __eq__(self, other):
- return self.name == other.name
-
- class Album(object):
- def __init__(self, name, artist):
- self.name, self.artist = name, artist
-
- def __eq__(self, other):
- return self.name == other.name and self.artist == other.artist
-
- class ArtistSerializer(serializers.Serializer):
- name = serializers.CharField()
-
- def restore_object(self, attrs, instance=None):
- if instance:
- instance.name = attrs['name']
- else:
- instance = Artist(attrs['name'])
- return instance
-
- class AlbumSerializer(serializers.Serializer):
- name = serializers.CharField()
- by = ArtistSerializer(source='artist')
-
- def restore_object(self, attrs, instance=None):
- if instance:
- instance.name = attrs['name']
- instance.artist = attrs['artist']
- else:
- instance = Album(attrs['name'], attrs['artist'])
- return instance
-
- self.Artist = Artist
- self.Album = Album
- self.AlbumSerializer = AlbumSerializer
-
- def test_create_via_foreign_key_with_source(self):
- """
- Check that we can both *create* and *update* into objects across
- ForeignKeys that have a `source` specified.
- Regression test for #1170
- """
- data = {
- 'name': 'Discovery',
- 'by': {'name': 'Daft Punk'},
- }
-
- expected = self.Album(artist=self.Artist('Daft Punk'), name='Discovery')
-
- # create
- serializer = self.AlbumSerializer(data=data)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected)
-
- # update
- original = self.Album(artist=self.Artist('The Bats'), name='Free All the Monsters')
- serializer = self.AlbumSerializer(instance=original, data=data)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected)
-
-
-class NestedModelSerializerUpdateTests(TestCase):
- def test_second_nested_level(self):
- john = models.Person.objects.create(name="john")
-
- post = john.blogpost_set.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I hate this blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- class BlogPostCommentSerializer(serializers.ModelSerializer):
- class Meta:
- model = models.BlogPostComment
-
- class BlogPostSerializer(serializers.ModelSerializer):
- comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set')
- class Meta:
- model = models.BlogPost
- fields = ('id', 'title', 'comments')
-
- class PersonSerializer(serializers.ModelSerializer):
- posts = BlogPostSerializer(many=True, source='blogpost_set')
- class Meta:
- model = models.Person
- fields = ('id', 'name', 'age', 'posts')
-
- serialize = PersonSerializer(instance=john)
- deserialize = PersonSerializer(data=serialize.data, instance=john)
- self.assertTrue(deserialize.is_valid())
-
- result = deserialize.object
- result.save()
- self.assertEqual(result.id, john.id)
diff --git a/rest_framework/tests/test_serializers.py b/rest_framework/tests/test_serializers.py
deleted file mode 100644
index 082a400c..00000000
--- a/rest_framework/tests/test_serializers.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from django.db import models
-from django.test import TestCase
-
-from rest_framework.serializers import _resolve_model
-from rest_framework.tests.models import BasicModel
-
-
-class ResolveModelTests(TestCase):
- """
- `_resolve_model` should return a Django model class given the
- provided argument is a Django model class itself, or a properly
- formatted string representation of one.
- """
- def test_resolve_django_model(self):
- resolved_model = _resolve_model(BasicModel)
- self.assertEqual(resolved_model, BasicModel)
-
- def test_resolve_string_representation(self):
- resolved_model = _resolve_model('tests.BasicModel')
- self.assertEqual(resolved_model, BasicModel)
-
- def test_resolve_non_django_model(self):
- with self.assertRaises(ValueError):
- _resolve_model(TestCase)
-
- def test_resolve_improper_string_representation(self):
- with self.assertRaises(ValueError):
- _resolve_model('BasicModel')
diff --git a/rest_framework/tests/test_settings.py b/rest_framework/tests/test_settings.py
deleted file mode 100644
index 857375c2..00000000
--- a/rest_framework/tests/test_settings.py
+++ /dev/null
@@ -1,22 +0,0 @@
-"""Tests for the settings module"""
-from __future__ import unicode_literals
-from django.test import TestCase
-
-from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS
-
-
-class TestSettings(TestCase):
- """Tests relating to the api settings"""
-
- def test_non_import_errors(self):
- """Make sure other errors aren't suppressed."""
- settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
- with self.assertRaises(ValueError):
- settings.DEFAULT_MODEL_SERIALIZER_CLASS
-
- def test_import_error_message_maintained(self):
- """Make sure real import errors are captured and raised sensibly."""
- settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
- with self.assertRaises(ImportError) as cm:
- settings.DEFAULT_MODEL_SERIALIZER_CLASS
- self.assertTrue('ImportError' in str(cm.exception))
diff --git a/rest_framework/tests/test_status.py b/rest_framework/tests/test_status.py
deleted file mode 100644
index 7b1bdae3..00000000
--- a/rest_framework/tests/test_status.py
+++ /dev/null
@@ -1,33 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.status import (
- is_informational, is_success, is_redirect, is_client_error, is_server_error
-)
-
-
-class TestStatus(TestCase):
- def test_status_categories(self):
- self.assertFalse(is_informational(99))
- self.assertTrue(is_informational(100))
- self.assertTrue(is_informational(199))
- self.assertFalse(is_informational(200))
-
- self.assertFalse(is_success(199))
- self.assertTrue(is_success(200))
- self.assertTrue(is_success(299))
- self.assertFalse(is_success(300))
-
- self.assertFalse(is_redirect(299))
- self.assertTrue(is_redirect(300))
- self.assertTrue(is_redirect(399))
- self.assertFalse(is_redirect(400))
-
- self.assertFalse(is_client_error(399))
- self.assertTrue(is_client_error(400))
- self.assertTrue(is_client_error(499))
- self.assertFalse(is_client_error(500))
-
- self.assertFalse(is_server_error(499))
- self.assertTrue(is_server_error(500))
- self.assertTrue(is_server_error(599))
- self.assertFalse(is_server_error(600)) \ No newline at end of file
diff --git a/rest_framework/tests/test_templatetags.py b/rest_framework/tests/test_templatetags.py
deleted file mode 100644
index d4da0c23..00000000
--- a/rest_framework/tests/test_templatetags.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# encoding: utf-8
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.test import APIRequestFactory
-from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links
-
-factory = APIRequestFactory()
-
-
-class TemplateTagTests(TestCase):
-
- def test_add_query_param_with_non_latin_charactor(self):
- # Ensure we don't double-escape non-latin characters
- # that are present in the querystring.
- # See #1314.
- request = factory.get("/", {'q': '查询'})
- json_url = add_query_param(request, "format", "json")
- self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url)
- self.assertIn("format=json", json_url)
-
-
-class Issue1386Tests(TestCase):
- """
- Covers #1386
- """
-
- def test_issue_1386(self):
- """
- Test function urlize_quoted_links with different args
- """
- correct_urls = [
- "asdf.com",
- "asdf.net",
- "www.as_df.org",
- "as.d8f.ghj8.gov",
- ]
- for i in correct_urls:
- res = urlize_quoted_links(i)
- self.assertNotEqual(res, i)
- self.assertIn(i, res)
-
- incorrect_urls = [
- "mailto://asdf@fdf.com",
- "asdf.netnet",
- ]
- for i in incorrect_urls:
- res = urlize_quoted_links(i)
- self.assertEqual(i, res)
-
- # example from issue #1386, this shouldn't raise an exception
- _ = urlize_quoted_links("asdf:[/p]zxcv.com")
diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py
deleted file mode 100644
index a55d4b22..00000000
--- a/rest_framework/tests/test_testing.py
+++ /dev/null
@@ -1,164 +0,0 @@
-# -- coding: utf-8 --
-
-from __future__ import unicode_literals
-from io import BytesIO
-
-from django.contrib.auth.models import User
-from django.test import TestCase
-from rest_framework.compat import patterns, url
-from rest_framework.decorators import api_view
-from rest_framework.response import Response
-from rest_framework.test import APIClient, APIRequestFactory, force_authenticate
-
-
-@api_view(['GET', 'POST'])
-def view(request):
- return Response({
- 'auth': request.META.get('HTTP_AUTHORIZATION', b''),
- 'user': request.user.username
- })
-
-
-@api_view(['GET', 'POST'])
-def session_view(request):
- active_session = request.session.get('active_session', False)
- request.session['active_session'] = True
- return Response({
- 'active_session': active_session
- })
-
-
-urlpatterns = patterns('',
- url(r'^view/$', view),
- url(r'^session-view/$', session_view),
-)
-
-
-class TestAPITestClient(TestCase):
- urls = 'rest_framework.tests.test_testing'
-
- def setUp(self):
- self.client = APIClient()
-
- def test_credentials(self):
- """
- Setting `.credentials()` adds the required headers to each request.
- """
- self.client.credentials(HTTP_AUTHORIZATION='example')
- for _ in range(0, 3):
- response = self.client.get('/view/')
- self.assertEqual(response.data['auth'], 'example')
-
- def test_force_authenticate(self):
- """
- Setting `.force_authenticate()` forcibly authenticates each request.
- """
- user = User.objects.create_user('example', 'example@example.com')
- self.client.force_authenticate(user)
- response = self.client.get('/view/')
- self.assertEqual(response.data['user'], 'example')
-
- def test_force_authenticate_with_sessions(self):
- """
- Setting `.force_authenticate()` forcibly authenticates each request.
- """
- user = User.objects.create_user('example', 'example@example.com')
- self.client.force_authenticate(user)
-
- # First request does not yet have an active session
- response = self.client.get('/session-view/')
- self.assertEqual(response.data['active_session'], False)
-
- # Subsequant requests have an active session
- response = self.client.get('/session-view/')
- self.assertEqual(response.data['active_session'], True)
-
- # Force authenticating as `None` should also logout the user session.
- self.client.force_authenticate(None)
- response = self.client.get('/session-view/')
- self.assertEqual(response.data['active_session'], False)
-
- def test_csrf_exempt_by_default(self):
- """
- By default, the test client is CSRF exempt.
- """
- User.objects.create_user('example', 'example@example.com', 'password')
- self.client.login(username='example', password='password')
- response = self.client.post('/view/')
- self.assertEqual(response.status_code, 200)
-
- def test_explicitly_enforce_csrf_checks(self):
- """
- The test client can enforce CSRF checks.
- """
- client = APIClient(enforce_csrf_checks=True)
- User.objects.create_user('example', 'example@example.com', 'password')
- client.login(username='example', password='password')
- response = client.post('/view/')
- expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
- self.assertEqual(response.status_code, 403)
- self.assertEqual(response.data, expected)
-
-
-class TestAPIRequestFactory(TestCase):
- def test_csrf_exempt_by_default(self):
- """
- By default, the test client is CSRF exempt.
- """
- user = User.objects.create_user('example', 'example@example.com', 'password')
- factory = APIRequestFactory()
- request = factory.post('/view/')
- request.user = user
- response = view(request)
- self.assertEqual(response.status_code, 200)
-
- def test_explicitly_enforce_csrf_checks(self):
- """
- The test client can enforce CSRF checks.
- """
- user = User.objects.create_user('example', 'example@example.com', 'password')
- factory = APIRequestFactory(enforce_csrf_checks=True)
- request = factory.post('/view/')
- request.user = user
- response = view(request)
- expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
- self.assertEqual(response.status_code, 403)
- self.assertEqual(response.data, expected)
-
- def test_invalid_format(self):
- """
- Attempting to use a format that is not configured will raise an
- assertion error.
- """
- factory = APIRequestFactory()
- self.assertRaises(AssertionError, factory.post,
- path='/view/', data={'example': 1}, format='xml'
- )
-
- def test_force_authenticate(self):
- """
- Setting `force_authenticate()` forcibly authenticates the request.
- """
- user = User.objects.create_user('example', 'example@example.com')
- factory = APIRequestFactory()
- request = factory.get('/view')
- force_authenticate(request, user=user)
- response = view(request)
- self.assertEqual(response.data['user'], 'example')
-
- def test_upload_file(self):
- # This is a 1x1 black png
- simple_png = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82')
- simple_png.name = 'test.png'
- factory = APIRequestFactory()
- factory.post('/', data={'image': simple_png})
-
- def test_request_factory_url_arguments(self):
- """
- This is a non regression test against #1461
- """
- factory = APIRequestFactory()
- request = factory.get('/view/?demo=test')
- self.assertEqual(dict(request.GET), {'demo': ['test']})
- request = factory.get('/view/', {'demo': 'test'})
- self.assertEqual(dict(request.GET), {'demo': ['test']})
diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py
deleted file mode 100644
index b5ae02cd..00000000
--- a/rest_framework/tests/test_throttling.py
+++ /dev/null
@@ -1,281 +0,0 @@
-"""
-Tests for the throttling implementations in the permissions module.
-"""
-from __future__ import unicode_literals
-from django.test import TestCase
-from django.contrib.auth.models import User
-from django.core.cache import cache
-from rest_framework.test import APIRequestFactory
-from rest_framework.views import APIView
-from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle
-from rest_framework.response import Response
-
-
-class User3SecRateThrottle(UserRateThrottle):
- rate = '3/sec'
- scope = 'seconds'
-
-
-class User3MinRateThrottle(UserRateThrottle):
- rate = '3/min'
- scope = 'minutes'
-
-
-class NonTimeThrottle(BaseThrottle):
- def allow_request(self, request, view):
- if not hasattr(self.__class__, 'called'):
- self.__class__.called = True
- return True
- return False
-
-
-class MockView(APIView):
- throttle_classes = (User3SecRateThrottle,)
-
- def get(self, request):
- return Response('foo')
-
-
-class MockView_MinuteThrottling(APIView):
- throttle_classes = (User3MinRateThrottle,)
-
- def get(self, request):
- return Response('foo')
-
-
-class MockView_NonTimeThrottling(APIView):
- throttle_classes = (NonTimeThrottle,)
-
- def get(self, request):
- return Response('foo')
-
-
-class ThrottlingTests(TestCase):
- def setUp(self):
- """
- Reset the cache so that no throttles will be active
- """
- cache.clear()
- self.factory = APIRequestFactory()
-
- def test_requests_are_throttled(self):
- """
- Ensure request rate is limited
- """
- request = self.factory.get('/')
- for dummy in range(4):
- response = MockView.as_view()(request)
- self.assertEqual(429, response.status_code)
-
- def set_throttle_timer(self, view, value):
- """
- Explicitly set the timer, overriding time.time()
- """
- view.throttle_classes[0].timer = lambda self: value
-
- def test_request_throttling_expires(self):
- """
- Ensure request rate is limited for a limited duration only
- """
- self.set_throttle_timer(MockView, 0)
-
- request = self.factory.get('/')
- for dummy in range(4):
- response = MockView.as_view()(request)
- self.assertEqual(429, response.status_code)
-
- # Advance the timer by one second
- self.set_throttle_timer(MockView, 1)
-
- response = MockView.as_view()(request)
- self.assertEqual(200, response.status_code)
-
- def ensure_is_throttled(self, view, expect):
- request = self.factory.get('/')
- request.user = User.objects.create(username='a')
- for dummy in range(3):
- view.as_view()(request)
- request.user = User.objects.create(username='b')
- response = view.as_view()(request)
- self.assertEqual(expect, response.status_code)
-
- def test_request_throttling_is_per_user(self):
- """
- Ensure request rate is only limited per user, not globally for
- PerUserThrottles
- """
- self.ensure_is_throttled(MockView, 200)
-
- def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
- """
- Ensure the response returns an X-Throttle field with status and next attributes
- set properly.
- """
- request = self.factory.get('/')
- for timer, expect in expected_headers:
- self.set_throttle_timer(view, timer)
- response = view.as_view()(request)
- if expect is not None:
- self.assertEqual(response['X-Throttle-Wait-Seconds'], expect)
- self.assertEqual(response['Retry-After'], expect)
- else:
- self.assertFalse('X-Throttle-Wait-Seconds' in response)
- self.assertFalse('Retry-After' in response)
-
- def test_seconds_fields(self):
- """
- Ensure for second based throttles.
- """
- self.ensure_response_header_contains_proper_throttle_field(MockView,
- ((0, None),
- (0, None),
- (0, None),
- (0, '1')
- ))
-
- def test_minutes_fields(self):
- """
- Ensure for minute based throttles.
- """
- self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
- ((0, None),
- (0, None),
- (0, None),
- (0, '60')
- ))
-
- def test_next_rate_remains_constant_if_followed(self):
- """
- If a client follows the recommended next request rate,
- the throttling rate should stay constant.
- """
- self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
- ((0, None),
- (20, None),
- (40, None),
- (60, None),
- (80, None)
- ))
-
- def test_non_time_throttle(self):
- """
- Ensure for second based throttles.
- """
- request = self.factory.get('/')
-
- self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))
-
- response = MockView_NonTimeThrottling.as_view()(request)
- self.assertFalse('X-Throttle-Wait-Seconds' in response)
- self.assertFalse('Retry-After' in response)
-
- self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
-
- response = MockView_NonTimeThrottling.as_view()(request)
- self.assertFalse('X-Throttle-Wait-Seconds' in response)
- self.assertFalse('Retry-After' in response)
-
-
-class ScopedRateThrottleTests(TestCase):
- """
- Tests for ScopedRateThrottle.
- """
-
- def setUp(self):
- class XYScopedRateThrottle(ScopedRateThrottle):
- TIMER_SECONDS = 0
- THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
- timer = lambda self: self.TIMER_SECONDS
-
- class XView(APIView):
- throttle_classes = (XYScopedRateThrottle,)
- throttle_scope = 'x'
-
- def get(self, request):
- return Response('x')
-
- class YView(APIView):
- throttle_classes = (XYScopedRateThrottle,)
- throttle_scope = 'y'
-
- def get(self, request):
- return Response('y')
-
- class UnscopedView(APIView):
- throttle_classes = (XYScopedRateThrottle,)
-
- def get(self, request):
- return Response('y')
-
- self.throttle_class = XYScopedRateThrottle
- self.factory = APIRequestFactory()
- self.x_view = XView.as_view()
- self.y_view = YView.as_view()
- self.unscoped_view = UnscopedView.as_view()
-
- def increment_timer(self, seconds=1):
- self.throttle_class.TIMER_SECONDS += seconds
-
- def test_scoped_rate_throttle(self):
- request = self.factory.get('/')
-
- # Should be able to hit x view 3 times per minute.
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(429, response.status_code)
-
- # Should be able to hit y view 1 time per minute.
- self.increment_timer()
- response = self.y_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.y_view(request)
- self.assertEqual(429, response.status_code)
-
- # Ensure throttles properly reset by advancing the rest of the minute
- self.increment_timer(55)
-
- # Should still be able to hit x view 3 times per minute.
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(429, response.status_code)
-
- # Should still be able to hit y view 1 time per minute.
- self.increment_timer()
- response = self.y_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.y_view(request)
- self.assertEqual(429, response.status_code)
-
- def test_unscoped_view_not_throttled(self):
- request = self.factory.get('/')
-
- for idx in range(10):
- self.increment_timer()
- response = self.unscoped_view(request)
- self.assertEqual(200, response.status_code)
diff --git a/rest_framework/tests/test_urlpatterns.py b/rest_framework/tests/test_urlpatterns.py
deleted file mode 100644
index 8132ec4c..00000000
--- a/rest_framework/tests/test_urlpatterns.py
+++ /dev/null
@@ -1,76 +0,0 @@
-from __future__ import unicode_literals
-from collections import namedtuple
-from django.core import urlresolvers
-from django.test import TestCase
-from rest_framework.test import APIRequestFactory
-from rest_framework.compat import patterns, url, include
-from rest_framework.urlpatterns import format_suffix_patterns
-
-
-# A container class for test paths for the test case
-URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs'])
-
-
-def dummy_view(request, *args, **kwargs):
- pass
-
-
-class FormatSuffixTests(TestCase):
- """
- Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
- """
- def _resolve_urlpatterns(self, urlpatterns, test_paths):
- factory = APIRequestFactory()
- try:
- urlpatterns = format_suffix_patterns(urlpatterns)
- except Exception:
- self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
- resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
- for test_path in test_paths:
- request = factory.get(test_path.path)
- try:
- callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
- except Exception:
- self.fail("Failed to resolve URL: %s" % request.path_info)
- self.assertEqual(callback_args, test_path.args)
- self.assertEqual(callback_kwargs, test_path.kwargs)
-
- def test_format_suffix(self):
- urlpatterns = patterns(
- '',
- url(r'^test$', dummy_view),
- )
- test_paths = [
- URLTestPath('/test', (), {}),
- URLTestPath('/test.api', (), {'format': 'api'}),
- URLTestPath('/test.asdf', (), {'format': 'asdf'}),
- ]
- self._resolve_urlpatterns(urlpatterns, test_paths)
-
- def test_default_args(self):
- urlpatterns = patterns(
- '',
- url(r'^test$', dummy_view, {'foo': 'bar'}),
- )
- test_paths = [
- URLTestPath('/test', (), {'foo': 'bar', }),
- URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
- URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
- ]
- self._resolve_urlpatterns(urlpatterns, test_paths)
-
- def test_included_urls(self):
- nested_patterns = patterns(
- '',
- url(r'^path$', dummy_view)
- )
- urlpatterns = patterns(
- '',
- url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
- )
- test_paths = [
- URLTestPath('/test/path', (), {'foo': 'bar', }),
- URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
- URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
- ]
- self._resolve_urlpatterns(urlpatterns, test_paths)
diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py
deleted file mode 100644
index e13e4078..00000000
--- a/rest_framework/tests/test_validation.py
+++ /dev/null
@@ -1,148 +0,0 @@
-from __future__ import unicode_literals
-from django.core.validators import MaxValueValidator
-from django.db import models
-from django.test import TestCase
-from rest_framework import generics, serializers, status
-from rest_framework.test import APIRequestFactory
-
-factory = APIRequestFactory()
-
-
-# Regression for #666
-
-class ValidationModel(models.Model):
- blank_validated_field = models.CharField(max_length=255)
-
-
-class ValidationModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = ValidationModel
- fields = ('blank_validated_field',)
- read_only_fields = ('blank_validated_field',)
-
-
-class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
- model = ValidationModel
- serializer_class = ValidationModelSerializer
-
-
-class TestPreSaveValidationExclusions(TestCase):
- def test_pre_save_validation_exclusions(self):
- """
- Somewhat weird test case to ensure that we don't perform model
- validation on read only fields.
- """
- obj = ValidationModel.objects.create(blank_validated_field='')
- request = factory.put('/', {}, format='json')
- view = UpdateValidationModel().as_view()
- response = view(request, pk=obj.pk).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
-
-# Regression for #653
-
-class ShouldValidateModel(models.Model):
- should_validate_field = models.CharField(max_length=255)
-
-
-class ShouldValidateModelSerializer(serializers.ModelSerializer):
- renamed = serializers.CharField(source='should_validate_field', required=False)
-
- def validate_renamed(self, attrs, source):
- value = attrs[source]
- if len(value) < 3:
- raise serializers.ValidationError('Minimum 3 characters.')
- return attrs
-
- class Meta:
- model = ShouldValidateModel
- fields = ('renamed',)
-
-
-class TestPreSaveValidationExclusionsSerializer(TestCase):
- def test_renamed_fields_are_model_validated(self):
- """
- Ensure fields with 'source' applied do get still get model validation.
- """
- # We've set `required=False` on the serializer, but the model
- # does not have `blank=True`, so this serializer should not validate.
- serializer = ShouldValidateModelSerializer(data={'renamed': ''})
- self.assertEqual(serializer.is_valid(), False)
- self.assertIn('renamed', serializer.errors)
- self.assertNotIn('should_validate_field', serializer.errors)
-
-
-class TestCustomValidationMethods(TestCase):
- def test_custom_validation_method_is_executed(self):
- serializer = ShouldValidateModelSerializer(data={'renamed': 'fo'})
- self.assertFalse(serializer.is_valid())
- self.assertIn('renamed', serializer.errors)
-
- def test_custom_validation_method_passing(self):
- serializer = ShouldValidateModelSerializer(data={'renamed': 'foo'})
- self.assertTrue(serializer.is_valid())
-
-
-class ValidationSerializer(serializers.Serializer):
- foo = serializers.CharField()
-
- def validate_foo(self, attrs, source):
- raise serializers.ValidationError("foo invalid")
-
- def validate(self, attrs):
- raise serializers.ValidationError("serializer invalid")
-
-
-class TestAvoidValidation(TestCase):
- """
- If serializer was initialized with invalid data (None or non dict-like), it
- should avoid validation layer (validate_<field> and validate methods)
- """
- def test_serializer_errors_has_only_invalid_data_error(self):
- serializer = ValidationSerializer(data='invalid data')
- self.assertFalse(serializer.is_valid())
- self.assertDictEqual(serializer.errors,
- {'non_field_errors': ['Invalid data']})
-
-
-# regression tests for issue: 1493
-
-class ValidationMaxValueValidatorModel(models.Model):
- number_value = models.PositiveIntegerField(validators=[MaxValueValidator(100)])
-
-
-class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = ValidationMaxValueValidatorModel
-
-
-class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView):
- model = ValidationMaxValueValidatorModel
- serializer_class = ValidationMaxValueValidatorModelSerializer
-
-
-class TestMaxValueValidatorValidation(TestCase):
-
- def test_max_value_validation_serializer_success(self):
- serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 99})
- self.assertTrue(serializer.is_valid())
-
- def test_max_value_validation_serializer_fails(self):
- serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101})
- self.assertFalse(serializer.is_valid())
- self.assertDictEqual({'number_value': ['Ensure this value is less than or equal to 100.']}, serializer.errors)
-
- def test_max_value_validation_success(self):
- obj = ValidationMaxValueValidatorModel.objects.create(number_value=100)
- request = factory.patch('/{0}'.format(obj.pk), {'number_value': 98}, format='json')
- view = UpdateMaxValueValidationModel().as_view()
- response = view(request, pk=obj.pk).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_max_value_validation_fail(self):
- obj = ValidationMaxValueValidatorModel.objects.create(number_value=100)
- request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json')
- view = UpdateMaxValueValidationModel().as_view()
- response = view(request, pk=obj.pk).render()
- self.assertEqual(response.content, b'{"number_value": ["Ensure this value is less than or equal to 100."]}')
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py
deleted file mode 100644
index 65c7e50e..00000000
--- a/rest_framework/tests/test_views.py
+++ /dev/null
@@ -1,142 +0,0 @@
-from __future__ import unicode_literals
-
-import copy
-from django.test import TestCase
-from rest_framework import status
-from rest_framework.decorators import api_view
-from rest_framework.response import Response
-from rest_framework.settings import api_settings
-from rest_framework.test import APIRequestFactory
-from rest_framework.views import APIView
-
-factory = APIRequestFactory()
-
-
-class BasicView(APIView):
- def get(self, request, *args, **kwargs):
- return Response({'method': 'GET'})
-
- def post(self, request, *args, **kwargs):
- return Response({'method': 'POST', 'data': request.DATA})
-
-
-@api_view(['GET', 'POST', 'PUT', 'PATCH'])
-def basic_view(request):
- if request.method == 'GET':
- return {'method': 'GET'}
- elif request.method == 'POST':
- return {'method': 'POST', 'data': request.DATA}
- elif request.method == 'PUT':
- return {'method': 'PUT', 'data': request.DATA}
- elif request.method == 'PATCH':
- return {'method': 'PATCH', 'data': request.DATA}
-
-
-class ErrorView(APIView):
- def get(self, request, *args, **kwargs):
- raise Exception
-
-
-@api_view(['GET'])
-def error_view(request):
- raise Exception
-
-
-def sanitise_json_error(error_dict):
- """
- Exact contents of JSON error messages depend on the installed version
- of json.
- """
- ret = copy.copy(error_dict)
- chop = len('JSON parse error - No JSON object could be decoded')
- ret['detail'] = ret['detail'][:chop]
- return ret
-
-
-class ClassBasedViewIntegrationTests(TestCase):
- def setUp(self):
- self.view = BasicView.as_view()
-
- def test_400_parse_error(self):
- request = factory.post('/', 'f00bar', content_type='application/json')
- response = self.view(request)
- expected = {
- 'detail': 'JSON parse error - No JSON object could be decoded'
- }
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEqual(sanitise_json_error(response.data), expected)
-
- def test_400_parse_error_tunneled_content(self):
- content = 'f00bar'
- content_type = 'application/json'
- form_data = {
- api_settings.FORM_CONTENT_OVERRIDE: content,
- api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
- }
- request = factory.post('/', form_data)
- response = self.view(request)
- expected = {
- 'detail': 'JSON parse error - No JSON object could be decoded'
- }
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEqual(sanitise_json_error(response.data), expected)
-
-
-class FunctionBasedViewIntegrationTests(TestCase):
- def setUp(self):
- self.view = basic_view
-
- def test_400_parse_error(self):
- request = factory.post('/', 'f00bar', content_type='application/json')
- response = self.view(request)
- expected = {
- 'detail': 'JSON parse error - No JSON object could be decoded'
- }
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEqual(sanitise_json_error(response.data), expected)
-
- def test_400_parse_error_tunneled_content(self):
- content = 'f00bar'
- content_type = 'application/json'
- form_data = {
- api_settings.FORM_CONTENT_OVERRIDE: content,
- api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
- }
- request = factory.post('/', form_data)
- response = self.view(request)
- expected = {
- 'detail': 'JSON parse error - No JSON object could be decoded'
- }
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEqual(sanitise_json_error(response.data), expected)
-
-
-class TestCustomExceptionHandler(TestCase):
- def setUp(self):
- self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
-
- def exception_handler(exc):
- return Response('Error!', status=status.HTTP_400_BAD_REQUEST)
-
- api_settings.EXCEPTION_HANDLER = exception_handler
-
- def tearDown(self):
- api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
-
- def test_class_based_view_exception_handler(self):
- view = ErrorView.as_view()
-
- request = factory.get('/', content_type='application/json')
- response = view(request)
- expected = 'Error!'
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEqual(response.data, expected)
-
- def test_function_based_view_exception_handler(self):
- view = error_view
-
- request = factory.get('/', content_type='application/json')
- response = view(request)
- expected = 'Error!'
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEqual(response.data, expected)
diff --git a/rest_framework/tests/test_write_only_fields.py b/rest_framework/tests/test_write_only_fields.py
deleted file mode 100644
index aabb18d6..00000000
--- a/rest_framework/tests/test_write_only_fields.py
+++ /dev/null
@@ -1,42 +0,0 @@
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-
-
-class ExampleModel(models.Model):
- email = models.EmailField(max_length=100)
- password = models.CharField(max_length=100)
-
-
-class WriteOnlyFieldTests(TestCase):
- def test_write_only_fields(self):
- class ExampleSerializer(serializers.Serializer):
- email = serializers.EmailField()
- password = serializers.CharField(write_only=True)
-
- data = {
- 'email': 'foo@example.com',
- 'password': '123'
- }
- serializer = ExampleSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.object, data)
- self.assertEquals(serializer.data, {'email': 'foo@example.com'})
-
- def test_write_only_fields_meta(self):
- class ExampleSerializer(serializers.ModelSerializer):
- class Meta:
- model = ExampleModel
- fields = ('email', 'password')
- write_only_fields = ('password',)
-
- data = {
- 'email': 'foo@example.com',
- 'password': '123'
- }
- serializer = ExampleSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- self.assertTrue(isinstance(serializer.object, ExampleModel))
- self.assertEquals(serializer.object.email, data['email'])
- self.assertEquals(serializer.object.password, data['password'])
- self.assertEquals(serializer.data, {'email': 'foo@example.com'})
diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py
deleted file mode 100644
index 554ebd1a..00000000
--- a/rest_framework/tests/tests.py
+++ /dev/null
@@ -1,16 +0,0 @@
-"""
-Force import of all modules in this package in order to get the standard test
-runner to pick up the tests. Yowzers.
-"""
-from __future__ import unicode_literals
-import os
-import django
-
-modules = [filename.rsplit('.', 1)[0]
- for filename in os.listdir(os.path.dirname(__file__))
- if filename.endswith('.py') and not filename.startswith('_')]
-__test__ = dict()
-
-if django.VERSION < (1, 6):
- for module in modules:
- exec("from rest_framework.tests.%s import *" % module)
diff --git a/rest_framework/tests/users/__init__.py b/rest_framework/tests/users/__init__.py
deleted file mode 100644
index e69de29b..00000000
--- a/rest_framework/tests/users/__init__.py
+++ /dev/null
diff --git a/rest_framework/tests/users/models.py b/rest_framework/tests/users/models.py
deleted file mode 100644
index 128bac90..00000000
--- a/rest_framework/tests/users/models.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from django.db import models
-
-
-class User(models.Model):
- account = models.ForeignKey('accounts.Account', blank=True, null=True, related_name='users')
- active_record = models.ForeignKey('records.Record', blank=True, null=True)
diff --git a/rest_framework/tests/users/serializers.py b/rest_framework/tests/users/serializers.py
deleted file mode 100644
index da496554..00000000
--- a/rest_framework/tests/users/serializers.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from rest_framework import serializers
-
-from rest_framework.tests.users.models import User
-
-
-class UserSerializer(serializers.ModelSerializer):
- class Meta:
- model = User
diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py
deleted file mode 100644
index a8f2eb0b..00000000
--- a/rest_framework/tests/utils.py
+++ /dev/null
@@ -1,25 +0,0 @@
-from contextlib import contextmanager
-from rest_framework.compat import six
-from rest_framework.settings import api_settings
-
-
-@contextmanager
-def temporary_setting(setting, value, module=None):
- """
- Temporarily change value of setting for test.
-
- Optionally reload given module, useful when module uses value of setting on
- import.
- """
- original_value = getattr(api_settings, setting)
- setattr(api_settings, setting, value)
-
- if module is not None:
- six.moves.reload_module(module)
-
- yield
-
- setattr(api_settings, setting, original_value)
-
- if module is not None:
- six.moves.reload_module(module)
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py
deleted file mode 100644
index 3917b74a..00000000
--- a/rest_framework/tests/views.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from rest_framework import generics
-from rest_framework.tests.models import NullableForeignKeySource
-from rest_framework.tests.serializers import NullableFKSourceSerializer
-
-
-class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView):
- model = NullableForeignKeySource
- model_serializer_class = NullableFKSourceSerializer
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index efa9fb94..361dbddf 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -18,6 +18,25 @@ class BaseThrottle(object):
"""
raise NotImplementedError('.allow_request() must be overridden')
+ def get_ident(self, request):
+ """
+ Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
+ if present and number of proxies is > 0. If not use all of
+ HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
+ """
+ xff = request.META.get('HTTP_X_FORWARDED_FOR')
+ remote_addr = request.META.get('REMOTE_ADDR')
+ num_proxies = api_settings.NUM_PROXIES
+
+ if num_proxies is not None:
+ if num_proxies == 0 or xff is None:
+ return remote_addr
+ addrs = xff.split(',')
+ client_addr = addrs[-min(num_proxies, len(xff))]
+ return client_addr.strip()
+
+ return xff if xff else remote_addr
+
def wait(self):
"""
Optionally, return a recommended number of seconds to wait before
@@ -41,7 +60,7 @@ class SimpleRateThrottle(BaseThrottle):
cache = default_cache
timer = time.time
- cache_format = 'throtte_%(scope)s_%(ident)s'
+ cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
@@ -157,10 +176,12 @@ class AnonRateThrottle(SimpleRateThrottle):
ident = request.META.get('HTTP_X_FORWARDED_FOR')
if ident is None:
ident = request.META.get('REMOTE_ADDR')
+ else:
+ ident = ''.join(ident.split())
return self.cache_format % {
'scope': self.scope,
- 'ident': ident
+ 'ident': self.get_ident(request)
}
@@ -178,7 +199,7 @@ class UserRateThrottle(SimpleRateThrottle):
if request.user.is_authenticated():
ident = request.user.id
else:
- ident = request.META.get('REMOTE_ADDR', None)
+ ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
@@ -226,7 +247,7 @@ class ScopedRateThrottle(SimpleRateThrottle):
if request.user.is_authenticated():
ident = request.user.id
else:
- ident = request.META.get('REMOTE_ADDR', None)
+ ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index 0ff137b0..038e9ee3 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -1,6 +1,6 @@
from __future__ import unicode_literals
+from django.conf.urls import url, include
from django.core.urlresolvers import RegexURLResolver
-from rest_framework.compat import url, include
from rest_framework.settings import api_settings
diff --git a/rest_framework/urls.py b/rest_framework/urls.py
index 9c4719f1..8fa3073e 100644
--- a/rest_framework/urls.py
+++ b/rest_framework/urls.py
@@ -2,23 +2,25 @@
Login and logout views for the browsable API.
Add these to your root URLconf if you're using the browsable API and
-your API requires authentication.
-
-The urls must be namespaced as 'rest_framework', and you should make sure
-your authentication settings include `SessionAuthentication`.
+your API requires authentication:
urlpatterns = patterns('',
...
url(r'^auth', include('rest_framework.urls', namespace='rest_framework'))
)
+
+The urls must be namespaced as 'rest_framework', and you should make sure
+your authentication settings include `SessionAuthentication`.
"""
from __future__ import unicode_literals
-from rest_framework.compat import patterns, url
+from django.conf.urls import patterns, url
+from django.contrib.auth import views
template_name = {'template_name': 'rest_framework/login.html'}
-urlpatterns = patterns('django.contrib.auth.views',
- url(r'^login/$', 'login', template_name, name='login'),
- url(r'^logout/$', 'logout', template_name, name='logout'),
+urlpatterns = patterns(
+ '',
+ url(r'^login/$', views.login, template_name, name='login'),
+ url(r'^logout/$', views.logout, template_name, name='logout')
)
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index e5fa4194..00ffdfba 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -2,10 +2,11 @@
Helper classes for parsers.
"""
from __future__ import unicode_literals
+from django.utils import timezone
from django.db.models.query import QuerySet
from django.utils.datastructures import SortedDict
from django.utils.functional import Promise
-from rest_framework.compat import timezone, force_text
+from rest_framework.compat import force_text
from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
import datetime
import decimal
@@ -97,14 +98,23 @@ else:
node.flow_style = best_style
return node
- SafeDumper.add_representer(decimal.Decimal,
- SafeDumper.represent_decimal)
-
- SafeDumper.add_representer(SortedDict,
- yaml.representer.SafeRepresenter.represent_dict)
- SafeDumper.add_representer(DictWithMetadata,
- yaml.representer.SafeRepresenter.represent_dict)
- SafeDumper.add_representer(SortedDictWithMetadata,
- yaml.representer.SafeRepresenter.represent_dict)
- SafeDumper.add_representer(types.GeneratorType,
- yaml.representer.SafeRepresenter.represent_list)
+ SafeDumper.add_representer(
+ decimal.Decimal,
+ SafeDumper.represent_decimal
+ )
+ SafeDumper.add_representer(
+ SortedDict,
+ yaml.representer.SafeRepresenter.represent_dict
+ )
+ SafeDumper.add_representer(
+ DictWithMetadata,
+ yaml.representer.SafeRepresenter.represent_dict
+ )
+ SafeDumper.add_representer(
+ SortedDictWithMetadata,
+ yaml.representer.SafeRepresenter.represent_dict
+ )
+ SafeDumper.add_representer(
+ types.GeneratorType,
+ yaml.representer.SafeRepresenter.represent_list
+ )
diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py
index 4b59ba84..6d53aed1 100644
--- a/rest_framework/utils/formatting.py
+++ b/rest_framework/utils/formatting.py
@@ -6,8 +6,6 @@ from __future__ import unicode_literals
from django.utils.html import escape
from django.utils.safestring import mark_safe
from rest_framework.compat import apply_markdown
-from rest_framework.settings import api_settings
-from textwrap import dedent
import re
@@ -40,6 +38,7 @@ def dedent(content):
return content.strip()
+
def camelcase_to_spaces(content):
"""
Translate 'CamelCaseNames' to 'Camel Case Names'.
@@ -49,6 +48,7 @@ def camelcase_to_spaces(content):
content = re.sub(camelcase_boundry, ' \\1', content).strip()
return ' '.join(content.split('_')).title()
+
def markup_description(description):
"""
Apply HTML markup to the given description.
diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py
index c09c2933..87b3cc6a 100644
--- a/rest_framework/utils/mediatypes.py
+++ b/rest_framework/utils/mediatypes.py
@@ -57,7 +57,7 @@ class _MediaType(object):
if key != 'q' and other.params.get(key, None) != self.params.get(key, None):
return False
- if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:
+ if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:
return False
if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type:
@@ -74,12 +74,12 @@ class _MediaType(object):
return 0
elif self.sub_type == '*':
return 1
- elif not self.params or self.params.keys() == ['q']:
+ elif not self.params or list(self.params.keys()) == ['q']:
return 2
return 3
def __str__(self):
- return unicode(self).encode('utf-8')
+ return self.__unicode__().encode('utf-8')
def __unicode__(self):
ret = "%s/%s" % (self.main_type, self.sub_type)
diff --git a/rest_framework/views.py b/rest_framework/views.py
index d6ccb301..23df3443 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -31,6 +31,7 @@ def get_view_name(view_cls, suffix=None):
return name
+
def get_view_description(view_cls, html=False):
"""
Given a view class, return a textual description to represent the view.
@@ -120,7 +121,6 @@ class APIView(View):
headers['Vary'] = 'Accept'
return headers
-
def http_method_not_allowed(self, request, *args, **kwargs):
"""
If `request.method` does not correspond to a handler method,
diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py
index 7eb29f99..bb5b304e 100644
--- a/rest_framework/viewsets.py
+++ b/rest_framework/viewsets.py
@@ -127,11 +127,11 @@ class ReadOnlyModelViewSet(mixins.RetrieveModelMixin,
class ModelViewSet(mixins.CreateModelMixin,
- mixins.RetrieveModelMixin,
- mixins.UpdateModelMixin,
- mixins.DestroyModelMixin,
- mixins.ListModelMixin,
- GenericViewSet):
+ mixins.RetrieveModelMixin,
+ mixins.UpdateModelMixin,
+ mixins.DestroyModelMixin,
+ mixins.ListModelMixin,
+ GenericViewSet):
"""
A viewset that provides default `create()`, `retrieve()`, `update()`,
`partial_update()`, `destroy()` and `list()` actions.