aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py5
-rw-r--r--rest_framework/authentication.py118
-rw-r--r--rest_framework/authtoken/migrations/0001_initial.py17
-rw-r--r--rest_framework/compat.py31
-rw-r--r--rest_framework/fields.py187
-rw-r--r--rest_framework/runtests/settings.py17
-rw-r--r--rest_framework/settings.py19
-rw-r--r--rest_framework/tests/authentication.py177
-rw-r--r--rest_framework/tests/fields.py365
-rw-r--r--rest_framework/tests/filterset.py13
-rw-r--r--rest_framework/tests/pagination.py4
-rw-r--r--rest_framework/tests/serializer.py2
12 files changed, 839 insertions, 116 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 29f3d7bc..2180c509 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,6 +1,9 @@
-__version__ = '2.2.1'
+__version__ = '2.2.2'
VERSION = __version__ # synonym
# Header encoding (see RFC5987)
HTTP_HEADER_ENCODING = 'iso-8859-1'
+
+# Default datetime input and output formats
+ISO_8601 = 'iso-8601'
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index d4ba7967..4d6e3375 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -3,10 +3,11 @@ Provides a set of pluggable authentication policies.
"""
from __future__ import unicode_literals
from django.contrib.auth import authenticate
-from django.utils.encoding import DjangoUnicodeDecodeError
+from django.core.exceptions import ImproperlyConfigured
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
-from rest_framework.compat import oauth2_provider
+from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
+from rest_framework.compat import oauth2_provider, oauth2_provider_forms, oauth2_provider_backends
from rest_framework.authtoken.models import Token
import base64
@@ -59,11 +60,7 @@ class BasicAuthentication(BaseAuthentication):
except (TypeError, UnicodeDecodeError):
raise exceptions.AuthenticationFailed('Invalid basic header')
- try:
- userid, password = auth_parts[0], auth_parts[2]
- except DjangoUnicodeDecodeError:
- raise exceptions.AuthenticationFailed('Invalid basic header')
-
+ userid, password = auth_parts[0], auth_parts[2]
return self.authenticate_credentials(userid, password)
def authenticate_credentials(self, userid, password):
@@ -156,6 +153,107 @@ class TokenAuthentication(BaseAuthentication):
return 'Token'
+class OAuthAuthentication(BaseAuthentication):
+ """
+ OAuth 1.0a authentication backend using `django-oauth-plus` and `oauth2`.
+
+ Note: The `oauth2` package actually provides oauth1.0a support. Urg.
+ We import it from the `compat` module as `oauth`.
+ """
+ www_authenticate_realm = 'api'
+
+ def __init__(self, **kwargs):
+ super(OAuthAuthentication, self).__init__(**kwargs)
+
+ if oauth is None:
+ raise ImproperlyConfigured(
+ "The 'oauth2' package could not be imported."
+ "It is required for use with the 'OAuthAuthentication' class.")
+
+ if oauth_provider is None:
+ raise ImproperlyConfigured(
+ "The 'django-oauth-plus' package could not be imported."
+ "It is required for use with the 'OAuthAuthentication' class.")
+
+ def authenticate(self, request):
+ """
+ Returns two-tuple of (user, token) if authentication succeeds,
+ or None otherwise.
+ """
+ try:
+ oauth_request = oauth_provider.utils.get_oauth_request(request)
+ except oauth.Error as err:
+ raise exceptions.AuthenticationFailed(err.message)
+
+ oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES
+
+ found = any(param for param in oauth_params if param in oauth_request)
+ missing = list(param for param in oauth_params if param not in oauth_request)
+
+ if not found:
+ # OAuth authentication was not attempted.
+ return None
+
+ if missing:
+ # OAuth was attempted but missing parameters.
+ msg = 'Missing parameters: %s' % (', '.join(missing))
+ raise exceptions.AuthenticationFailed(msg)
+
+ if not self.check_nonce(request, oauth_request):
+ msg = 'Nonce check failed'
+ raise exceptions.AuthenticationFailed(msg)
+
+ try:
+ consumer_key = oauth_request.get_parameter('oauth_consumer_key')
+ consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
+ except oauth_provider_store.InvalidConsumerError as err:
+ raise exceptions.AuthenticationFailed(err)
+
+ if consumer.status != oauth_provider.consts.ACCEPTED:
+ msg = 'Invalid consumer key status: %s' % consumer.get_status_display()
+ raise exceptions.AuthenticationFailed(msg)
+
+ try:
+ token_param = oauth_request.get_parameter('oauth_token')
+ token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param)
+ except oauth_provider_store.InvalidTokenError:
+ msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token')
+ raise exceptions.AuthenticationFailed(msg)
+
+ try:
+ self.validate_token(request, consumer, token)
+ except oauth.Error as err:
+ raise exceptions.AuthenticationFailed(err.message)
+
+ user = token.user
+
+ if not user.is_active:
+ msg = 'User inactive or deleted: %s' % user.username
+ raise exceptions.AuthenticationFailed(msg)
+
+ return (token.user, token)
+
+ def authenticate_header(self, request):
+ """
+ If permission is denied, return a '401 Unauthorized' response,
+ with an appropraite 'WWW-Authenticate' header.
+ """
+ return 'OAuth realm="%s"' % self.www_authenticate_realm
+
+ def validate_token(self, request, consumer, token):
+ """
+ Check the token and raise an `oauth.Error` exception if invalid.
+ """
+ oauth_server, oauth_request = oauth_provider.utils.initialize_server_request(request)
+ oauth_server.verify_request(oauth_request, consumer, token)
+
+ def check_nonce(self, request, oauth_request):
+ """
+ Checks nonce of request, and return True if valid.
+ """
+ return oauth_provider_store.check_nonce(request, oauth_request, oauth_request['oauth_nonce'])
+
+
class OAuth2Authentication(BaseAuthentication):
"""
OAuth 2 authentication backend using `django-oauth2-provider`
@@ -189,13 +287,13 @@ class OAuth2Authentication(BaseAuthentication):
"""
# authenticate the client
- oauth2_client_form = oauth2_provider.forms.ClientAuthForm(request.REQUEST)
+ oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST)
if not oauth2_client_form.is_valid():
raise exceptions.AuthenticationFailed("Client could not be validated")
client = oauth2_client_form.cleaned_data.get('client')
# retrieve the `oauth2_provider.models.OAuth2AccessToken` instance from the access_token
- auth_backend = oauth2_provider.backends.AccessTokenBackend()
+ auth_backend = oauth2_provider_backends.AccessTokenBackend()
token = auth_backend.authenticate(access_token, client)
if token is None:
raise exceptions.AuthenticationFailed("Invalid token") # does not exist or is expired
@@ -214,7 +312,7 @@ class OAuth2Authentication(BaseAuthentication):
def authenticate_header(self, request):
"""
- Bearer is the only finalized type currently
+ Bearer is the only finalized type currently
Check details on the `OAuth2Authentication.authenticate` method
"""
diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py
index f4e052e4..d5965e40 100644
--- a/rest_framework/authtoken/migrations/0001_initial.py
+++ b/rest_framework/authtoken/migrations/0001_initial.py
@@ -4,6 +4,8 @@ from south.db import db
from south.v2 import SchemaMigration
from django.db import models
+from rest_framework.settings import api_settings
+
try:
from django.contrib.auth import get_user_model
@@ -45,20 +47,7 @@ class Migration(SchemaMigration):
'name': ('django.db.models.fields.CharField', [], {'max_length': '50'})
},
"%s.%s" % (User._meta.app_label, User._meta.module_name): {
- 'Meta': {'object_name': 'User'},
- 'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
- 'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}),
- 'first_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}),
- 'groups': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Group']", 'symmetrical': 'False', 'blank': 'True'}),
- 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}),
- 'is_active': ('django.db.models.fields.BooleanField', [], {'default': 'True'}),
- 'is_staff': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
- 'is_superuser': ('django.db.models.fields.BooleanField', [], {'default': 'False'}),
- 'last_login': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}),
- 'last_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}),
- 'password': ('django.db.models.fields.CharField', [], {'max_length': '128'}),
- 'user_permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'}),
- 'username': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '30'})
+ 'Meta': {'object_name': User._meta.module_name},
},
'authtoken.token': {
'Meta': {'object_name': 'Token'},
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 1e04e8f6..69be9543 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -427,16 +427,35 @@ try:
except ImportError:
etree = None
+# OAuth is optional
+try:
+ # Note: The `oauth2` package actually provides oauth1.0a support. Urg.
+ import oauth2 as oauth
+except ImportError:
+ oauth = None
+
+# OAuth is optional
+try:
+ import oauth_provider
+ from oauth_provider.store import store as oauth_provider_store
+except ImportError:
+ oauth_provider = None
+ oauth_provider_store = None
# OAuth 2 support is optional
try:
import provider.oauth2 as oauth2_provider
-
- # Hack to fix submodule import issues
- submodules = ['backends', 'forms','managers','models','urls','views']
- for s in submodules:
- mod = __import__('provider.oauth2.%s.*' % s)
- setattr(oauth2_provider, s, mod)
+ # # Hack to fix submodule import issues
+ # submodules = ['backends', 'forms', 'managers', 'models', 'urls', 'views']
+ # for s in submodules:
+ # mod = __import__('provider.oauth2.%s.*' % s)
+ # setattr(oauth2_provider, s, mod)
+ from provider.oauth2 import backends as oauth2_provider_backends
+ from provider.oauth2 import models as oauth2_provider_models
+ from provider.oauth2 import forms as oauth2_provider_forms
except ImportError:
oauth2_provider = None
+ oauth2_provider_backends = None
+ oauth2_provider_models = None
+ oauth2_provider_forms = None
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 86c3a837..fe555ee5 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -13,12 +13,13 @@ from django import forms
from django.forms import widgets
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
-from rest_framework.compat import parse_date, parse_datetime
-from rest_framework.compat import timezone
+
+from rest_framework import ISO_8601
+from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time
from rest_framework.compat import BytesIO
from rest_framework.compat import six
from rest_framework.compat import smart_text
-from rest_framework.compat import parse_time
+from rest_framework.settings import api_settings
def is_simple_callable(obj):
@@ -50,6 +51,46 @@ def get_component(obj, attr_name):
return val
+def readable_datetime_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]')
+ return humanize_strptime(format)
+
+
+def readable_date_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
+ return humanize_strptime(format)
+
+
+def readable_time_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
+ return humanize_strptime(format)
+
+
+def humanize_strptime(format_string):
+ # Note that we're missing some of the locale specific mappings that
+ # don't really make sense.
+ mapping = {
+ "%Y": "YYYY",
+ "%y": "YY",
+ "%m": "MM",
+ "%b": "[Jan-Dec]",
+ "%B": "[January-December]",
+ "%d": "DD",
+ "%H": "hh",
+ "%I": "hh", # Requires '%p' to differentiate from '%H'.
+ "%M": "mm",
+ "%S": "ss",
+ "%f": "uuuuuu",
+ "%a": "[Mon-Sun]",
+ "%A": "[Monday-Sunday]",
+ "%p": "[AM|PM]",
+ "%z": "[+HHMM|-HHMM]"
+ }
+ for key, val in mapping.items():
+ format_string = format_string.replace(key, val)
+ return format_string
+
+
class Field(object):
read_only = True
creation_counter = 0
@@ -447,12 +488,16 @@ class DateField(WritableField):
form_field_class = forms.DateField
default_error_messages = {
- 'invalid': _("'%s' value has an invalid date format. It must be "
- "in YYYY-MM-DD format."),
- 'invalid_date': _("'%s' value has the correct format (YYYY-MM-DD) "
- "but it is an invalid date."),
+ 'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
}
empty = None
+ input_formats = api_settings.DATE_INPUT_FORMATS
+ format = api_settings.DATE_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(DateField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
@@ -468,17 +513,33 @@ class DateField(WritableField):
if isinstance(value, datetime.date):
return value
- try:
- parsed = parse_date(value)
- if parsed is not None:
- return parsed
- except (ValueError, TypeError):
- msg = self.error_messages['invalid_date'] % value
- raise ValidationError(msg)
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_date(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed.date()
- msg = self.error_messages['invalid'] % value
+ msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats)
raise ValidationError(msg)
+ def to_native(self, value):
+ if isinstance(value, datetime.datetime):
+ value = value.date()
+ if self.format.lower() == ISO_8601:
+ return value.isoformat()
+ return value.strftime(self.format)
+
class DateTimeField(WritableField):
type_name = 'DateTimeField'
@@ -486,15 +547,16 @@ class DateTimeField(WritableField):
form_field_class = forms.DateTimeField
default_error_messages = {
- 'invalid': _("'%s' value has an invalid format. It must be in "
- "YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."),
- 'invalid_date': _("'%s' value has the correct format "
- "(YYYY-MM-DD) but it is an invalid date."),
- 'invalid_datetime': _("'%s' value has the correct format "
- "(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) "
- "but it is an invalid date/time."),
+ 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
}
empty = None
+ input_formats = api_settings.DATETIME_INPUT_FORMATS
+ format = api_settings.DATETIME_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(DateTimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
@@ -516,25 +578,31 @@ class DateTimeField(WritableField):
value = timezone.make_aware(value, default_timezone)
return value
- try:
- parsed = parse_datetime(value)
- if parsed is not None:
- return parsed
- except (ValueError, TypeError):
- msg = self.error_messages['invalid_datetime'] % value
- raise ValidationError(msg)
-
- try:
- parsed = parse_date(value)
- if parsed is not None:
- return datetime.datetime(parsed.year, parsed.month, parsed.day)
- except (ValueError, TypeError):
- msg = self.error_messages['invalid_date'] % value
- raise ValidationError(msg)
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_datetime(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed
- msg = self.error_messages['invalid'] % value
+ msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats)
raise ValidationError(msg)
+ def to_native(self, value):
+ if self.format.lower() == ISO_8601:
+ return value.isoformat()
+ return value.strftime(self.format)
+
class TimeField(WritableField):
type_name = 'TimeField'
@@ -542,10 +610,16 @@ class TimeField(WritableField):
form_field_class = forms.TimeField
default_error_messages = {
- 'invalid': _("'%s' value has an invalid format. It must be a valid "
- "time in the HH:MM[:ss[.uuuuuu]] format."),
+ 'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
}
empty = None
+ input_formats = api_settings.TIME_INPUT_FORMATS
+ format = api_settings.TIME_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(TimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
@@ -554,13 +628,32 @@ class TimeField(WritableField):
if isinstance(value, datetime.time):
return value
- try:
- parsed = parse_time(value)
- assert parsed is not None
- return parsed
- except (ValueError, TypeError):
- msg = self.error_messages['invalid'] % value
- raise ValidationError(msg)
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_time(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed.time()
+
+ msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats)
+ raise ValidationError(msg)
+
+ def to_native(self, value):
+ if isinstance(value, datetime.datetime):
+ value = value.time()
+ if self.format.lower() == ISO_8601:
+ return value.isoformat()
+ return value.strftime(self.format)
class IntegerField(WritableField):
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index 16ef1d2b..9b519f27 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -100,15 +100,26 @@ INSTALLED_APPS = (
'rest_framework.tests',
)
+# OAuth is optional and won't work if there is no oauth_provider & oauth2
+try:
+ import oauth_provider
+ import oauth2
+except ImportError:
+ pass
+else:
+ INSTALLED_APPS += (
+ 'oauth_provider',
+ )
+
try:
import provider
+except ImportError:
+ pass
+else:
INSTALLED_APPS += (
'provider',
'provider.oauth2',
)
-except ImportError:
- import logging
- logging.warning("django-oauth2-provider is not install, some tests will be skipped")
STATIC_URL = '/static/'
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index b7aa0bbe..eede0c5a 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -18,8 +18,11 @@ REST framework settings, checking for user settings first, then falling
back to the defaults.
"""
from __future__ import unicode_literals
+
from django.conf import settings
from django.utils import importlib
+
+from rest_framework import ISO_8601
from rest_framework.compat import six
@@ -76,6 +79,22 @@ DEFAULTS = {
'URL_FORMAT_OVERRIDE': 'format',
'FORMAT_SUFFIX_KWARG': 'format',
+
+ # Input and output formats
+ 'DATE_INPUT_FORMATS': (
+ ISO_8601,
+ ),
+ 'DATE_FORMAT': ISO_8601,
+
+ 'DATETIME_INPUT_FORMATS': (
+ ISO_8601,
+ ),
+ 'DATETIME_FORMAT': ISO_8601,
+
+ 'TIME_INPUT_FORMATS': (
+ ISO_8601,
+ ),
+ 'TIME_FORMAT': ISO_8601,
}
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index a02aef55..ddd61b63 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -1,5 +1,4 @@
from __future__ import unicode_literals
-from django.core.urlresolvers import reverse
from django.contrib.auth.models import User
from django.http import HttpResponse
from django.test import Client, TestCase
@@ -8,23 +7,25 @@ from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions
from rest_framework import permissions
from rest_framework import status
-from rest_framework.authtoken.models import Token
from rest_framework.authentication import (
BaseAuthentication,
TokenAuthentication,
BasicAuthentication,
SessionAuthentication,
+ OAuthAuthentication,
OAuth2Authentication
)
+from rest_framework.authtoken.models import Token
from rest_framework.compat import patterns, url, include
-from rest_framework.compat import oauth2_provider
+from rest_framework.compat import oauth2_provider, oauth2_provider_models
+from rest_framework.compat import oauth, oauth_provider
from rest_framework.tests.utils import RequestFactory
from rest_framework.views import APIView
import json
import base64
+import time
import datetime
-
factory = RequestFactory()
@@ -46,11 +47,12 @@ urlpatterns = patterns('',
(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
+ (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication]))
)
if oauth2_provider is not None:
- urlpatterns += patterns('',
- url(r'^oauth2/', include('provider.oauth2.urls', namespace = 'oauth2')),
+ urlpatterns += patterns('',
+ url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
)
@@ -235,6 +237,159 @@ class IncorrectCredentialsTests(TestCase):
self.assertEqual(response.data, {'detail': 'Bad credentials'})
+class OAuthTests(TestCase):
+ """OAuth 1.0a authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ # these imports are here because oauth is optional and hiding them in try..except block or compat
+ # could obscure problems if something breaks
+ from oauth_provider.models import Consumer, Resource
+ from oauth_provider.models import Token as OAuthToken
+ from oauth_provider import consts
+
+ self.consts = consts
+
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CONSUMER_KEY = 'consumer_key'
+ self.CONSUMER_SECRET = 'consumer_secret'
+ self.TOKEN_KEY = "token_key"
+ self.TOKEN_SECRET = "token_secret"
+
+ self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
+ name='example', user=self.user, status=self.consts.ACCEPTED)
+
+ self.resource = Resource.objects.create(name="resource name", url="api/")
+ self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, resource=self.resource,
+ token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True
+ )
+
+ def _create_authorization_header(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+
+ return req.to_header()["Authorization"]
+
+ def _create_authorization_url_parameters(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+ return dict(req)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_passing_oauth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_repeated_nonce_failing_oauth(self):
+ """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ # simulate reply attack auth header containes already used (nonce, timestamp) pair
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_token_removed_failing_oauth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.token.delete()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_consumer_status_not_accepted_failing_oauth(self):
+ """Ensure POSTing when consumer status is anything other than ACCEPTED fails"""
+ for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED):
+ self.consumer.status = consumer_status
+ self.consumer.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_request_token_failing_oauth(self):
+ """Ensure POSTing with unauthorized request token instead of access token fails"""
+ self.token.token_type = self.token.REQUEST
+ self.token.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_urlencoded_parameters(self):
+ """Ensure POSTing with x-www-form-urlencoded auth parameters passes"""
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_url_parameters(self):
+ """Ensure GETing with auth in url parameters passes"""
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_hmac_sha1_signature_passes(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+
class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication"""
urls = 'rest_framework.tests.authentication'
@@ -251,21 +406,21 @@ class OAuth2Tests(TestCase):
self.ACCESS_TOKEN = "access_token"
self.REFRESH_TOKEN = "refresh_token"
- self.oauth2_client = oauth2_provider.models.Client.objects.create(
- client_id=self.CLIENT_ID,
+ self.oauth2_client = oauth2_provider_models.Client.objects.create(
+ client_id=self.CLIENT_ID,
client_secret=self.CLIENT_SECRET,
redirect_uri='',
client_type=0,
- name='example',
+ name='example',
user=None,
)
- self.access_token = oauth2_provider.models.AccessToken.objects.create(
+ self.access_token = oauth2_provider_models.AccessToken.objects.create(
token=self.ACCESS_TOKEN,
client=self.oauth2_client,
user=self.user,
)
- self.refresh_token = oauth2_provider.models.RefreshToken.objects.create(
+ self.refresh_token = oauth2_provider_models.RefreshToken.objects.create(
user=self.user,
access_token=self.access_token,
client=self.oauth2_client
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
index 840ed320..28f18ed8 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/fields.py
@@ -3,9 +3,11 @@ General serializer field tests.
"""
from __future__ import unicode_literals
import datetime
+
from django.db import models
from django.test import TestCase
from django.core import validators
+
from rest_framework import serializers
@@ -59,37 +61,370 @@ class BasicFieldTests(TestCase):
serializer = CharPrimaryKeyModelSerializer()
self.assertEqual(serializer.fields['id'].read_only, False)
- def test_TimeField_from_native(self):
+
+class DateFieldTest(TestCase):
+ """
+ Tests for the DateFieldTest from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native('1984-07-31')
+
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_from_native_datetime_date(self):
+ """
+ Make sure from_native() accepts a datetime.date instance.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual(result_1, datetime.date(1984, 7, 31))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+ result = f.from_native('1984 -- 31')
+
+ self.assertEqual(datetime.date(1984, 1, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+
+ try:
+ f.from_native('1984-07-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_date(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid date.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984-13-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.DateField()
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984-07-31', result_1)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateField(format="%Y - %m.%d")
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984 - 07.31', result_1)
+
+
+class DateTimeFieldTest(TestCase):
+ """
+ Tests for the DateTimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native('1984-07-31 04:31')
+ result_2 = f.from_native('1984-07-31 04:31:59')
+ result_3 = f.from_native('1984-07-31 04:31:59.000200')
+
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3)
+
+ def test_from_native_datetime_datetime(self):
+ """
+ Make sure from_native() accepts a datetime.datetime instance.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31))
+ self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59))
+ self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+ result = f.from_native('1984 -- 04:59')
+
+ self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+
+ try:
+ f.from_native('1984-07-31 04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_datetime(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid datetime.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.DateTimeField()
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984-07-31T00:00:00', result_1)
+ self.assertEqual('1984-07-31T04:31:00', result_2)
+ self.assertEqual('1984-07-31T04:31:59', result_3)
+ self.assertEqual('1984-07-31T04:31:59.000200', result_4)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateTimeField(format="%Y - %H:%M")
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984 - 00:00', result_1)
+ self.assertEqual('1984 - 04:31', result_2)
+ self.assertEqual('1984 - 04:31', result_3)
+ self.assertEqual('1984 - 04:31', result_4)
+
+
+class TimeFieldTest(TestCase):
+ """
+ Tests for the TimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
f = serializers.TimeField()
- result = f.from_native('12:34:56.987654')
+ result_1 = f.from_native('04:31')
+ result_2 = f.from_native('04:31:59')
+ result_3 = f.from_native('04:31:59.000200')
- self.assertEqual(datetime.time(12, 34, 56, 987654), result)
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
- def test_TimeField_from_native_datetime_time(self):
+ def test_from_native_datetime_time(self):
"""
Make sure from_native() accepts a datetime.time instance.
"""
f = serializers.TimeField()
- result = f.from_native(datetime.time(12, 34, 56))
- self.assertEqual(result, datetime.time(12, 34, 56))
+ result_1 = f.from_native(datetime.time(4, 31))
+ result_2 = f.from_native(datetime.time(4, 31, 59))
+ result_3 = f.from_native(datetime.time(4, 31, 59, 200))
- def test_TimeField_from_native_empty(self):
+ self.assertEqual(result_1, datetime.time(4, 31))
+ self.assertEqual(result_2, datetime.time(4, 31, 59))
+ self.assertEqual(result_3, datetime.time(4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+ result = f.from_native('04 -- 31')
+
+ self.assertEqual(datetime.time(4, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+
+ try:
+ f.from_native('04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
f = serializers.TimeField()
result = f.from_native('')
+
self.assertEqual(result, None)
- def test_TimeField_from_native_invalid_time(self):
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.TimeField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_time(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid time.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
f = serializers.TimeField()
try:
- f.from_native('12:69:12')
+ f.from_native('04 -- 31')
except validators.ValidationError as e:
- self.assertEqual(e.messages, ["'12:69:12' value has an invalid "
- "format. It must be a valid time "
- "in the HH:MM[:ss[.uuuuuu]] format."])
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
else:
self.fail("ValidationError was not properly raised")
- def test_TimeFieldModelSerializer(self):
- serializer = TimeFieldModelSerializer()
- self.assertTrue(isinstance(serializer.fields['clock'], serializers.TimeField))
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.TimeField()
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04:31:00', result_1)
+ self.assertEqual('04:31:59', result_2)
+ self.assertEqual('04:31:59.000200', result_3)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.TimeField(format="%H - %S [%f]")
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04 - 00 [000000]', result_1)
+ self.assertEqual('04 - 59 [000000]', result_2)
+ self.assertEqual('04 - 59 [000200]', result_3)
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
index 8c13947c..fe92e0bc 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/filterset.py
@@ -65,8 +65,8 @@ class IntegrationTestFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()}
+ for obj in self.objects.all()
]
@unittest.skipUnless(django_filters, 'django-filters not installed')
@@ -95,7 +95,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] == search_date]
+ expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() == search_date]
self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed')
@@ -125,7 +125,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date]
+ expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date]
self.assertEqual(response.data, expected_data)
# Tests that the text filter set with 'icontains' in the filter class works.
@@ -142,8 +142,9 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date and
- f['decimal'] < search_decimal]
+ expected_data = [f for f in self.data if
+ datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date and
+ f['decimal'] < search_decimal]
self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed')
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 6b9970a6..472ffcdd 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -112,8 +112,8 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()}
+ for obj in self.objects.all()
]
self.view = FilterFieldsRootView.as_view()
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index d0300f9e..51065017 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -112,7 +112,7 @@ class BasicTests(TestCase):
self.expected = {
'email': 'tom@example.com',
'content': 'Happy new year!',
- 'created': datetime.datetime(2012, 1, 1),
+ 'created': '2012-01-01T00:00:00',
'sub_comment': 'And Merry Christmas!'
}
self.person_data = {'name': 'dwight', 'age': 35}