diff options
Diffstat (limited to 'rest_framework')
51 files changed, 4798 insertions, 979 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 29f3d7bc..0b1e67fb 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,6 +1,9 @@ -__version__ = '2.2.1' +__version__ = '2.3.3' VERSION = __version__ # synonym # Header encoding (see RFC5987) HTTP_HEADER_ENCODING = 'iso-8859-1' + +# Default datetime input and output formats +ISO_8601 = 'iso-8601' diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 14b2136b..9caca788 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -1,13 +1,30 @@ """ -Provides a set of pluggable authentication policies. +Provides various authentication policies. """ from __future__ import unicode_literals +import base64 +from datetime import datetime + from django.contrib.auth import authenticate -from django.utils.encoding import DjangoUnicodeDecodeError +from django.core.exceptions import ImproperlyConfigured from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware +from rest_framework.compat import oauth, oauth_provider, oauth_provider_store +from rest_framework.compat import oauth2_provider from rest_framework.authtoken.models import Token -import base64 + + +def get_authorization_header(request): + """ + Return request's 'Authorization:' header, as a bytestring. + + Hide some test client ickyness where the header can be unicode. + """ + auth = request.META.get('HTTP_AUTHORIZATION', b'') + if type(auth) == type(''): + # Work around django test client oddness + auth = auth.encode(HTTP_HEADER_ENCODING) + return auth class BaseAuthentication(object): @@ -41,28 +58,25 @@ class BasicAuthentication(BaseAuthentication): Returns a `User` if a correct username and password have been supplied using HTTP Basic authentication. Otherwise returns `None`. """ - auth = request.META.get('HTTP_AUTHORIZATION', b'') - if type(auth) == type(''): - # Work around django test client oddness - auth = auth.encode(HTTP_HEADER_ENCODING) - auth = auth.split() + auth = get_authorization_header(request).split() if not auth or auth[0].lower() != b'basic': return None - if len(auth) != 2: - raise exceptions.AuthenticationFailed('Invalid basic header') + if len(auth) == 1: + msg = 'Invalid basic header. No credentials provided.' + raise exceptions.AuthenticationFailed(msg) + elif len(auth) > 2: + msg = 'Invalid basic header. Credentials string should not contain spaces.' + raise exceptions.AuthenticationFailed(msg) try: auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':') except (TypeError, UnicodeDecodeError): - raise exceptions.AuthenticationFailed('Invalid basic header') - - try: - userid, password = auth_parts[0], auth_parts[2] - except DjangoUnicodeDecodeError: - raise exceptions.AuthenticationFailed('Invalid basic header') + msg = 'Invalid basic header. Credentials not correctly base64 encoded' + raise exceptions.AuthenticationFailed(msg) + userid, password = auth_parts[0], auth_parts[2] return self.authenticate_credentials(userid, password) def authenticate_credentials(self, userid, password): @@ -70,9 +84,9 @@ class BasicAuthentication(BaseAuthentication): Authenticate the userid and password against username and password. """ user = authenticate(username=userid, password=password) - if user is not None and user.is_active: - return (user, None) - raise exceptions.AuthenticationFailed('Invalid username/password') + if user is None or not user.is_active: + raise exceptions.AuthenticationFailed('Invalid username/password') + return (user, None) def authenticate_header(self, request): return 'Basic realm="%s"' % self.www_authenticate_realm @@ -131,13 +145,17 @@ class TokenAuthentication(BaseAuthentication): """ def authenticate(self, request): - auth = request.META.get('HTTP_AUTHORIZATION', '').split() + auth = get_authorization_header(request).split() - if not auth or auth[0].lower() != "token": + if not auth or auth[0].lower() != b'token': return None - if len(auth) != 2: - raise exceptions.AuthenticationFailed('Invalid token header') + if len(auth) == 1: + msg = 'Invalid token header. No credentials provided.' + raise exceptions.AuthenticationFailed(msg) + elif len(auth) > 2: + msg = 'Invalid token header. Token string should not contain spaces.' + raise exceptions.AuthenticationFailed(msg) return self.authenticate_credentials(auth[1]) @@ -147,12 +165,178 @@ class TokenAuthentication(BaseAuthentication): except self.model.DoesNotExist: raise exceptions.AuthenticationFailed('Invalid token') - if token.user.is_active: - return (token.user, token) - raise exceptions.AuthenticationFailed('User inactive or deleted') + if not token.user.is_active: + raise exceptions.AuthenticationFailed('User inactive or deleted') + + return (token.user, token) def authenticate_header(self, request): return 'Token' -# TODO: OAuthAuthentication +class OAuthAuthentication(BaseAuthentication): + """ + OAuth 1.0a authentication backend using `django-oauth-plus` and `oauth2`. + + Note: The `oauth2` package actually provides oauth1.0a support. Urg. + We import it from the `compat` module as `oauth`. + """ + www_authenticate_realm = 'api' + + def __init__(self, *args, **kwargs): + super(OAuthAuthentication, self).__init__(*args, **kwargs) + + if oauth is None: + raise ImproperlyConfigured( + "The 'oauth2' package could not be imported." + "It is required for use with the 'OAuthAuthentication' class.") + + if oauth_provider is None: + raise ImproperlyConfigured( + "The 'django-oauth-plus' package could not be imported." + "It is required for use with the 'OAuthAuthentication' class.") + + def authenticate(self, request): + """ + Returns two-tuple of (user, token) if authentication succeeds, + or None otherwise. + """ + try: + oauth_request = oauth_provider.utils.get_oauth_request(request) + except oauth.Error as err: + raise exceptions.AuthenticationFailed(err.message) + + if not oauth_request: + return None + + oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES + + found = any(param for param in oauth_params if param in oauth_request) + missing = list(param for param in oauth_params if param not in oauth_request) + + if not found: + # OAuth authentication was not attempted. + return None + + if missing: + # OAuth was attempted but missing parameters. + msg = 'Missing parameters: %s' % (', '.join(missing)) + raise exceptions.AuthenticationFailed(msg) + + if not self.check_nonce(request, oauth_request): + msg = 'Nonce check failed' + raise exceptions.AuthenticationFailed(msg) + + try: + consumer_key = oauth_request.get_parameter('oauth_consumer_key') + consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key) + except oauth_provider.store.InvalidConsumerError as err: + raise exceptions.AuthenticationFailed(err) + + if consumer.status != oauth_provider.consts.ACCEPTED: + msg = 'Invalid consumer key status: %s' % consumer.get_status_display() + raise exceptions.AuthenticationFailed(msg) + + try: + token_param = oauth_request.get_parameter('oauth_token') + token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param) + except oauth_provider.store.InvalidTokenError: + msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token') + raise exceptions.AuthenticationFailed(msg) + + try: + self.validate_token(request, consumer, token) + except oauth.Error as err: + raise exceptions.AuthenticationFailed(err.message) + + user = token.user + + if not user.is_active: + msg = 'User inactive or deleted: %s' % user.username + raise exceptions.AuthenticationFailed(msg) + + return (token.user, token) + + def authenticate_header(self, request): + """ + If permission is denied, return a '401 Unauthorized' response, + with an appropraite 'WWW-Authenticate' header. + """ + return 'OAuth realm="%s"' % self.www_authenticate_realm + + def validate_token(self, request, consumer, token): + """ + Check the token and raise an `oauth.Error` exception if invalid. + """ + oauth_server, oauth_request = oauth_provider.utils.initialize_server_request(request) + oauth_server.verify_request(oauth_request, consumer, token) + + def check_nonce(self, request, oauth_request): + """ + Checks nonce of request, and return True if valid. + """ + return oauth_provider_store.check_nonce(request, oauth_request, oauth_request['oauth_nonce']) + + +class OAuth2Authentication(BaseAuthentication): + """ + OAuth 2 authentication backend using `django-oauth2-provider` + """ + www_authenticate_realm = 'api' + + def __init__(self, *args, **kwargs): + super(OAuth2Authentication, self).__init__(*args, **kwargs) + + if oauth2_provider is None: + raise ImproperlyConfigured( + "The 'django-oauth2-provider' package could not be imported. " + "It is required for use with the 'OAuth2Authentication' class.") + + def authenticate(self, request): + """ + Returns two-tuple of (user, token) if authentication succeeds, + or None otherwise. + """ + + auth = get_authorization_header(request).split() + + if not auth or auth[0].lower() != b'bearer': + return None + + if len(auth) == 1: + msg = 'Invalid bearer header. No credentials provided.' + raise exceptions.AuthenticationFailed(msg) + elif len(auth) > 2: + msg = 'Invalid bearer header. Token string should not contain spaces.' + raise exceptions.AuthenticationFailed(msg) + + return self.authenticate_credentials(request, auth[1]) + + def authenticate_credentials(self, request, access_token): + """ + Authenticate the request, given the access token. + """ + + try: + token = oauth2_provider.models.AccessToken.objects.select_related('user') + # TODO: Change to timezone aware datetime when oauth2_provider add + # support to it. + token = token.get(token=access_token, expires__gt=datetime.now()) + except oauth2_provider.models.AccessToken.DoesNotExist: + raise exceptions.AuthenticationFailed('Invalid token') + + user = token.user + + if not user.is_active: + msg = 'User inactive or deleted: %s' % user.username + raise exceptions.AuthenticationFailed(msg) + + return (user, token) + + def authenticate_header(self, request): + """ + Bearer is the only finalized type currently + + Check details on the `OAuth2Authentication.authenticate` method + """ + return 'Bearer realm="%s"' % self.www_authenticate_realm diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py index f4e052e4..d5965e40 100644 --- a/rest_framework/authtoken/migrations/0001_initial.py +++ b/rest_framework/authtoken/migrations/0001_initial.py @@ -4,6 +4,8 @@ from south.db import db from south.v2 import SchemaMigration from django.db import models +from rest_framework.settings import api_settings + try: from django.contrib.auth import get_user_model @@ -45,20 +47,7 @@ class Migration(SchemaMigration): 'name': ('django.db.models.fields.CharField', [], {'max_length': '50'}) }, "%s.%s" % (User._meta.app_label, User._meta.module_name): { - 'Meta': {'object_name': 'User'}, - 'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}), - 'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}), - 'first_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}), - 'groups': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Group']", 'symmetrical': 'False', 'blank': 'True'}), - 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), - 'is_active': ('django.db.models.fields.BooleanField', [], {'default': 'True'}), - 'is_staff': ('django.db.models.fields.BooleanField', [], {'default': 'False'}), - 'is_superuser': ('django.db.models.fields.BooleanField', [], {'default': 'False'}), - 'last_login': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}), - 'last_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}), - 'password': ('django.db.models.fields.CharField', [], {'max_length': '128'}), - 'user_permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'}), - 'username': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '30'}) + 'Meta': {'object_name': User._meta.module_name}, }, 'authtoken.token': { 'Meta': {'object_name': 'Token'}, diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py index 7f5a75a3..52c45ad1 100644 --- a/rest_framework/authtoken/models.py +++ b/rest_framework/authtoken/models.py @@ -2,6 +2,7 @@ import uuid import hmac from hashlib import sha1 from rest_framework.compat import User +from django.conf import settings from django.db import models @@ -13,6 +14,14 @@ class Token(models.Model): user = models.OneToOneField(User, related_name='auth_token') created = models.DateTimeField(auto_now_add=True) + class Meta: + # Work around for a bug in Django: + # https://code.djangoproject.com/ticket/19422 + # + # Also see corresponding ticket: + # https://github.com/tomchristie/django-rest-framework/issues/705 + abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS + def save(self, *args, **kwargs): if not self.key: self.key = self.generate_key() diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 07fdddce..cd39f544 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -6,6 +6,7 @@ versions of django/python, and compatibility wrappers around optional packages. from __future__ import unicode_literals import django +from django.core.exceptions import ImproperlyConfigured # Try to import six from Django, fallback to included `six`. try: @@ -87,9 +88,7 @@ else: raise ImportError("User model is not to be found.") -# First implementation of Django class-based views did not include head method -# in base View class - https://code.djangoproject.com/ticket/15668 -if django.VERSION >= (1, 4): +if django.VERSION >= (1, 5): from django.views.generic import View else: from django.views.generic import View as _View @@ -97,6 +96,8 @@ else: from django.utils.functional import update_wrapper class View(_View): + # 1.3 does not include head method in base View class + # See: https://code.djangoproject.com/ticket/15668 @classonlymethod def as_view(cls, **initkwargs): """ @@ -126,11 +127,15 @@ else: update_wrapper(view, cls.dispatch, assigned=()) return view -# Taken from @markotibold's attempt at supporting PATCH. -# https://github.com/markotibold/django-rest-framework/tree/patch -http_method_names = set(View.http_method_names) -http_method_names.add('patch') -View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django + # _allowed_methods only present from 1.5 onwards + def _allowed_methods(self): + return [m.upper() for m in self.http_method_names if hasattr(self, m)] + + +# PATCH method is not implemented by Django +if 'patch' not in View.http_method_names: + View.http_method_names = View.http_method_names + ['patch'] + # PUT, DELETE do not require CSRF until 1.4. They should. Make it better. if django.VERSION >= (1, 4): @@ -395,6 +400,41 @@ except ImportError: kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None) return datetime.datetime(**kw) + +# smart_urlquote is new on Django 1.4 +try: + from django.utils.html import smart_urlquote +except ImportError: + import re + from django.utils.encoding import smart_str + try: + from urllib.parse import quote, urlsplit, urlunsplit + except ImportError: # Python 2 + from urllib import quote + from urlparse import urlsplit, urlunsplit + + unquoted_percents_re = re.compile(r'%(?![0-9A-Fa-f]{2})') + + def smart_urlquote(url): + "Quotes a URL if it isn't already quoted." + # Handle IDN before quoting. + scheme, netloc, path, query, fragment = urlsplit(url) + try: + netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE + except UnicodeError: # invalid domain part + pass + else: + url = urlunsplit((scheme, netloc, path, query, fragment)) + + # An URL is considered unquoted if it contains no % characters or + # contains a % not followed by two hexadecimal digits. See #9655. + if '%' not in url or unquoted_percents_re.search(url): + # See http://bugs.python.org/issue2637 + url = quote(smart_str(url), safe=b'!*\'();:@&=+$,/?#[]~') + + return force_text(url) + + # Markdown is optional try: import markdown @@ -426,3 +466,32 @@ try: import defusedxml.ElementTree as etree except ImportError: etree = None + +# OAuth is optional +try: + # Note: The `oauth2` package actually provides oauth1.0a support. Urg. + import oauth2 as oauth +except ImportError: + oauth = None + +# OAuth is optional +try: + import oauth_provider + from oauth_provider.store import store as oauth_provider_store +except (ImportError, ImproperlyConfigured): + oauth_provider = None + oauth_provider_store = None + +# OAuth 2 support is optional +try: + import provider.oauth2 as oauth2_provider + from provider.oauth2 import models as oauth2_provider_models + from provider.oauth2 import forms as oauth2_provider_forms + from provider import scope as oauth2_provider_scope + from provider import constants as oauth2_constants +except ImportError: + oauth2_provider = None + oauth2_provider_models = None + oauth2_provider_forms = None + oauth2_provider_scope = None + oauth2_constants = None diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 8250cd3b..81e585e1 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,3 +1,11 @@ +""" +The most imporant decorator in this module is `@api_view`, which is used +for writing function-based views with REST framework. + +There are also various decorators for setting the API policies on function +based views, as well as the `@action` and `@link` decorators, which are +used to annotate methods on viewsets that should be included by routers. +""" from __future__ import unicode_literals from rest_framework.compat import six from rest_framework.views import APIView @@ -97,3 +105,25 @@ def permission_classes(permission_classes): func.permission_classes = permission_classes return func return decorator + + +def link(**kwargs): + """ + Used to mark a method on a ViewSet that should be routed for GET requests. + """ + def decorator(func): + func.bind_to_method = 'get' + func.kwargs = kwargs + return func + return decorator + + +def action(**kwargs): + """ + Used to mark a method on a ViewSet that should be routed for POST requests. + """ + def decorator(func): + func.bind_to_method = 'post' + func.kwargs = kwargs + return func + return decorator diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 86c3a837..c83ee5ec 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,7 +1,13 @@ +""" +Serializer fields perform validation on incoming data. + +They are very similar to Django's form fields. +""" from __future__ import unicode_literals import copy import datetime +from decimal import Decimal, DecimalException import inspect import re import warnings @@ -13,26 +19,29 @@ from django import forms from django.forms import widgets from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ -from rest_framework.compat import parse_date, parse_datetime -from rest_framework.compat import timezone + +from rest_framework import ISO_8601 +from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time from rest_framework.compat import BytesIO from rest_framework.compat import six from rest_framework.compat import smart_text -from rest_framework.compat import parse_time +from rest_framework.settings import api_settings def is_simple_callable(obj): """ True if the object is a callable that takes no arguments. """ - try: - args, _, _, defaults = inspect.getargspec(obj) - except TypeError: + function = inspect.isfunction(obj) + method = inspect.ismethod(obj) + + if not (function or method): return False - else: - len_args = len(args) if inspect.isfunction(obj) else len(args) - 1 - len_defaults = len(defaults) if defaults else 0 - return len_args <= len_defaults + + args, _, _, defaults = inspect.getargspec(obj) + len_args = len(args) if function else len(args) - 1 + len_defaults = len(defaults) if defaults else 0 + return len_args <= len_defaults def get_component(obj, attr_name): @@ -50,6 +59,46 @@ def get_component(obj, attr_name): return val +def readable_datetime_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]') + return humanize_strptime(format) + + +def readable_date_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') + return humanize_strptime(format) + + +def readable_time_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') + return humanize_strptime(format) + + +def humanize_strptime(format_string): + # Note that we're missing some of the locale specific mappings that + # don't really make sense. + mapping = { + "%Y": "YYYY", + "%y": "YY", + "%m": "MM", + "%b": "[Jan-Dec]", + "%B": "[January-December]", + "%d": "DD", + "%H": "hh", + "%I": "hh", # Requires '%p' to differentiate from '%H'. + "%M": "mm", + "%S": "ss", + "%f": "uuuuuu", + "%a": "[Mon-Sun]", + "%A": "[Monday-Sunday]", + "%p": "[AM|PM]", + "%z": "[+HHMM|-HHMM]" + } + for key, val in mapping.items(): + format_string = format_string.replace(key, val) + return format_string + + class Field(object): read_only = True creation_counter = 0 @@ -151,9 +200,9 @@ class WritableField(Field): # 'blank' is to be deprecated in favor of 'required' if blank is not None: - warnings.warn('The `blank` keyword argument is due to deprecated. ' + warnings.warn('The `blank` keyword argument is deprecated. ' 'Use the `required` keyword argument instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) required = not(blank) super(WritableField, self).__init__(source=source) @@ -447,12 +496,16 @@ class DateField(WritableField): form_field_class = forms.DateField default_error_messages = { - 'invalid': _("'%s' value has an invalid date format. It must be " - "in YYYY-MM-DD format."), - 'invalid_date': _("'%s' value has the correct format (YYYY-MM-DD) " - "but it is an invalid date."), + 'invalid': _("Date has wrong format. Use one of these formats instead: %s"), } empty = None + input_formats = api_settings.DATE_INPUT_FORMATS + format = api_settings.DATE_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(DateField, self).__init__(*args, **kwargs) def from_native(self, value): if value in validators.EMPTY_VALUES: @@ -468,17 +521,37 @@ class DateField(WritableField): if isinstance(value, datetime.date): return value - try: - parsed = parse_date(value) - if parsed is not None: - return parsed - except (ValueError, TypeError): - msg = self.error_messages['invalid_date'] % value - raise ValidationError(msg) + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_date(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed.date() - msg = self.error_messages['invalid'] % value + msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats) raise ValidationError(msg) + def to_native(self, value): + if value is None or self.format is None: + return value + + if isinstance(value, datetime.datetime): + value = value.date() + + if self.format.lower() == ISO_8601: + return value.isoformat() + return value.strftime(self.format) + class DateTimeField(WritableField): type_name = 'DateTimeField' @@ -486,15 +559,16 @@ class DateTimeField(WritableField): form_field_class = forms.DateTimeField default_error_messages = { - 'invalid': _("'%s' value has an invalid format. It must be in " - "YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."), - 'invalid_date': _("'%s' value has the correct format " - "(YYYY-MM-DD) but it is an invalid date."), - 'invalid_datetime': _("'%s' value has the correct format " - "(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) " - "but it is an invalid date/time."), + 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), } empty = None + input_formats = api_settings.DATETIME_INPUT_FORMATS + format = api_settings.DATETIME_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(DateTimeField, self).__init__(*args, **kwargs) def from_native(self, value): if value in validators.EMPTY_VALUES: @@ -516,25 +590,37 @@ class DateTimeField(WritableField): value = timezone.make_aware(value, default_timezone) return value - try: - parsed = parse_datetime(value) - if parsed is not None: - return parsed - except (ValueError, TypeError): - msg = self.error_messages['invalid_datetime'] % value - raise ValidationError(msg) - - try: - parsed = parse_date(value) - if parsed is not None: - return datetime.datetime(parsed.year, parsed.month, parsed.day) - except (ValueError, TypeError): - msg = self.error_messages['invalid_date'] % value - raise ValidationError(msg) + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_datetime(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed - msg = self.error_messages['invalid'] % value + msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats) raise ValidationError(msg) + def to_native(self, value): + if value is None or self.format is None: + return value + + if self.format.lower() == ISO_8601: + ret = value.isoformat() + if ret.endswith('+00:00'): + ret = ret[:-6] + 'Z' + return ret + return value.strftime(self.format) + class TimeField(WritableField): type_name = 'TimeField' @@ -542,10 +628,16 @@ class TimeField(WritableField): form_field_class = forms.TimeField default_error_messages = { - 'invalid': _("'%s' value has an invalid format. It must be a valid " - "time in the HH:MM[:ss[.uuuuuu]] format."), + 'invalid': _("Time has wrong format. Use one of these formats instead: %s"), } empty = None + input_formats = api_settings.TIME_INPUT_FORMATS + format = api_settings.TIME_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(TimeField, self).__init__(*args, **kwargs) def from_native(self, value): if value in validators.EMPTY_VALUES: @@ -554,13 +646,36 @@ class TimeField(WritableField): if isinstance(value, datetime.time): return value - try: - parsed = parse_time(value) - assert parsed is not None - return parsed - except (ValueError, TypeError): - msg = self.error_messages['invalid'] % value - raise ValidationError(msg) + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_time(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed.time() + + msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats) + raise ValidationError(msg) + + def to_native(self, value): + if value is None or self.format is None: + return value + + if isinstance(value, datetime.datetime): + value = value.time() + + if self.format.lower() == ISO_8601: + return value.isoformat() + return value.strftime(self.format) class IntegerField(WritableField): @@ -612,6 +727,75 @@ class FloatField(WritableField): raise ValidationError(msg) +class DecimalField(WritableField): + type_name = 'DecimalField' + form_field_class = forms.DecimalField + + default_error_messages = { + 'invalid': _('Enter a number.'), + 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), + 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), + 'max_digits': _('Ensure that there are no more than %s digits in total.'), + 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), + 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') + } + + def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): + self.max_value, self.min_value = max_value, min_value + self.max_digits, self.decimal_places = max_digits, decimal_places + super(DecimalField, self).__init__(*args, **kwargs) + + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) + + def from_native(self, value): + """ + Validates that the input is a decimal number. Returns a Decimal + instance. Returns None for empty values. Ensures that there are no more + than max_digits in the number, and no more than decimal_places digits + after the decimal point. + """ + if value in validators.EMPTY_VALUES: + return None + value = smart_text(value).strip() + try: + value = Decimal(value) + except DecimalException: + raise ValidationError(self.error_messages['invalid']) + return value + + def validate(self, value): + super(DecimalField, self).validate(value) + if value in validators.EMPTY_VALUES: + return + # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, + # since it is never equal to itself. However, NaN is the only value that + # isn't equal to itself, so we can use this to identify NaN + if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): + raise ValidationError(self.error_messages['invalid']) + sign, digittuple, exponent = value.as_tuple() + decimals = abs(exponent) + # digittuple doesn't include any leading zeros. + digits = len(digittuple) + if decimals > digits: + # We have leading zeros up to or past the decimal point. Count + # everything past the decimal point as a digit. We do not count + # 0 before the decimal point as a digit since that would mean + # we would not allow max_digits = decimal_places. + digits = decimals + whole_digits = digits - decimals + + if self.max_digits is not None and digits > self.max_digits: + raise ValidationError(self.error_messages['max_digits'] % self.max_digits) + if self.decimal_places is not None and decimals > self.decimal_places: + raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) + if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): + raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) + return value + + class FileField(WritableField): use_files = True type_name = 'FileField' diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 6fea46fa..c058bc71 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -1,5 +1,12 @@ +""" +Provides generic filtering backends that can be used to filter the results +returned by list views. +""" from __future__ import unicode_literals -from rest_framework.compat import django_filters +from django.db import models +from rest_framework.compat import django_filters, six +from functools import reduce +import operator FilterSet = django_filters and django_filters.FilterSet or None @@ -25,36 +32,112 @@ class DjangoFilterBackend(BaseFilterBackend): def __init__(self): assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed' - def get_filter_class(self, view): + def get_filter_class(self, view, queryset=None): """ Return the django-filters `FilterSet` used to filter the queryset. """ filter_class = getattr(view, 'filter_class', None) filter_fields = getattr(view, 'filter_fields', None) - view_model = getattr(view, 'model', None) if filter_class: filter_model = filter_class.Meta.model - assert issubclass(filter_model, view_model), \ - 'FilterSet model %s does not match view model %s' % \ - (filter_model, view_model) + assert issubclass(filter_model, queryset.model), \ + 'FilterSet model %s does not match queryset model %s' % \ + (filter_model, queryset.model) return filter_class if filter_fields: class AutoFilterSet(self.default_filter_set): class Meta: - model = view_model + model = queryset.model fields = filter_fields return AutoFilterSet return None def filter_queryset(self, request, queryset, view): - filter_class = self.get_filter_class(view) + filter_class = self.get_filter_class(view, queryset) if filter_class: - return filter_class(request.QUERY_PARAMS, queryset=queryset) + return filter_class(request.QUERY_PARAMS, queryset=queryset).qs + + return queryset + + +class SearchFilter(BaseFilterBackend): + search_param = 'search' # The URL query parameter used for the search. + + def get_search_terms(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be comma and/or whitespace delimited. + """ + params = request.QUERY_PARAMS.get(self.search_param, '') + return params.replace(',', ' ').split() + + def construct_search(self, field_name): + if field_name.startswith('^'): + return "%s__istartswith" % field_name[1:] + elif field_name.startswith('='): + return "%s__iexact" % field_name[1:] + elif field_name.startswith('@'): + return "%s__search" % field_name[1:] + else: + return "%s__icontains" % field_name + + def filter_queryset(self, request, queryset, view): + search_fields = getattr(view, 'search_fields', None) + + if not search_fields: + return queryset + + orm_lookups = [self.construct_search(str(search_field)) + for search_field in search_fields] + + for search_term in self.get_search_terms(request): + or_queries = [models.Q(**{orm_lookup: search_term}) + for orm_lookup in orm_lookups] + queryset = queryset.filter(reduce(operator.or_, or_queries)) + + return queryset + + +class OrderingFilter(BaseFilterBackend): + ordering_param = 'ordering' # The URL query parameter used for the ordering. + + def get_ordering(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be comma and/or whitespace delimited. + """ + params = request.QUERY_PARAMS.get(self.ordering_param) + if params: + return [param.strip() for param in params.split(',')] + + def get_default_ordering(self, view): + ordering = getattr(view, 'ordering', None) + if isinstance(ordering, six.string_types): + return (ordering,) + return ordering + + def remove_invalid_fields(self, queryset, ordering): + field_names = [field.name for field in queryset.model._meta.fields] + return [term for term in ordering if term.lstrip('-') in field_names] + + def filter_queryset(self, request, queryset, view): + ordering = self.get_ordering(request) + + if ordering: + # Skip any incorrect parameters + ordering = self.remove_invalid_fields(queryset, ordering) + + if not ordering: + # Use 'ordering' attribtue by default + ordering = self.get_default_ordering(view) + + if ordering: + return queryset.order_by(*ordering) return queryset diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 9ae8cf0a..05ec93d3 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -2,22 +2,59 @@ Generic views that provide commonly needed behaviour. """ from __future__ import unicode_literals + +from django.core.exceptions import ImproperlyConfigured +from django.core.paginator import Paginator, InvalidPage +from django.http import Http404 +from django.shortcuts import get_object_or_404 +from django.utils.translation import ugettext as _ from rest_framework import views, mixins +from rest_framework.exceptions import ConfigurationError from rest_framework.settings import api_settings -from django.views.generic.detail import SingleObjectMixin -from django.views.generic.list import MultipleObjectMixin - +import warnings -### Base classes for the generic views ### class GenericAPIView(views.APIView): """ Base class for all other generic views. """ - model = None + # You'll need to either set these attributes, + # or override `get_queryset()`/`get_serializer_class()`. + queryset = None serializer_class = None + + # This shortcut may be used instead of setting either or both + # of the `queryset`/`serializer_class` attributes, although using + # the explicit style is generally preferred. + model = None + + # If you want to use object lookups other than pk, set this attribute. + # For more complex lookup requirements override `get_object()`. + lookup_field = 'pk' + + # Pagination settings + paginate_by = api_settings.PAGINATE_BY + paginate_by_param = api_settings.PAGINATE_BY_PARAM + pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS + page_kwarg = 'page' + + # The filter backend classes to use for queryset filtering + filter_backends = api_settings.DEFAULT_FILTER_BACKENDS + + # The following attributes may be subject to change, + # and should be considered private API. model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS + paginator_class = Paginator + + ###################################### + # These are pending deprecation... + + pk_url_kwarg = 'pk' + slug_url_kwarg = 'slug' + slug_field = 'slug' + allow_empty = True + filter_backend = api_settings.FILTER_BACKEND def get_serializer_context(self): """ @@ -29,24 +66,6 @@ class GenericAPIView(views.APIView): 'view': self } - def get_serializer_class(self): - """ - Return the class to use for the serializer. - - Defaults to using `self.serializer_class`, falls back to constructing a - model serializer class using `self.model_serializer_class`, with - `self.model` as the model. - """ - serializer_class = self.serializer_class - - if serializer_class is None: - class DefaultSerializer(self.model_serializer_class): - class Meta: - model = self.model - serializer_class = DefaultSerializer - - return serializer_class - def get_serializer(self, instance=None, data=None, files=None, many=False, partial=False): """ @@ -58,86 +77,244 @@ class GenericAPIView(views.APIView): return serializer_class(instance, data=data, files=files, many=many, partial=partial, context=context) - def pre_save(self, obj): + def get_pagination_serializer(self, page): """ - Placeholder method for calling before saving an object. - May be used eg. to set attributes on the object that are implicit - in either the request, or the url. + Return a serializer instance to use with paginated data. """ - pass + class SerializerClass(self.pagination_serializer_class): + class Meta: + object_serializer_class = self.get_serializer_class() - def post_save(self, obj, created=False): + pagination_serializer_class = SerializerClass + context = self.get_serializer_context() + return pagination_serializer_class(instance=page, context=context) + + def paginate_queryset(self, queryset, page_size=None): """ - Placeholder method for calling after saving an object. + Paginate a queryset if required, either returning a page object, + or `None` if pagination is not configured for this view. """ - pass - - -class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): - """ - Base class for generic views onto a queryset. - """ - - paginate_by = api_settings.PAGINATE_BY - paginate_by_param = api_settings.PAGINATE_BY_PARAM - pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS - filter_backend = api_settings.FILTER_BACKEND + deprecated_style = False + if page_size is not None: + warnings.warn('The `page_size` parameter to `paginate_queryset()` ' + 'is due to be deprecated. ' + 'Note that the return style of this method is also ' + 'changed, and will simply return a page object ' + 'when called without a `page_size` argument.', + PendingDeprecationWarning, stacklevel=2) + deprecated_style = True + else: + # Determine the required page size. + # If pagination is not configured, simply return None. + page_size = self.get_paginate_by() + if not page_size: + return None + + if not self.allow_empty: + warnings.warn( + 'The `allow_empty` parameter is due to be deprecated. ' + 'To use `allow_empty=False` style behavior, You should override ' + '`get_queryset()` and explicitly raise a 404 on empty querysets.', + PendingDeprecationWarning, stacklevel=2 + ) + + paginator = self.paginator_class(queryset, page_size, + allow_empty_first_page=self.allow_empty) + page_kwarg = self.kwargs.get(self.page_kwarg) + page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) + page = page_kwarg or page_query_param or 1 + try: + page_number = int(page) + except ValueError: + if page == 'last': + page_number = paginator.num_pages + else: + raise Http404(_("Page is not 'last', nor can it be converted to an int.")) + try: + page = paginator.page(page_number) + except InvalidPage as e: + raise Http404(_('Invalid page (%(page_number)s): %(message)s') % { + 'page_number': page_number, + 'message': str(e) + }) + + if deprecated_style: + return (paginator, page, page.object_list, page.has_other_pages()) + return page def filter_queryset(self, queryset): """ Given a queryset, filter it with whichever filter backend is in use. - """ - if not self.filter_backend: - return queryset - backend = self.filter_backend() - return backend.filter_queryset(self.request, queryset, self) - def get_pagination_serializer(self, page=None): + You are unlikely to want to override this method, although you may need + to call it either from a list view, or from a custom `get_object` + method if you want to apply the configured filtering backend to the + default queryset. """ - Return a serializer instance to use with paginated data. + filter_backends = self.filter_backends or [] + if not filter_backends and self.filter_backend: + warnings.warn( + 'The `filter_backend` attribute and `FILTER_BACKEND` setting ' + 'are due to be deprecated in favor of a `filter_backends` ' + 'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take ' + 'a *list* of filter backend classes.', + PendingDeprecationWarning, stacklevel=2 + ) + filter_backends = [self.filter_backend] + + for backend in filter_backends: + queryset = backend().filter_queryset(self.request, queryset, self) + return queryset + + ######################## + ### The following methods provide default implementations + ### that you may want to override for more complex cases. + + def get_paginate_by(self, queryset=None): """ - class SerializerClass(self.pagination_serializer_class): - class Meta: - object_serializer_class = self.get_serializer_class() + Return the size of pages to use with pagination. - pagination_serializer_class = SerializerClass - context = self.get_serializer_context() - return pagination_serializer_class(instance=page, context=context) + If `PAGINATE_BY_PARAM` is set it will attempt to get the page size + from a named query parameter in the url, eg. ?page_size=100 - def get_paginate_by(self, queryset): - """ - Return the size of pages to use with pagination. + Otherwise defaults to using `self.paginate_by`. """ + if queryset is not None: + warnings.warn('The `queryset` parameter to `get_paginate_by()` ' + 'is due to be deprecated.', + PendingDeprecationWarning, stacklevel=2) + if self.paginate_by_param: query_params = self.request.QUERY_PARAMS try: return int(query_params[self.paginate_by_param]) except (KeyError, ValueError): pass + return self.paginate_by + def get_serializer_class(self): + """ + Return the class to use for the serializer. + Defaults to using `self.serializer_class`. + + You may want to override this if you need to provide different + serializations depending on the incoming request. -class SingleObjectAPIView(SingleObjectMixin, GenericAPIView): - """ - Base class for generic views onto a model instance. - """ + (Eg. admins get full serialization, others get basic serilization) + """ + serializer_class = self.serializer_class + if serializer_class is not None: + return serializer_class - pk_url_kwarg = 'pk' # Not provided in Django 1.3 - slug_url_kwarg = 'slug' # Not provided in Django 1.3 - slug_field = 'slug' + assert self.model is not None, \ + "'%s' should either include a 'serializer_class' attribute, " \ + "or use the 'model' attribute as a shortcut for " \ + "automatically generating a serializer class." \ + % self.__class__.__name__ + + class DefaultSerializer(self.model_serializer_class): + class Meta: + model = self.model + return DefaultSerializer + + def get_queryset(self): + """ + Get the list of items for this view. + This must be an iterable, and may be a queryset. + Defaults to using `self.queryset`. + + You may want to override this if you need to provide different + querysets depending on the incoming request. + + (Eg. return a list of items that is specific to the user) + """ + if self.queryset is not None: + return self.queryset._clone() + + if self.model is not None: + return self.model._default_manager.all() + + raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" + % self.__class__.__name__) def get_object(self, queryset=None): """ - Override default to add support for object-level permissions. + Returns the object the view is displaying. + + You may want to override this if you need to provide non-standard + queryset lookups. Eg if objects are referenced using multiple + keyword arguments in the url conf. """ - obj = super(SingleObjectAPIView, self).get_object(queryset) + # Determine the base queryset to use. + if queryset is None: + queryset = self.filter_queryset(self.get_queryset()) + else: + pass # Deprecation warning + + # Perform the lookup filtering. + pk = self.kwargs.get(self.pk_url_kwarg, None) + slug = self.kwargs.get(self.slug_url_kwarg, None) + lookup = self.kwargs.get(self.lookup_field, None) + + if lookup is not None: + filter_kwargs = {self.lookup_field: lookup} + elif pk is not None and self.lookup_field == 'pk': + warnings.warn( + 'The `pk_url_kwarg` attribute is due to be deprecated. ' + 'Use the `lookup_field` attribute instead', + PendingDeprecationWarning + ) + filter_kwargs = {'pk': pk} + elif slug is not None and self.lookup_field == 'pk': + warnings.warn( + 'The `slug_url_kwarg` attribute is due to be deprecated. ' + 'Use the `lookup_field` attribute instead', + PendingDeprecationWarning + ) + filter_kwargs = {self.slug_field: slug} + else: + raise ConfigurationError( + 'Expected view %s to be called with a URL keyword argument ' + 'named "%s". Fix your URL conf, or set the `.lookup_field` ' + 'attribute on the view correctly.' % + (self.__class__.__name__, self.lookup_field) + ) + + obj = get_object_or_404(queryset, **filter_kwargs) + + # May raise a permission denied self.check_object_permissions(self.request, obj) + return obj + ######################## + ### The following are placeholder methods, + ### and are intended to be overridden. + ### + ### The are not called by GenericAPIView directly, + ### but are used by the mixin methods. + + def pre_save(self, obj): + """ + Placeholder method for calling before saving an object. + + May be used to set attributes on the object that are implicit + in either the request, or the url. + """ + pass + + def post_save(self, obj, created=False): + """ + Placeholder method for calling after saving an object. + """ + pass -### Concrete view classes that provide method handlers ### -### by composing the mixin classes with a base view. ### +########################################################## +### Concrete view classes that provide method handlers ### +### by composing the mixin classes with the base view. ### +########################################################## class CreateAPIView(mixins.CreateModelMixin, GenericAPIView): @@ -150,7 +327,7 @@ class CreateAPIView(mixins.CreateModelMixin, class ListAPIView(mixins.ListModelMixin, - MultipleObjectAPIView): + GenericAPIView): """ Concrete view for listing a queryset. """ @@ -159,7 +336,7 @@ class ListAPIView(mixins.ListModelMixin, class RetrieveAPIView(mixins.RetrieveModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving a model instance. """ @@ -168,7 +345,7 @@ class RetrieveAPIView(mixins.RetrieveModelMixin, class DestroyAPIView(mixins.DestroyModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for deleting a model instance. @@ -178,7 +355,7 @@ class DestroyAPIView(mixins.DestroyModelMixin, class UpdateAPIView(mixins.UpdateModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for updating a model instance. @@ -187,13 +364,12 @@ class UpdateAPIView(mixins.UpdateModelMixin, return self.update(request, *args, **kwargs) def patch(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) + return self.partial_update(request, *args, **kwargs) class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, - MultipleObjectAPIView): + GenericAPIView): """ Concrete view for listing a queryset or creating a model instance. """ @@ -206,7 +382,7 @@ class ListCreateAPIView(mixins.ListModelMixin, class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving, updating a model instance. """ @@ -217,13 +393,12 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, return self.update(request, *args, **kwargs) def patch(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) + return self.partial_update(request, *args, **kwargs) class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, mixins.DestroyModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving or deleting a model instance. """ @@ -237,7 +412,7 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving, updating or deleting a model instance. """ @@ -248,8 +423,31 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, return self.update(request, *args, **kwargs) def patch(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) + return self.partial_update(request, *args, **kwargs) def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) + + +########################## +### Deprecated classes ### +########################## + +class MultipleObjectAPIView(GenericAPIView): + def __init__(self, *args, **kwargs): + warnings.warn( + 'Subclassing `MultipleObjectAPIView` is due to be deprecated. ' + 'You should simply subclass `GenericAPIView` instead.', + PendingDeprecationWarning, stacklevel=2 + ) + super(MultipleObjectAPIView, self).__init__(*args, **kwargs) + + +class SingleObjectAPIView(GenericAPIView): + def __init__(self, *args, **kwargs): + warnings.warn( + 'Subclassing `SingleObjectAPIView` is due to be deprecated. ' + 'You should simply subclass `GenericAPIView` instead.', + PendingDeprecationWarning, stacklevel=2 + ) + super(SingleObjectAPIView, self).__init__(*args, **kwargs) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 97201c4b..f3cd5868 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -10,9 +10,10 @@ from django.http import Http404 from rest_framework import status from rest_framework.response import Response from rest_framework.request import clone_request +import warnings -def _get_validation_exclusions(obj, pk=None, slug_field=None): +def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None): """ Given a model instance, and an optional pk and slug field, return the full list of all other field names on that model. @@ -23,28 +24,32 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None): include = [] if pk: + # Pending deprecation pk_field = obj._meta.pk while pk_field.rel: pk_field = pk_field.rel.to._meta.pk include.append(pk_field.name) if slug_field: + # Pending deprecation include.append(slug_field) + if lookup_field and lookup_field != 'pk': + include.append(lookup_field) + return [field.name for field in obj._meta.fields if field.name not in include] class CreateModelMixin(object): """ Create a model instance. - Should be mixed in with any `GenericAPIView`. """ def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.DATA, files=request.FILES) if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() + self.object = serializer.save(force_insert=True) self.post_save(self.object, created=True) headers = self.get_success_headers(serializer.data) return Response(serializer.data, status=status.HTTP_201_CREATED, @@ -62,28 +67,28 @@ class CreateModelMixin(object): class ListModelMixin(object): """ List a queryset. - Should be mixed in with `MultipleObjectAPIView`. """ empty_error = "Empty list and '%(class_name)s.allow_empty' is False." def list(self, request, *args, **kwargs): - queryset = self.get_queryset() - self.object_list = self.filter_queryset(queryset) + self.object_list = self.filter_queryset(self.get_queryset()) # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. - allow_empty = self.get_allow_empty() - if not allow_empty and not self.object_list: + if not self.allow_empty and not self.object_list: + warnings.warn( + 'The `allow_empty` parameter is due to be deprecated. ' + 'To use `allow_empty=False` style behavior, You should override ' + '`get_queryset()` and explicitly raise a 404 on empty querysets.', + PendingDeprecationWarning + ) class_name = self.__class__.__name__ error_msg = self.empty_error % {'class_name': class_name} raise Http404(error_msg) - # Pagination size is set by the `.paginate_by` attribute, - # which may be `None` to disable pagination. - page_size = self.get_paginate_by(self.object_list) - if page_size: - packed = self.paginate_queryset(self.object_list, page_size) - paginator, page, queryset, is_paginated = packed + # Switch between paginated or standard style responses + page = self.paginate_queryset(self.object_list) + if page is not None: serializer = self.get_pagination_serializer(page) else: serializer = self.get_serializer(self.object_list, many=True) @@ -94,7 +99,6 @@ class ListModelMixin(object): class RetrieveModelMixin(object): """ Retrieve a model instance. - Should be mixed in with `SingleObjectAPIView`. """ def retrieve(self, request, *args, **kwargs): self.object = self.get_object() @@ -105,21 +109,28 @@ class RetrieveModelMixin(object): class UpdateModelMixin(object): """ Update a model instance. - Should be mixed in with `SingleObjectAPIView`. """ - def update(self, request, *args, **kwargs): - partial = kwargs.pop('partial', False) - self.object = None + def get_object_or_none(self): try: - self.object = self.get_object() + return self.get_object() except Http404: # If this is a PUT-as-create operation, we need to ensure that # we have relevant permissions, as if this was a POST request. - self.check_permissions(clone_request(request, 'POST')) + # This will either raise a PermissionDenied exception, + # or simply return None + self.check_permissions(clone_request(self.request, 'POST')) + + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + self.object = self.get_object_or_none() + + if self.object is None: created = True + save_kwargs = {'force_insert': True} success_status_code = status.HTTP_201_CREATED else: created = False + save_kwargs = {'force_update': True} success_status_code = status.HTTP_200_OK serializer = self.get_serializer(self.object, data=request.DATA, @@ -127,20 +138,28 @@ class UpdateModelMixin(object): if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() + self.object = serializer.save(**save_kwargs) self.post_save(self.object, created=created) return Response(serializer.data, status=success_status_code) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def partial_update(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + def pre_save(self, obj): """ Set any attributes on the object that are implicit in the request. """ # pk and/or slug attributes are implicit in the URL. + lookup = self.kwargs.get(self.lookup_field, None) pk = self.kwargs.get(self.pk_url_kwarg, None) slug = self.kwargs.get(self.slug_url_kwarg, None) - slug_field = slug and self.get_slug_field() or None + slug_field = slug and self.slug_field or None + + if lookup: + setattr(obj, self.lookup_field, lookup) if pk: setattr(obj, 'pk', pk) @@ -151,14 +170,13 @@ class UpdateModelMixin(object): # Ensure we clean the attributes so that we don't eg return integer # pk using a string representation, as provided by the url conf kwarg. if hasattr(obj, 'full_clean'): - exclude = _get_validation_exclusions(obj, pk, slug_field) + exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field) obj.full_clean(exclude) class DestroyModelMixin(object): """ Destroy a model instance. - Should be mixed in with `SingleObjectAPIView`. """ def destroy(self, request, *args, **kwargs): obj = self.get_object() diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index 0694d35f..4d205c0e 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -1,3 +1,7 @@ +""" +Content negotiation deals with selecting an appropriate renderer given the +incoming request. Typically this will be based on the request's Accept header. +""" from __future__ import unicode_literals from django.http import Http404 from rest_framework import exceptions diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 03a7a30f..d51ea929 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -1,9 +1,11 @@ +""" +Pagination serializers determine the structure of the output that should +be used for paginated responses. +""" from __future__ import unicode_literals from rest_framework import serializers from rest_framework.templatetags.rest_framework import replace_query_param -# TODO: Support URLconf kwarg-style paging - class NextPageField(serializers.Field): """ diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 491acd68..25be2e6a 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -6,9 +6,10 @@ on the request, such as form content or json encoded data. """ from __future__ import unicode_literals from django.conf import settings +from django.core.files.uploadhandler import StopFutureHandlers from django.http import QueryDict from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser -from django.http.multipartparser import MultiPartParserError +from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter from rest_framework.compat import yaml, etree from rest_framework.exceptions import ParseError from rest_framework.compat import six @@ -205,3 +206,90 @@ class XMLParser(BaseParser): pass return value + + +class FileUploadParser(BaseParser): + """ + Parser for file upload data. + """ + media_type = '*/*' + + def parse(self, stream, media_type=None, parser_context=None): + """ + Returns a DataAndFiles object. + + `.data` will be None (we expect request body to be a file content). + `.files` will be a `QueryDict` containing one 'file' element. + """ + + parser_context = parser_context or {} + request = parser_context['request'] + encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + meta = request.META + upload_handlers = request.upload_handlers + filename = self.get_filename(stream, media_type, parser_context) + + # Note that this code is extracted from Django's handling of + # file uploads in MultiPartParser. + content_type = meta.get('HTTP_CONTENT_TYPE', + meta.get('CONTENT_TYPE', '')) + try: + content_length = int(meta.get('HTTP_CONTENT_LENGTH', + meta.get('CONTENT_LENGTH', 0))) + except (ValueError, TypeError): + content_length = None + + # See if the handler will want to take care of the parsing. + for handler in upload_handlers: + result = handler.handle_raw_input(None, + meta, + content_length, + None, + encoding) + if result is not None: + return DataAndFiles(None, {'file': result[1]}) + + # This is the standard case. + possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size] + chunk_size = min([2 ** 31 - 4] + possible_sizes) + chunks = ChunkIter(stream, chunk_size) + counters = [0] * len(upload_handlers) + + for handler in upload_handlers: + try: + handler.new_file(None, filename, content_type, + content_length, encoding) + except StopFutureHandlers: + break + + for chunk in chunks: + for i, handler in enumerate(upload_handlers): + chunk_length = len(chunk) + chunk = handler.receive_data_chunk(chunk, counters[i]) + counters[i] += chunk_length + if chunk is None: + break + + for i, handler in enumerate(upload_handlers): + file_obj = handler.file_complete(counters[i]) + if file_obj: + return DataAndFiles(None, {'file': file_obj}) + raise ParseError("FileUpload parse error - " + "none of upload handlers can handle the stream") + + def get_filename(self, stream, media_type, parser_context): + """ + Detects the uploaded file name. First searches a 'filename' url kwarg. + Then tries to parse Content-Disposition header. + """ + try: + return parser_context['kwargs']['filename'] + except KeyError: + pass + + try: + meta = parser_context['request'].META + disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) + return disposition[1]['filename'] + except (AttributeError, KeyError): + pass diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 306f00ca..45fcfd66 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -7,6 +7,8 @@ import warnings SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS'] +from rest_framework.compat import oauth2_provider_scope, oauth2_constants + class BasePermission(object): """ @@ -23,10 +25,12 @@ class BasePermission(object): """ Return `True` if permission is granted, `False` otherwise. """ - if len(inspect.getargspec(self.has_permission)[0]) == 4: - warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. ' - 'Use `has_object_permission()` instead for object permissions.', - PendingDeprecationWarning, stacklevel=2) + if len(inspect.getargspec(self.has_permission).args) == 4: + warnings.warn( + 'The `obj` argument in `has_permission` is deprecated. ' + 'Use `has_object_permission()` instead for object permissions.', + DeprecationWarning, stacklevel=2 + ) return self.has_permission(request, view, obj) return True @@ -85,8 +89,8 @@ class DjangoModelPermissions(BasePermission): It ensures that the user is authenticated, and has the appropriate `add`/`change`/`delete` permissions on the model. - This permission will only be applied against view classes that - provide a `.model` attribute, such as the generic class-based views. + This permission can only be applied against view classes that + provide a `.model` or `.queryset` attribute. """ # Map methods into required permission codes. @@ -102,6 +106,8 @@ class DjangoModelPermissions(BasePermission): 'DELETE': ['%(app_label)s.delete_%(model_name)s'], } + authenticated_users_only = True + def get_required_permissions(self, method, model_cls): """ Given a model and an HTTP method, return the list of permission @@ -115,13 +121,54 @@ class DjangoModelPermissions(BasePermission): def has_permission(self, request, view): model_cls = getattr(view, 'model', None) - if not model_cls: + queryset = getattr(view, 'queryset', None) + + if model_cls is None and queryset is not None: + model_cls = queryset.model + + # Workaround to ensure DjangoModelPermissions are not applied + # to the root view when using DefaultRouter. + if model_cls is None and getattr(view, '_ignore_model_permissions'): return True + assert model_cls, ('Cannot apply DjangoModelPermissions on a view that' + ' does not have `.model` or `.queryset` property.') + perms = self.get_required_permissions(request.method, model_cls) if (request.user and - request.user.is_authenticated() and + (request.user.is_authenticated() or not self.authenticated_users_only) and request.user.has_perms(perms)): return True return False + + +class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions): + """ + Similar to DjangoModelPermissions, except that anonymous users are + allowed read-only access. + """ + authenticated_users_only = False + + +class TokenHasReadWriteScope(BasePermission): + """ + The request is authenticated as a user and the token used has the right scope + """ + + def has_permission(self, request, view): + token = request.auth + read_only = request.method in SAFE_METHODS + + if not token: + return False + + if hasattr(token, 'resource'): # OAuth 1 + return read_only or not request.auth.resource.is_readonly + elif hasattr(token, 'scope'): # OAuth 2 + required = oauth2_constants.READ if read_only else oauth2_constants.WRITE + return oauth2_provider_scope.check(required, request.auth.scope) + + assert False, ('TokenHasReadWriteScope requires either the' + '`OAuthAuthentication` or `OAuth2Authentication` authentication ' + 'class to be used.') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index ef465b3c..884b954c 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,3 +1,9 @@ +""" +Serializer fields that deal with relationships. + +These fields allow you to specify the style that should be used to represent +model relationships, including hyperlinks, primary keys, or slugs. +""" from __future__ import unicode_literals from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch @@ -36,9 +42,9 @@ class RelatedField(WritableField): # 'null' is to be deprecated in favor of 'required' if 'null' in kwargs: - warnings.warn('The `null` keyword argument is due to be deprecated. ' + warnings.warn('The `null` keyword argument is deprecated. ' 'Use the `required` keyword argument instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['required'] = not kwargs.pop('null') self.queryset = kwargs.pop('queryset', None) @@ -243,7 +249,6 @@ class PrimaryKeyRelatedField(RelatedField): pk = getattr(obj, self.source or field_name).pk except ObjectDoesNotExist: return None - return self.to_native(obj.pk) # Forward relationship return self.to_native(pk) @@ -291,10 +296,8 @@ class HyperlinkedRelatedField(RelatedField): """ Represents a relationship using hyperlinking. """ - pk_url_kwarg = 'pk' - slug_field = 'slug' - slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden read_only = False + lookup_field = 'pk' default_error_messages = { 'no_match': _('Invalid hyperlink - No URL match'), @@ -304,69 +307,138 @@ class HyperlinkedRelatedField(RelatedField): 'incorrect_type': _('Incorrect type. Expected url string, received %s.'), } + # These are all pending deprecation + pk_url_kwarg = 'pk' + slug_field = 'slug' + slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden + def __init__(self, *args, **kwargs): try: self.view_name = kwargs.pop('view_name') except KeyError: raise ValueError("Hyperlinked field requires 'view_name' kwarg") + self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + self.format = kwargs.pop('format', None) + + # These are pending deprecation + if 'pk_url_kwarg' in kwargs: + msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_url_kwarg' in kwargs: + msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_field' in kwargs: + msg = 'slug_field is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + + self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) self.slug_field = kwargs.pop('slug_field', self.slug_field) default_slug_kwarg = self.slug_url_kwarg or self.slug_field - self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) - self.format = kwargs.pop('format', None) super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) - def get_slug_field(self): - """ - Get the name of a slug field to be used to look up by slug. + def get_url(self, obj, view_name, request, format): """ - return self.slug_field - - def to_native(self, obj): - view_name = self.view_name - request = self.context.get('request', None) - format = self.format or self.context.get('format', None) + Given an object, return the URL that hyperlinks to the object. - if request is None: - warnings.warn("Using `HyperlinkedRelatedField` without including the " - "request in the serializer context is due to be deprecated. " - "Add `context={'request': request}` when instantiating the serializer.", - PendingDeprecationWarning, stacklevel=4) - - pk = getattr(obj, 'pk', None) - if pk is None: - return - kwargs = {self.pk_url_kwarg: pk} + May raise a `NoReverseMatch` if the `view_name` and `lookup_field` + attributes are not configured to correctly match the URL conf. + """ + lookup_field = getattr(obj, self.lookup_field) + kwargs = {self.lookup_field: lookup_field} try: return reverse(view_name, kwargs=kwargs, request=request, format=format) except NoReverseMatch: pass + if self.pk_url_kwarg != 'pk': + # Only try pk if it has been explicitly set. + # Otherwise, the default `lookup_field = 'pk'` has us covered. + pk = obj.pk + kwargs = {self.pk_url_kwarg: pk} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass + slug = getattr(obj, self.slug_field, None) + if slug is not None: + # Only try slug if it corresponds to an attribute on the object. + kwargs = {self.slug_url_kwarg: slug} + try: + ret = reverse(view_name, kwargs=kwargs, request=request, format=format) + if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug': + # If the lookup succeeds using the default slug params, + # then `slug_field` is being used implicitly, and we + # we need to warn about the pending deprecation. + msg = 'Implicit slug field hyperlinked fields are pending deprecation.' \ + 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + return ret + except NoReverseMatch: + pass + + raise NoReverseMatch() + + def get_object(self, queryset, view_name, view_args, view_kwargs): + """ + Return the object corresponding to a matched URL. - if not slug: - raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + Takes the matched URL conf arguments, and the queryset, and should + return an object instance, or raise an `ObjectDoesNotExist` exception. + """ + lookup = view_kwargs.get(self.lookup_field, None) + pk = view_kwargs.get(self.pk_url_kwarg, None) + slug = view_kwargs.get(self.slug_url_kwarg, None) + + if lookup is not None: + filter_kwargs = {self.lookup_field: lookup} + elif pk is not None: + filter_kwargs = {'pk': pk} + elif slug is not None: + filter_kwargs = {self.slug_field: slug} + else: + raise ObjectDoesNotExist() - kwargs = {self.slug_url_kwarg: slug} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass + return queryset.get(**filter_kwargs) - kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} + def to_native(self, obj): + view_name = self.view_name + request = self.context.get('request', None) + format = self.format or self.context.get('format', None) + + if request is None: + msg = ( + "Using `HyperlinkedRelatedField` without including the request " + "in the serializer context is deprecated. " + "Add `context={'request': request}` when instantiating " + "the serializer." + ) + warnings.warn(msg, DeprecationWarning, stacklevel=4) + + # If the object has not yet been saved then we cannot hyperlink to it. + if getattr(obj, 'pk', None) is None: + return + + # Return the hyperlink, or error if incorrectly configured. try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) + return self.get_url(obj, view_name, request, format) except NoReverseMatch: - pass - - raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + msg = ( + 'Could not resolve URL for hyperlinked relationship using ' + 'view name "%s". You may have failed to include the related ' + 'model in your API, or incorrectly configured the ' + '`lookup_field` attribute on this field.' + ) + raise Exception(msg % view_name) def from_native(self, value): # Convert URL -> model instance pk # TODO: Use values_list - if self.queryset is None: + queryset = self.queryset + if queryset is None: raise Exception('Writable related fields must include a `queryset` argument') try: @@ -390,39 +462,24 @@ class HyperlinkedRelatedField(RelatedField): if match.view_name != self.view_name: raise ValidationError(self.error_messages['incorrect_match']) - pk = match.kwargs.get(self.pk_url_kwarg, None) - slug = match.kwargs.get(self.slug_url_kwarg, None) - - # Try explicit primary key. - if pk is not None: - queryset = self.queryset.filter(pk=pk) - # Next, try looking up by slug. - elif slug is not None: - slug_field = self.get_slug_field() - queryset = self.queryset.filter(**{slug_field: slug}) - # If none of those are defined, it's probably a configuation error. - else: - raise ValidationError(self.error_messages['configuration_error']) - try: - obj = queryset.get() - except ObjectDoesNotExist: + return self.get_object(queryset, match.view_name, + match.args, match.kwargs) + except (ObjectDoesNotExist, TypeError, ValueError): raise ValidationError(self.error_messages['does_not_exist']) - except (TypeError, ValueError): - msg = self.error_messages['incorrect_type'] - raise ValidationError(msg % type(value).__name__) - - return obj class HyperlinkedIdentityField(Field): """ Represents the instance, or a property on the instance, using hyperlinking. """ + lookup_field = 'pk' + read_only = True + + # These are all pending deprecation pk_url_kwarg = 'pk' slug_field = 'slug' slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden - read_only = True def __init__(self, *args, **kwargs): # TODO: Make view_name mandatory, and have the @@ -431,6 +488,19 @@ class HyperlinkedIdentityField(Field): # Optionally the format of the target hyperlink may be specified self.format = kwargs.pop('format', None) + self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + + # These are pending deprecation + if 'pk_url_kwarg' in kwargs: + msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_url_kwarg' in kwargs: + msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_field' in kwargs: + msg = 'slug_field is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + self.slug_field = kwargs.pop('slug_field', self.slug_field) default_slug_kwarg = self.slug_url_kwarg or self.slug_field self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) @@ -442,13 +512,14 @@ class HyperlinkedIdentityField(Field): request = self.context.get('request', None) format = self.context.get('format', None) view_name = self.view_name or self.parent.opts.view_name - kwargs = {self.pk_url_kwarg: obj.pk} + lookup_field = getattr(obj, self.lookup_field) + kwargs = {self.lookup_field: lookup_field} if request is None: warnings.warn("Using `HyperlinkedIdentityField` without including the " - "request in the serializer context is due to be deprecated. " + "request in the serializer context is deprecated. " "Add `context={'request': request}` when instantiating the serializer.", - PendingDeprecationWarning, stacklevel=4) + DeprecationWarning, stacklevel=4) # By default use whatever format is given for the current context # unless the target is a different type to the source. @@ -491,35 +562,35 @@ class HyperlinkedIdentityField(Field): class ManyRelatedField(RelatedField): def __init__(self, *args, **kwargs): - warnings.warn('`ManyRelatedField()` is due to be deprecated. ' + warnings.warn('`ManyRelatedField()` is deprecated. ' 'Use `RelatedField(many=True)` instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['many'] = True super(ManyRelatedField, self).__init__(*args, **kwargs) class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField): def __init__(self, *args, **kwargs): - warnings.warn('`ManyPrimaryKeyRelatedField()` is due to be deprecated. ' + warnings.warn('`ManyPrimaryKeyRelatedField()` is deprecated. ' 'Use `PrimaryKeyRelatedField(many=True)` instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['many'] = True super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs) class ManySlugRelatedField(SlugRelatedField): def __init__(self, *args, **kwargs): - warnings.warn('`ManySlugRelatedField()` is due to be deprecated. ' + warnings.warn('`ManySlugRelatedField()` is deprecated. ' 'Use `SlugRelatedField(many=True)` instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['many'] = True super(ManySlugRelatedField, self).__init__(*args, **kwargs) class ManyHyperlinkedRelatedField(HyperlinkedRelatedField): def __init__(self, *args, **kwargs): - warnings.warn('`ManyHyperlinkedRelatedField()` is due to be deprecated. ' + warnings.warn('`ManyHyperlinkedRelatedField()` is deprecated. ' 'Use `HyperlinkedRelatedField(many=True)` instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['many'] = True super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 4c15e0db..1917a080 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -24,6 +24,7 @@ from rest_framework.settings import api_settings from rest_framework.request import clone_request from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs +from rest_framework.utils.formatting import get_view_name, get_view_description from rest_framework import exceptions, parsers, status, VERSION @@ -57,7 +58,7 @@ class JSONRenderer(BaseRenderer): return '' # If 'indent' is provided in the context, then pretty print the result. - # E.g. If we're being called by the BrowseableAPIRenderer. + # E.g. If we're being called by the BrowsableAPIRenderer. renderer_context = renderer_context or {} indent = renderer_context.get('indent', None) @@ -438,16 +439,13 @@ class BrowsableAPIRenderer(BaseRenderer): return GenericContentForm() def get_name(self, view): - try: - return view.get_name() - except AttributeError: - return smart_text(view.__class__.__name__) + return get_view_name(view.__class__, getattr(view, 'suffix', None)) def get_description(self, view): - try: - return view.get_description(html=True) - except AttributeError: - return smart_text(view.__doc__ or '') + return get_view_description(view.__class__, html=True) + + def get_breadcrumbs(self, request): + return get_breadcrumbs(request.path) def render(self, data, accepted_media_type=None, renderer_context=None): """ @@ -480,7 +478,7 @@ class BrowsableAPIRenderer(BaseRenderer): name = self.get_name(view) description = self.get_description(view) - breadcrumb_list = get_breadcrumbs(request.path) + breadcrumb_list = self.get_breadcrumbs(request) template = loader.get_template(self.template) context = RequestContext(request, { diff --git a/rest_framework/request.py b/rest_framework/request.py index 3e2fbd88..a434659c 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -1,11 +1,10 @@ """ -The :mod:`request` module provides a :class:`Request` class used to wrap the standard `request` -object received in all the views. +The Request class is used as a wrapper around the standard request object. The wrapped request then offers a richer API, in particular : - content automatically parsed according to `Content-Type` header, - and available as :meth:`.DATA<Request.DATA>` + and available as `request.DATA` - full support of PUT method, including support for file uploads - form overloading of HTTP method, content type and content """ @@ -231,11 +230,17 @@ class Request(object): """ self._content_type = self.META.get('HTTP_CONTENT_TYPE', self.META.get('CONTENT_TYPE', '')) + self._perform_form_overloading() - # if the HTTP method was not overloaded, we take the raw HTTP method + if not _hasattr(self, '_method'): self._method = self._request.method + if self._method == 'POST': + # Allow X-HTTP-METHOD-OVERRIDE header + self._method = self.META.get('HTTP_X_HTTP_METHOD_OVERRIDE', + self._method) + def _load_stream(self): """ Return the content body of the request, as a stream. diff --git a/rest_framework/response.py b/rest_framework/response.py index 5e1bf46e..26e4ab37 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -1,3 +1,9 @@ +""" +The Response class in REST framework is similiar to HTTPResponse, except that +it is initialized with unrendered data, instead of a pre-rendered string. + +The appropriate renderer is called during Django's template response rendering. +""" from __future__ import unicode_literals from django.core.handlers.wsgi import STATUS_CODE_TEXT from django.template.response import SimpleTemplateResponse diff --git a/rest_framework/routers.py b/rest_framework/routers.py new file mode 100644 index 00000000..dba104c3 --- /dev/null +++ b/rest_framework/routers.py @@ -0,0 +1,249 @@ +""" +Routers provide a convenient and consistent way of automatically +determining the URL conf for your API. + +They are used by simply instantiating a Router class, and then registering +all the required ViewSets with that router. + +For example, you might have a `urls.py` that looks something like this: + + router = routers.DefaultRouter() + router.register('users', UserViewSet, 'user') + router.register('accounts', AccountViewSet, 'account') + + urlpatterns = router.urls +""" +from __future__ import unicode_literals + +from collections import namedtuple +from rest_framework import views +from rest_framework.compat import patterns, url +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.reverse import reverse +from rest_framework.urlpatterns import format_suffix_patterns + + +Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs']) + + +def replace_methodname(format_string, methodname): + """ + Partially format a format_string, swapping out any + '{methodname}' or '{methodnamehyphen}' components. + """ + methodnamehyphen = methodname.replace('_', '-') + ret = format_string + ret = ret.replace('{methodname}', methodname) + ret = ret.replace('{methodnamehyphen}', methodnamehyphen) + return ret + + +class BaseRouter(object): + def __init__(self): + self.registry = [] + + def register(self, prefix, viewset, base_name=None): + if base_name is None: + base_name = self.get_default_base_name(viewset) + self.registry.append((prefix, viewset, base_name)) + + def get_default_base_name(self, viewset): + """ + If `base_name` is not specified, attempt to automatically determine + it from the viewset. + """ + raise NotImplemented('get_default_base_name must be overridden') + + def get_urls(self): + """ + Return a list of URL patterns, given the registered viewsets. + """ + raise NotImplemented('get_urls must be overridden') + + @property + def urls(self): + if not hasattr(self, '_urls'): + self._urls = patterns('', *self.get_urls()) + return self._urls + + +class SimpleRouter(BaseRouter): + routes = [ + # List route. + Route( + url=r'^{prefix}/$', + mapping={ + 'get': 'list', + 'post': 'create' + }, + name='{basename}-list', + initkwargs={'suffix': 'List'} + ), + # Detail route. + Route( + url=r'^{prefix}/{lookup}/$', + mapping={ + 'get': 'retrieve', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy' + }, + name='{basename}-detail', + initkwargs={'suffix': 'Instance'} + ), + # Dynamically generated routes. + # Generated using @action or @link decorators on methods of the viewset. + Route( + url=r'^{prefix}/{lookup}/{methodname}/$', + mapping={ + '{httpmethod}': '{methodname}', + }, + name='{basename}-{methodnamehyphen}', + initkwargs={} + ), + ] + + def get_default_base_name(self, viewset): + """ + If `base_name` is not specified, attempt to automatically determine + it from the viewset. + """ + model_cls = getattr(viewset, 'model', None) + queryset = getattr(viewset, 'queryset', None) + if model_cls is None and queryset is not None: + model_cls = queryset.model + + assert model_cls, '`name` not argument not specified, and could ' \ + 'not automatically determine the name from the viewset, as ' \ + 'it does not have a `.model` or `.queryset` attribute.' + + return model_cls._meta.object_name.lower() + + def get_routes(self, viewset): + """ + Augment `self.routes` with any dynamically generated routes. + + Returns a list of the Route namedtuple. + """ + + # Determine any `@action` or `@link` decorated methods on the viewset + dynamic_routes = [] + for methodname in dir(viewset): + attr = getattr(viewset, methodname) + httpmethod = getattr(attr, 'bind_to_method', None) + if httpmethod: + dynamic_routes.append((httpmethod, methodname)) + + ret = [] + for route in self.routes: + if route.mapping == {'{httpmethod}': '{methodname}'}: + # Dynamic routes (@link or @action decorator) + for httpmethod, methodname in dynamic_routes: + initkwargs = route.initkwargs.copy() + initkwargs.update(getattr(viewset, methodname).kwargs) + ret.append(Route( + url=replace_methodname(route.url, methodname), + mapping={httpmethod: methodname}, + name=replace_methodname(route.name, methodname), + initkwargs=initkwargs, + )) + else: + # Standard route + ret.append(route) + + return ret + + def get_method_map(self, viewset, method_map): + """ + Given a viewset, and a mapping of http methods to actions, + return a new mapping which only includes any mappings that + are actually implemented by the viewset. + """ + bound_methods = {} + for method, action in method_map.items(): + if hasattr(viewset, action): + bound_methods[method] = action + return bound_methods + + def get_lookup_regex(self, viewset): + """ + Given a viewset, return the portion of URL regex that is used + to match against a single instance. + """ + base_regex = '(?P<{lookup_field}>[^/]+)' + lookup_field = getattr(viewset, 'lookup_field', 'pk') + return base_regex.format(lookup_field=lookup_field) + + def get_urls(self): + """ + Use the registered viewsets to generate a list of URL patterns. + """ + ret = [] + + for prefix, viewset, basename in self.registry: + lookup = self.get_lookup_regex(viewset) + routes = self.get_routes(viewset) + + for route in routes: + + # Only actions which actually exist on the viewset will be bound + mapping = self.get_method_map(viewset, route.mapping) + if not mapping: + continue + + # Build the url pattern + regex = route.url.format(prefix=prefix, lookup=lookup) + view = viewset.as_view(mapping, **route.initkwargs) + name = route.name.format(basename=basename) + ret.append(url(regex, view, name=name)) + + return ret + + +class DefaultRouter(SimpleRouter): + """ + The default router extends the SimpleRouter, but also adds in a default + API root view, and adds format suffix patterns to the URLs. + """ + include_root_view = True + include_format_suffixes = True + + def get_api_root_view(self): + """ + Return a view to use as the API root. + """ + api_root_dict = {} + list_name = self.routes[0].name + for prefix, viewset, basename in self.registry: + api_root_dict[prefix] = list_name.format(basename=basename) + + class APIRoot(views.APIView): + _ignore_model_permissions = True + + def get(self, request, format=None): + ret = {} + for key, url_name in api_root_dict.items(): + ret[key] = reverse(url_name, request=request, format=format) + return Response(ret) + + return APIRoot.as_view() + + def get_urls(self): + """ + Generate the list of URL patterns, including a default root view + for the API, and appending `.json` style format suffixes. + """ + urls = [] + + if self.include_root_view: + root_url = url(r'^$', self.get_api_root_view(), name='api-root') + urls.append(root_url) + + default_urls = super(DefaultRouter, self).get_urls() + urls.extend(default_urls) + + if self.include_format_suffixes: + urls = format_suffix_patterns(urls) + + return urls diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index 03bfc216..9b519f27 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -97,9 +97,30 @@ INSTALLED_APPS = ( # 'django.contrib.admindocs', 'rest_framework', 'rest_framework.authtoken', - 'rest_framework.tests' + 'rest_framework.tests', ) +# OAuth is optional and won't work if there is no oauth_provider & oauth2 +try: + import oauth_provider + import oauth2 +except ImportError: + pass +else: + INSTALLED_APPS += ( + 'oauth_provider', + ) + +try: + import provider +except ImportError: + pass +else: + INSTALLED_APPS += ( + 'provider', + 'provider.oauth2', + ) + STATIC_URL = '/static/' PASSWORD_HASHERS = ( diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index ba9e9e9c..7707de7a 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1,3 +1,15 @@ +""" +Serializers and ModelSerializers are similar to Forms and ModelForms. +Unlike forms, they are not constrained to dealing with HTML output, and +form encoded input. + +Serialization in REST framework is a two-phase process: + +1. Serializers marshal between complex types like model instances, and +python primatives. +2. The process of marshalling between python primatives and request and +response content is handled by parsers and renderers. +""" from __future__ import unicode_literals import copy import datetime @@ -7,8 +19,7 @@ from django.core.paginator import Page from django.db import models from django.forms import widgets from django.utils.datastructures import SortedDict -from rest_framework.compat import get_concrete_model -from rest_framework.compat import six +from rest_framework.compat import get_concrete_model, six # Note: We do the following so that users of the framework can use this style: # @@ -21,6 +32,25 @@ from rest_framework.relations import * from rest_framework.fields import * +class NestedValidationError(ValidationError): + """ + The default ValidationError behavior is to stringify each item in the list + if the messages are a list of error messages. + + In the case of nested serializers, where the parent has many children, + then the child's `serializer.errors` will be a list of dicts. In the case + of a single child, the `serializer.errors` will be a dict. + + We need to override the default behavior to get properly nested error dicts. + """ + + def __init__(self, message): + if isinstance(message, dict): + self.messages = [message] + else: + self.messages = message + + class DictWithMetadata(dict): """ A dict-like object, that can have additional properties attached. @@ -99,7 +129,7 @@ class SerializerOptions(object): self.exclude = getattr(meta, 'exclude', ()) -class BaseSerializer(Field): +class BaseSerializer(WritableField): """ This is the Serializer implementation. We need to implement it as `BaseSerializer` due to metaclass magicks. @@ -111,13 +141,15 @@ class BaseSerializer(Field): _dict_class = SortedDictWithMetadata def __init__(self, instance=None, data=None, files=None, - context=None, partial=False, many=None, source=None): - super(BaseSerializer, self).__init__(source=source) + context=None, partial=False, many=None, + allow_add_remove=False, **kwargs): + super(BaseSerializer, self).__init__(**kwargs) self.opts = self._options_class(self.Meta) self.parent = None self.root = None self.partial = partial self.many = many + self.allow_add_remove = allow_add_remove self.context = context or {} @@ -129,6 +161,13 @@ class BaseSerializer(Field): self._data = None self._files = None self._errors = None + self._deleted = None + + if many and instance is not None and not hasattr(instance, '__iter__'): + raise ValueError('instance should be a queryset or other iterable with many=True') + + if allow_add_remove and not many: + raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') ##### # Methods to determine which fields to use when (de)serializing objects. @@ -161,7 +200,7 @@ class BaseSerializer(Field): # If 'fields' is specified, use those fields, in that order. if self.opts.fields: - assert isinstance(self.opts.fields, (list, tuple)), '`include` must be a list or tuple' + assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple' new = SortedDict() for key in self.opts.fields: new[key] = ret[key] @@ -169,7 +208,7 @@ class BaseSerializer(Field): # Remove anything in 'exclude' if self.opts.exclude: - assert isinstance(self.opts.fields, (list, tuple)), '`exclude` must be a list or tuple' + assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple' for key in self.opts.exclude: ret.pop(key, None) @@ -179,18 +218,6 @@ class BaseSerializer(Field): return ret ##### - # Field methods - used when the serializer class is itself used as a field. - - def initialize(self, parent, field_name): - """ - Same behaviour as usual Field, except that we need to keep track - of state so that we can deal with handling maximum depth. - """ - super(BaseSerializer, self).initialize(parent, field_name) - if parent.opts.depth: - self.opts.depth = parent.opts.depth - 1 - - ##### # Methods to convert or revert from objects <--> primitive representations. def get_field_key(self, field_name): @@ -285,10 +312,6 @@ class BaseSerializer(Field): """ Deserialize primitives -> objects. """ - if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): - # TODO: error data when deserializing lists - return [self.from_native(item, None) for item in data] - self._errors = {} if data is not None or files is not None: attrs = self.restore_fields(data, files) @@ -301,40 +324,91 @@ class BaseSerializer(Field): def field_to_native(self, obj, field_name): """ - Override default so that we can apply ModelSerializer as a nested - field to relationships. + Override default so that the serializer can be used as a nested field + across relationships. """ if self.source == '*': return self.to_native(obj) try: - if self.source: - for component in self.source.split('.'): - obj = getattr(obj, component) - if is_simple_callable(obj): - obj = obj() - else: - obj = getattr(obj, field_name) - if is_simple_callable(obj): - obj = obj() + source = self.source or field_name + value = obj + + for component in source.split('.'): + value = get_component(value, component) + if value is None: + break except ObjectDoesNotExist: return None - # If the object has an "all" method, assume it's a relationship - if is_simple_callable(getattr(obj, 'all', None)): - return [self.to_native(item) for item in obj.all()] + if is_simple_callable(getattr(value, 'all', None)): + return [self.to_native(item) for item in value.all()] - if obj is None: + if value is None: return None if self.many is not None: many = self.many else: - many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) + many = hasattr(value, '__iter__') and not isinstance(value, (Page, dict, six.text_type)) if many: - return [self.to_native(item) for item in obj] - return self.to_native(obj) + return [self.to_native(item) for item in value] + return self.to_native(value) + + def field_from_native(self, data, files, field_name, into): + """ + Override default so that the serializer can be used as a writable + nested field across relationships. + """ + if self.read_only: + return + + try: + value = data[field_name] + except KeyError: + if self.default is not None and not self.partial: + # Note: partial updates shouldn't set defaults + value = copy.deepcopy(self.default) + else: + if self.required: + raise ValidationError(self.error_messages['required']) + return + + # Set the serializer object if it exists + obj = getattr(self.parent.object, field_name) if self.parent.object else None + + if value in (None, ''): + into[(self.source or field_name)] = None + else: + kwargs = { + 'instance': obj, + 'data': value, + 'context': self.context, + 'partial': self.partial, + 'many': self.many + } + serializer = self.__class__(**kwargs) + + if serializer.is_valid(): + into[self.source or field_name] = serializer.object + else: + # Propagate errors up to our parent + raise NestedValidationError(serializer.errors) + + def get_identity(self, data): + """ + This hook is required for bulk update. + It is used to determine the canonical identity of a given object. + + Note that the data has not been validated at this point, so we need + to make sure that we catch any cases of incorrect datatypes being + passed to this method. + """ + try: + return data.get('id', None) + except AttributeError: + return None @property def errors(self): @@ -348,19 +422,52 @@ class BaseSerializer(Field): if self.many is not None: many = self.many else: - many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict)) + many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) if many: - warnings.warn('Implict list/queryset serialization is due to be deprecated. ' + warnings.warn('Implict list/queryset serialization is deprecated. ' 'Use the `many=True` flag when instantiating the serializer.', - PendingDeprecationWarning, stacklevel=3) + DeprecationWarning, stacklevel=3) - # TODO: error data when deserializing lists if many: - ret = [self.from_native(item, None) for item in data] - ret = self.from_native(data, files) + ret = [] + errors = [] + update = self.object is not None + + if update: + # If this is a bulk update we need to map all the objects + # to a canonical identity so we can determine which + # individual object is being updated for each item in the + # incoming data + objects = self.object + identities = [self.get_identity(self.to_native(obj)) for obj in objects] + identity_to_objects = dict(zip(identities, objects)) + + if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): + for item in data: + if update: + # Determine which object we're updating + identity = self.get_identity(item) + self.object = identity_to_objects.pop(identity, None) + if self.object is None and not self.allow_add_remove: + ret.append(None) + errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) + continue + + ret.append(self.from_native(item, None)) + errors.append(self._errors) + + if update: + self._deleted = identity_to_objects.values() + + self._errors = any(errors) and errors or [] + else: + self._errors = {'non_field_errors': ['Expected a list of items.']} + else: + ret = self.from_native(data, files) if not self._errors: self.object = ret + return self._errors def is_valid(self): @@ -379,9 +486,9 @@ class BaseSerializer(Field): else: many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) if many: - warnings.warn('Implict list/queryset serialization is due to be deprecated. ' + warnings.warn('Implict list/queryset serialization is deprecated. ' 'Use the `many=True` flag when instantiating the serializer.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) if many: self._data = [self.to_native(item) for item in obj] @@ -390,11 +497,24 @@ class BaseSerializer(Field): return self._data - def save(self): + def save_object(self, obj, **kwargs): + obj.save(**kwargs) + + def delete_object(self, obj): + obj.delete() + + def save(self, **kwargs): """ Save the deserialized object and return it. """ - self.object.save() + if isinstance(self.object, list): + [self.save_object(item, **kwargs) for item in self.object] + else: + self.save_object(self.object, **kwargs) + + if self.allow_add_remove and self._deleted: + [self.delete_object(item) for item in self._deleted] + return self.object @@ -428,6 +548,7 @@ class ModelSerializer(Serializer): models.DateTimeField: DateTimeField, models.DateField: DateField, models.TimeField: TimeField, + models.DecimalField: DecimalField, models.EmailField: EmailField, models.CharField: CharField, models.URLField: URLField, @@ -448,36 +569,94 @@ class ModelSerializer(Serializer): assert cls is not None, \ "Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__ opts = get_concrete_model(cls)._meta - pk_field = opts.pk - # while pk_field.rel: - # pk_field = pk_field.rel.to._meta.pk - fields = [pk_field] - fields += [field for field in opts.fields if field.serialize] - fields += [field for field in opts.many_to_many if field.serialize] - ret = SortedDict() nested = bool(self.opts.depth) - is_pk = True # First field in the list is the pk - - for model_field in fields: - if is_pk: - field = self.get_pk_field(model_field) - is_pk = False - elif model_field.rel and nested: - field = self.get_nested_field(model_field) - elif model_field.rel: + + # Deal with adding the primary key field + pk_field = opts.pk + while pk_field.rel and pk_field.rel.parent_link: + # If model is a child via multitable inheritance, use parent's pk + pk_field = pk_field.rel.to._meta.pk + + field = self.get_pk_field(pk_field) + if field: + ret[pk_field.name] = field + + # Deal with forward relationships + forward_rels = [field for field in opts.fields if field.serialize] + forward_rels += [field for field in opts.many_to_many if field.serialize] + + for model_field in forward_rels: + if model_field.rel: to_many = isinstance(model_field, models.fields.related.ManyToManyField) - field = self.get_related_field(model_field, to_many=to_many) + related_model = model_field.rel.to + + if model_field.rel and nested: + if len(inspect.getargspec(self.get_nested_field).args) == 2: + warnings.warn( + 'The `get_nested_field(model_field)` call signature ' + 'is due to be deprecated. ' + 'Use `get_nested_field(model_field, related_model, ' + 'to_many) instead', + PendingDeprecationWarning + ) + field = self.get_nested_field(model_field) + else: + field = self.get_nested_field(model_field, related_model, to_many) + elif model_field.rel: + if len(inspect.getargspec(self.get_nested_field).args) == 3: + warnings.warn( + 'The `get_related_field(model_field, to_many)` call ' + 'signature is due to be deprecated. ' + 'Use `get_related_field(model_field, related_model, ' + 'to_many) instead', + PendingDeprecationWarning + ) + field = self.get_related_field(model_field, to_many=to_many) + else: + field = self.get_related_field(model_field, related_model, to_many) else: field = self.get_field(model_field) if field: ret[model_field.name] = field + # Deal with reverse relationships + if not self.opts.fields: + reverse_rels = [] + else: + # Reverse relationships are only included if they are explicitly + # present in the `fields` option on the serializer + reverse_rels = opts.get_all_related_objects() + reverse_rels += opts.get_all_related_many_to_many_objects() + + for relation in reverse_rels: + accessor_name = relation.get_accessor_name() + if not self.opts.fields or accessor_name not in self.opts.fields: + continue + related_model = relation.model + to_many = relation.field.rel.multiple + + if nested: + field = self.get_nested_field(None, related_model, to_many) + else: + field = self.get_related_field(None, related_model, to_many) + + if field: + ret[accessor_name] = field + + # Add the `read_only` flag to any fields that have bee specified + # in the `read_only_fields` option for field_name in self.opts.read_only_fields: + assert field_name not in self.base_fields.keys(), \ + "field '%s' on serializer '%s' specfied in " \ + "`read_only_fields`, but also added " \ + "as an explict field. Remove it from `read_only_fields`." % \ + (field_name, self.__class__.__name__) assert field_name in ret, \ - "read_only_fields on '%s' included invalid item '%s'" % \ + "Noexistant field '%s' specified in `read_only_fields` " \ + "on serializer '%s'." % \ (self.__class__.__name__, field_name) ret[field_name].read_only = True @@ -489,27 +668,36 @@ class ModelSerializer(Serializer): """ return self.get_field(model_field) - def get_nested_field(self, model_field): + def get_nested_field(self, model_field, related_model, to_many): """ Creates a default instance of a nested relational field. + + Note that model_field will be `None` for reverse relationships. """ class NestedModelSerializer(ModelSerializer): class Meta: - model = model_field.rel.to - return NestedModelSerializer() + model = related_model + depth = self.opts.depth - 1 + + return NestedModelSerializer(many=to_many) - def get_related_field(self, model_field, to_many=False): + def get_related_field(self, model_field, related_model, to_many): """ Creates a default instance of a flat relational field. + + Note that model_field will be `None` for reverse relationships. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) + kwargs = { - 'required': not(model_field.null or model_field.blank), - 'queryset': model_field.rel.to._default_manager, + 'queryset': related_model._default_manager, 'many': to_many } + if model_field: + kwargs['required'] = not(model_field.null or model_field.blank) + return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): @@ -574,33 +762,43 @@ class ModelSerializer(Serializer): """ Restore the model instance. """ - self.m2m_data = {} - self.related_data = {} + m2m_data = {} + related_data = {} + meta = self.opts.model._meta - # Reverse fk relations - for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): + # Reverse fk or one-to-one relations + for (obj, model) in meta.get_all_related_objects_with_model(): field_name = obj.field.related_query_name() if field_name in attrs: - self.related_data[field_name] = attrs.pop(field_name) + related_data[field_name] = attrs.pop(field_name) # Reverse m2m relations - for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): + for (obj, model) in meta.get_all_related_m2m_objects_with_model(): field_name = obj.field.related_query_name() if field_name in attrs: - self.m2m_data[field_name] = attrs.pop(field_name) + m2m_data[field_name] = attrs.pop(field_name) # Forward m2m relations - for field in self.opts.model._meta.many_to_many: + for field in meta.many_to_many: if field.name in attrs: - self.m2m_data[field.name] = attrs.pop(field.name) + m2m_data[field.name] = attrs.pop(field.name) + # Update an existing instance... if instance is not None: for key, val in attrs.items(): setattr(instance, key, val) + # ...or create a new instance else: instance = self.opts.model(**attrs) + # Any relations that cannot be set until we've + # saved the model get hidden away on these + # private attributes, so we can deal with them + # at the point of save. + instance._related_data = related_data + instance._m2m_data = m2m_data + return instance def from_native(self, data, files): @@ -608,26 +806,24 @@ class ModelSerializer(Serializer): Override the default method to also include model field validation. """ instance = super(ModelSerializer, self).from_native(data, files) - if instance: + if not self._errors: return self.full_clean(instance) - def save(self): + def save_object(self, obj, **kwargs): """ Save the deserialized object and return it. """ - self.object.save() - - if getattr(self, 'm2m_data', None): - for accessor_name, object_list in self.m2m_data.items(): - setattr(self.object, accessor_name, object_list) - self.m2m_data = {} + obj.save(**kwargs) - if getattr(self, 'related_data', None): - for accessor_name, object_list in self.related_data.items(): - setattr(self.object, accessor_name, object_list) - self.related_data = {} + if getattr(obj, '_m2m_data', None): + for accessor_name, object_list in obj._m2m_data.items(): + setattr(obj, accessor_name, object_list) + del(obj._m2m_data) - return self.object + if getattr(obj, '_related_data', None): + for accessor_name, related in obj._related_data.items(): + setattr(obj, accessor_name, related) + del(obj._related_data) class HyperlinkedModelSerializerOptions(ModelSerializerOptions): @@ -637,6 +833,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): def __init__(self, meta): super(HyperlinkedModelSerializerOptions, self).__init__(meta) self.view_name = getattr(meta, 'view_name', None) + self.lookup_field = getattr(meta, 'lookup_field', None) class HyperlinkedModelSerializer(ModelSerializer): @@ -646,6 +843,7 @@ class HyperlinkedModelSerializer(ModelSerializer): """ _options_class = HyperlinkedModelSerializerOptions _default_view_name = '%(model_name)s-detail' + _hyperlink_field_class = HyperlinkedRelatedField url = HyperlinkedIdentityField() @@ -666,19 +864,35 @@ class HyperlinkedModelSerializer(ModelSerializer): return self._default_view_name % format_kwargs def get_pk_field(self, model_field): - return None + if self.opts.fields and model_field.name in self.opts.fields: + return self.get_field(model_field) - def get_related_field(self, model_field, to_many): + def get_related_field(self, model_field, related_model, to_many): """ Creates a default instance of a flat relational field. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - rel = model_field.rel.to kwargs = { - 'required': not(model_field.null or model_field.blank), - 'queryset': rel._default_manager, - 'view_name': self._get_default_view_name(rel), + 'queryset': related_model._default_manager, + 'view_name': self._get_default_view_name(related_model), 'many': to_many } - return HyperlinkedRelatedField(**kwargs) + + if model_field: + kwargs['required'] = not(model_field.null or model_field.blank) + + if self.opts.lookup_field: + kwargs['lookup_field'] = self.opts.lookup_field + + return self._hyperlink_field_class(**kwargs) + + def get_identity(self, data): + """ + This hook is required for bulk update. + We need to override the default, to use the url as the identity. + """ + try: + return data.get('url', None) + except AttributeError: + return None diff --git a/rest_framework/settings.py b/rest_framework/settings.py index b7aa0bbe..beb511ac 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -18,14 +18,18 @@ REST framework settings, checking for user settings first, then falling back to the defaults. """ from __future__ import unicode_literals + from django.conf import settings from django.utils import importlib + +from rest_framework import ISO_8601 from rest_framework.compat import six USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None) DEFAULTS = { + # Base API policies 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer', 'rest_framework.renderers.BrowsableAPIRenderer', @@ -47,11 +51,15 @@ DEFAULTS = { 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', + + # Genric view behavior 'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.serializers.ModelSerializer', 'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer', + 'DEFAULT_FILTER_BACKENDS': (), + # Throttling 'DEFAULT_THROTTLE_RATES': { 'user': None, 'anon': None, @@ -61,9 +69,6 @@ DEFAULTS = { 'PAGINATE_BY': None, 'PAGINATE_BY_PARAM': None, - # Filtering - 'FILTER_BACKEND': None, - # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -76,6 +81,25 @@ DEFAULTS = { 'URL_FORMAT_OVERRIDE': 'format', 'FORMAT_SUFFIX_KWARG': 'format', + + # Input and output formats + 'DATE_INPUT_FORMATS': ( + ISO_8601, + ), + 'DATE_FORMAT': None, + + 'DATETIME_INPUT_FORMATS': ( + ISO_8601, + ), + 'DATETIME_FORMAT': None, + + 'TIME_INPUT_FORMATS': ( + ISO_8601, + ), + 'TIME_FORMAT': None, + + # Pending deprecation + 'FILTER_BACKEND': None, } @@ -89,6 +113,7 @@ IMPORT_STRINGS = ( 'DEFAULT_CONTENT_NEGOTIATION_CLASS', 'DEFAULT_MODEL_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS', + 'DEFAULT_FILTER_BACKENDS', 'FILTER_BACKEND', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 44633f5a..4410f285 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -115,7 +115,7 @@ </div> <div class="response-info"> <pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %} -{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|urlize_quoted_links }}</span> +{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span> {% endfor %} </div>{{ content|urlize_quoted_links }}</pre>{% endautoescape %} </div> diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html index e10ce20f..b7629327 100644 --- a/rest_framework/templates/rest_framework/login.html +++ b/rest_framework/templates/rest_framework/login.html @@ -1,53 +1,3 @@ -{% load url from future %} -{% load rest_framework %} -<html> +{% extends "rest_framework/login_base.html" %} - <head> - <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/> - <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> - <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/> - </head> - - <body class="container"> - -<div class="container-fluid" style="margin-top: 30px"> - <div class="row-fluid"> - - <div class="well" style="width: 320px; margin-left: auto; margin-right: auto"> - <div class="row-fluid"> - <div> - <h3 style="margin: 0 0 20px;">Django REST framework</h3> - </div> - </div><!-- /row fluid --> - - <div class="row-fluid"> - <div> - <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> - {% csrf_token %} - <div id="div_id_username" class="clearfix control-group"> - <div class="controls"> - <Label class="span4">Username:</label> - <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> - </div> - </div> - <div id="div_id_password" class="clearfix control-group"> - <div class="controls"> - <Label class="span4">Password:</label> - <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> - </div> - </div> - <input type="hidden" name="next" value="{{ next }}" /> - <div class="form-actions-no-box"> - <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit"> - </div> - </form> - </div> - </div><!-- /row fluid --> - </div><!--/span--> - - </div><!-- /.row-fluid --> - </div> - - </div> - </body> -</html> +{# Override this template in your own templates directory to customize #} diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html new file mode 100644 index 00000000..a3e73b6b --- /dev/null +++ b/rest_framework/templates/rest_framework/login_base.html @@ -0,0 +1,51 @@ +{% load url from future %} +{% load rest_framework %} +<html> + + <head> + {% block style %} + {% block bootstrap_theme %}<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>{% endblock %} + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/> + {% endblock %} + </head> + + <body class="container"> + + <div class="container-fluid" style="margin-top: 30px"> + <div class="row-fluid"> + <div class="well" style="width: 320px; margin-left: auto; margin-right: auto"> + <div class="row-fluid"> + <div> + {% block branding %}<h3 style="margin: 0 0 20px;">Django REST framework</h3>{% endblock %} + </div> + </div><!-- /row fluid --> + + <div class="row-fluid"> + <div> + <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> + {% csrf_token %} + <div id="div_id_username" class="clearfix control-group"> + <div class="controls"> + <Label class="span4">Username:</label> + <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> + </div> + </div> + <div id="div_id_password" class="clearfix control-group"> + <div class="controls"> + <Label class="span4">Password:</label> + <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> + </div> + </div> + <input type="hidden" name="next" value="{{ next }}" /> + <div class="form-actions-no-box"> + <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit"> + </div> + </form> + </div> + </div><!-- /.row-fluid --> + </div><!--/.well--> + </div><!-- /.row-fluid --> + </div><!-- /.container-fluid --> + </body> +</html> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index c21ddcd7..c86b6456 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -4,11 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch from django.http import QueryDict from django.utils.html import escape from django.utils.safestring import SafeData, mark_safe -from rest_framework.compat import urlparse -from rest_framework.compat import force_text -from rest_framework.compat import six -import re -import string +from rest_framework.compat import urlparse, force_text, six, smart_urlquote +import re, string register = template.Library() @@ -112,22 +109,6 @@ def replace_query_param(url, key, val): class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') -# Bunch of stuff cloned from urlize -LEADING_PUNCTUATION = ['(', '<', '<', '"', "'"] -TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '>', '"', "'"] -DOTS = ['·', '*', '\xe2\x80\xa2', '•', '•', '•'] -unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)') -word_split_re = re.compile(r'(\s+)') -punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \ - ('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]), - '|'.join([re.escape(x) for x in TRAILING_PUNCTUATION]))) -simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$') -link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+') -html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE) -hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL) -trailing_empty_content_re = re.compile(r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z') - - # And the template tags themselves... @register.simple_tag @@ -195,15 +176,25 @@ def add_class(value, css_class): return value +# Bunch of stuff cloned from urlize +TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"] +WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), + ('"', '"'), ("'", "'")] +word_split_re = re.compile(r'(\s+)') +simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE) +simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE) +simple_email_re = re.compile(r'^\S+@\S+\.\S+$') + + @register.filter def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True): """ Converts any URLs in text into clickable links. - Works on http://, https://, www. links and links ending in .org, .net or - .com. Links can have trailing punctuation (periods, commas, close-parens) - and leading punctuation (opening parens) and it'll still do the right - thing. + Works on http://, https://, www. links, and also on links ending in one of + the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org). + Links can have trailing punctuation (periods, commas, close-parens) and + leading punctuation (opening parens) and it'll still do the right thing. If trim_url_limit is not None, the URLs in link text longer than this limit will truncated to trim_url_limit-3 characters and appended with an elipsis. @@ -216,24 +207,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x safe_input = isinstance(text, SafeData) words = word_split_re.split(force_text(text)) - nofollow_attr = nofollow and ' rel="nofollow"' or '' for i, word in enumerate(words): match = None if '.' in word or '@' in word or ':' in word: - match = punctuation_re.match(word) - if match: - lead, middle, trail = match.groups() + # Deal with punctuation. + lead, middle, trail = '', word, '' + for punctuation in TRAILING_PUNCTUATION: + if middle.endswith(punctuation): + middle = middle[:-len(punctuation)] + trail = punctuation + trail + for opening, closing in WRAPPING_PUNCTUATION: + if middle.startswith(opening): + middle = middle[len(opening):] + lead = lead + opening + # Keep parentheses at the end only if they're balanced. + if (middle.endswith(closing) + and middle.count(closing) == middle.count(opening) + 1): + middle = middle[:-len(closing)] + trail = closing + trail + # Make URL we want to point to. url = None - if middle.startswith('http://') or middle.startswith('https://'): - url = middle - elif middle.startswith('www.') or ('@' not in middle and \ - middle and middle[0] in string.ascii_letters + string.digits and \ - (middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))): - url = 'http://%s' % middle - elif '@' in middle and not ':' in middle and simple_email_re.match(middle): - url = 'mailto:%s' % middle + nofollow_attr = ' rel="nofollow"' if nofollow else '' + if simple_url_re.match(middle): + url = smart_urlquote(middle) + elif simple_url_2_re.match(middle): + url = smart_urlquote('http://%s' % middle) + elif not ':' in middle and simple_email_re.match(middle): + local, domain = middle.rsplit('@', 1) + try: + domain = domain.encode('idna').decode('ascii') + except UnicodeError: + continue + url = 'mailto:%s@%s' % (local, domain) nofollow_attr = '' + # Make link. if url: trimmed = trim_url(middle) @@ -251,4 +259,15 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru words[i] = mark_safe(word) elif autoescape: words[i] = escape(word) - return mark_safe(''.join(words)) + return ''.join(words) + + +@register.filter +def break_long_headers(header): + """ + Breaks headers longer than 160 characters (~page length) + when possible (are comma separated) + """ + if len(header) > 160 and ',' in header: + header = mark_safe('<br> ' + ', <br>'.join(header.split(','))) + return header diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 7b754af5..8e6d3e51 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -2,23 +2,29 @@ from __future__ import unicode_literals from django.contrib.auth.models import User from django.http import HttpResponse from django.test import Client, TestCase +from django.utils import unittest from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions from rest_framework import permissions from rest_framework import status -from rest_framework.authtoken.models import Token from rest_framework.authentication import ( BaseAuthentication, TokenAuthentication, BasicAuthentication, - SessionAuthentication + SessionAuthentication, + OAuthAuthentication, + OAuth2Authentication ) -from rest_framework.compat import patterns +from rest_framework.authtoken.models import Token +from rest_framework.compat import patterns, url, include +from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope +from rest_framework.compat import oauth, oauth_provider from rest_framework.tests.utils import RequestFactory from rest_framework.views import APIView import json import base64 - +import time +import datetime factory = RequestFactory() @@ -41,8 +47,19 @@ urlpatterns = patterns('', (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), + (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), + (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], + permission_classes=[permissions.TokenHasReadWriteScope])) ) +if oauth2_provider is not None: + urlpatterns += patterns('', + url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), + url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), + url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], + permission_classes=[permissions.TokenHasReadWriteScope])), + ) + class BasicAuthTests(TestCase): """Basic authentication""" @@ -146,7 +163,7 @@ class TokenAuthTests(TestCase): def test_post_form_passing_token_auth(self): """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" - auth = "Token " + self.key + auth = 'Token ' + self.key response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -222,3 +239,317 @@ class IncorrectCredentialsTests(TestCase): response = view(request) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.data, {'detail': 'Bad credentials'}) + + +class OAuthTests(TestCase): + """OAuth 1.0a authentication""" + urls = 'rest_framework.tests.authentication' + + def setUp(self): + # these imports are here because oauth is optional and hiding them in try..except block or compat + # could obscure problems if something breaks + from oauth_provider.models import Consumer, Resource + from oauth_provider.models import Token as OAuthToken + from oauth_provider import consts + + self.consts = consts + + self.csrf_client = Client(enforce_csrf_checks=True) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + self.CONSUMER_KEY = 'consumer_key' + self.CONSUMER_SECRET = 'consumer_secret' + self.TOKEN_KEY = "token_key" + self.TOKEN_SECRET = "token_secret" + + self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, + name='example', user=self.user, status=self.consts.ACCEPTED) + + self.resource = Resource.objects.create(name="resource name", url="api/") + self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, resource=self.resource, + token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True + ) + + def _create_authorization_header(self): + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': self.token.key, + 'oauth_consumer_key': self.consumer.key + } + + req = oauth.Request(method="GET", url="http://example.com", parameters=params) + + signature_method = oauth.SignatureMethod_PLAINTEXT() + req.sign_request(signature_method, self.consumer, self.token) + + return req.to_header()["Authorization"] + + def _create_authorization_url_parameters(self): + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': self.token.key, + 'oauth_consumer_key': self.consumer.key + } + + req = oauth.Request(method="GET", url="http://example.com", parameters=params) + + signature_method = oauth.SignatureMethod_PLAINTEXT() + req.sign_request(signature_method, self.consumer, self.token) + return dict(req) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_passing_oauth(self): + """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_repeated_nonce_failing_oauth(self): + """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails""" + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + # simulate reply attack auth header containes already used (nonce, timestamp) pair + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_token_removed_failing_oauth(self): + """Ensure POSTing when there is no OAuth access token in db fails""" + self.token.delete() + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_consumer_status_not_accepted_failing_oauth(self): + """Ensure POSTing when consumer status is anything other than ACCEPTED fails""" + for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED): + self.consumer.status = consumer_status + self.consumer.save() + + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_with_request_token_failing_oauth(self): + """Ensure POSTing with unauthorized request token instead of access token fails""" + self.token.token_type = self.token.REQUEST + self.token.save() + + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_with_urlencoded_parameters(self): + """Ensure POSTing with x-www-form-urlencoded auth parameters passes""" + params = self._create_authorization_url_parameters() + response = self.csrf_client.post('/oauth/', params) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_get_form_with_url_parameters(self): + """Ensure GETing with auth in url parameters passes""" + params = self._create_authorization_url_parameters() + response = self.csrf_client.get('/oauth/', params) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_hmac_sha1_signature_passes(self): + """Ensure POSTing using HMAC_SHA1 signature method passes""" + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': self.token.key, + 'oauth_consumer_key': self.consumer.key + } + + req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + + signature_method = oauth.SignatureMethod_HMAC_SHA1() + req.sign_request(signature_method, self.consumer, self.token) + auth = req.to_header()["Authorization"] + + response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_get_form_with_readonly_resource_passing_auth(self): + """Ensure POSTing with a readonly resource instead of a write scope fails""" + read_only_access_token = self.token + read_only_access_token.resource.is_readonly = True + read_only_access_token.resource.save() + params = self._create_authorization_url_parameters() + response = self.csrf_client.get('/oauth-with-scope/', params) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_with_readonly_resource_failing_auth(self): + """Ensure POSTing with a readonly resource instead of a write scope fails""" + read_only_access_token = self.token + read_only_access_token.resource.is_readonly = True + read_only_access_token.resource.save() + params = self._create_authorization_url_parameters() + response = self.csrf_client.post('/oauth-with-scope/', params) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_with_write_resource_passing_auth(self): + """Ensure POSTing with a write resource succeed""" + read_write_access_token = self.token + read_write_access_token.resource.is_readonly = False + read_write_access_token.resource.save() + params = self._create_authorization_url_parameters() + response = self.csrf_client.post('/oauth-with-scope/', params) + self.assertEqual(response.status_code, 200) + + +class OAuth2Tests(TestCase): + """OAuth 2.0 authentication""" + urls = 'rest_framework.tests.authentication' + + def setUp(self): + self.csrf_client = Client(enforce_csrf_checks=True) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + self.CLIENT_ID = 'client_key' + self.CLIENT_SECRET = 'client_secret' + self.ACCESS_TOKEN = "access_token" + self.REFRESH_TOKEN = "refresh_token" + + self.oauth2_client = oauth2_provider_models.Client.objects.create( + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + redirect_uri='', + client_type=0, + name='example', + user=None, + ) + + self.access_token = oauth2_provider_models.AccessToken.objects.create( + token=self.ACCESS_TOKEN, + client=self.oauth2_client, + user=self.user, + ) + self.refresh_token = oauth2_provider_models.RefreshToken.objects.create( + user=self.user, + access_token=self.access_token, + client=self.oauth2_client + ) + + def _create_authorization_header(self, token=None): + return "Bearer {0}".format(token or self.access_token.token) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_type_failing(self): + """Ensure that a wrong token type lead to the correct HTTP error status code""" + auth = "Wrong token-type-obsviously" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_format_failing(self): + """Ensure that a wrong token format lead to the correct HTTP error status code""" + auth = "Bearer wrong token format" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_failing(self): + """Ensure that a wrong token lead to the correct HTTP error status code""" + auth = "Bearer wrong-token" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_passing_auth(self): + """Ensure GETing form over OAuth with correct client credentials succeed""" + auth = self._create_authorization_header() + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_passing_auth(self): + """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_token_removed_failing_auth(self): + """Ensure POSTing when there is no OAuth access token in db fails""" + self.access_token.delete() + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_refresh_token_failing_auth(self): + """Ensure POSTing with refresh token instead of access token fails""" + auth = self._create_authorization_header(token=self.refresh_token.token) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_expired_access_token_failing_auth(self): + """Ensure POSTing with expired access token fails with an 'Invalid token' error""" + self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late + self.access_token.save() + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + self.assertIn('Invalid token', response.content) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_invalid_scope_failing_auth(self): + """Ensure POSTing with a readonly scope instead of a write scope fails""" + read_only_access_token = self.access_token + read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] + read_only_access_token.save() + auth = self._create_authorization_header(token=read_only_access_token.token) + response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_valid_scope_passing_auth(self): + """Ensure POSTing with a write scope succeed""" + read_write_access_token = self.access_token + read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] + read_write_access_token.save() + auth = self._create_authorization_header(token=read_write_access_token.token) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py index 5b3315bc..52c1a34c 100644 --- a/rest_framework/tests/description.py +++ b/rest_framework/tests/description.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from django.test import TestCase from rest_framework.views import APIView from rest_framework.compat import apply_markdown +from rest_framework.utils.formatting import get_view_name, get_view_description # We check that docstrings get nicely un-indented. DESCRIPTION = """an example docstring @@ -49,22 +50,16 @@ MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2> class TestViewNamesAndDescriptions(TestCase): - def test_resource_name_uses_classname_by_default(self): - """Ensure Resource names are based on the classname by default.""" + def test_view_name_uses_class_name(self): + """ + Ensure view names are based on the class name. + """ class MockView(APIView): pass - self.assertEqual(MockView().get_name(), 'Mock') + self.assertEqual(get_view_name(MockView), 'Mock') - def test_resource_name_can_be_set_explicitly(self): - """Ensure Resource names can be set using the 'get_name' method.""" - example = 'Some Other Name' - class MockView(APIView): - def get_name(self): - return example - self.assertEqual(MockView().get_name(), example) - - def test_resource_description_uses_docstring_by_default(self): - """Ensure Resource names are based on the docstring by default.""" + def test_view_description_uses_docstring(self): + """Ensure view descriptions are based on the docstring.""" class MockView(APIView): """an example docstring ==================== @@ -81,44 +76,32 @@ class TestViewNamesAndDescriptions(TestCase): # hash style header #""" - self.assertEqual(MockView().get_description(), DESCRIPTION) - - def test_resource_description_can_be_set_explicitly(self): - """Ensure Resource descriptions can be set using the 'get_description' method.""" - example = 'Some other description' - - class MockView(APIView): - """docstring""" - def get_description(self): - return example - self.assertEqual(MockView().get_description(), example) + self.assertEqual(get_view_description(MockView), DESCRIPTION) - def test_resource_description_supports_unicode(self): + def test_view_description_supports_unicode(self): + """ + Unicode in docstrings should be respected. + """ class MockView(APIView): """Проверка""" pass - self.assertEqual(MockView().get_description(), "Проверка") - - - def test_resource_description_does_not_require_docstring(self): - """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method.""" - example = 'Some other description' - - class MockView(APIView): - def get_description(self): - return example - self.assertEqual(MockView().get_description(), example) + self.assertEqual(get_view_description(MockView), "Проверка") - def test_resource_description_can_be_empty(self): - """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string.""" + def test_view_description_can_be_empty(self): + """ + Ensure that if a view has no docstring, + then it's description is the empty string. + """ class MockView(APIView): pass - self.assertEqual(MockView().get_description(), '') + self.assertEqual(get_view_description(MockView), '') def test_markdown(self): - """Ensure markdown to HTML works as expected""" + """ + Ensure markdown to HTML works as expected. + """ if apply_markdown: gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21 lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21 diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index 840ed320..3cdfa0f6 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -3,10 +3,14 @@ General serializer field tests. """ from __future__ import unicode_literals import datetime +from decimal import Decimal + from django.db import models from django.test import TestCase from django.core import validators + from rest_framework import serializers +from rest_framework.serializers import Serializer class TimestampedModel(models.Model): @@ -59,37 +63,586 @@ class BasicFieldTests(TestCase): serializer = CharPrimaryKeyModelSerializer() self.assertEqual(serializer.fields['id'].read_only, False) - def test_TimeField_from_native(self): + +class DateFieldTest(TestCase): + """ + Tests for the DateFieldTest from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts default iso input formats. + """ + f = serializers.DateField() + result_1 = f.from_native('1984-07-31') + + self.assertEqual(datetime.date(1984, 7, 31), result_1) + + def test_from_native_datetime_date(self): + """ + Make sure from_native() accepts a datetime.date instance. + """ + f = serializers.DateField() + result_1 = f.from_native(datetime.date(1984, 7, 31)) + + self.assertEqual(result_1, datetime.date(1984, 7, 31)) + + def test_from_native_custom_format(self): + """ + Make sure from_native() accepts custom input formats. + """ + f = serializers.DateField(input_formats=['%Y -- %d']) + result = f.from_native('1984 -- 31') + + self.assertEqual(datetime.date(1984, 1, 31), result) + + def test_from_native_invalid_default_on_custom_format(self): + """ + Make sure from_native() don't accept default formats if custom format is preset + """ + f = serializers.DateField(input_formats=['%Y -- %d']) + + try: + f.from_native('1984-07-31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.DateField() + result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DateField() + result = f.from_native(None) + + self.assertEqual(result, None) + + def test_from_native_invalid_date(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid date. + """ + f = serializers.DateField() + + try: + f.from_native('1984-13-31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_invalid_format(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid format. + """ + f = serializers.DateField() + + try: + f.from_native('1984 -- 31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_to_native(self): + """ + Make sure to_native() returns datetime as default. + """ + f = serializers.DateField() + + result_1 = f.to_native(datetime.date(1984, 7, 31)) + + self.assertEqual(datetime.date(1984, 7, 31), result_1) + + def test_to_native_iso(self): + """ + Make sure to_native() with 'iso-8601' returns iso formated date. + """ + f = serializers.DateField(format='iso-8601') + + result_1 = f.to_native(datetime.date(1984, 7, 31)) + + self.assertEqual('1984-07-31', result_1) + + def test_to_native_custom_format(self): + """ + Make sure to_native() returns correct custom format. + """ + f = serializers.DateField(format="%Y - %m.%d") + + result_1 = f.to_native(datetime.date(1984, 7, 31)) + + self.assertEqual('1984 - 07.31', result_1) + + def test_to_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DateField(required=False) + self.assertEqual(None, f.to_native(None)) + + +class DateTimeFieldTest(TestCase): + """ + Tests for the DateTimeField from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts default iso input formats. + """ + f = serializers.DateTimeField() + result_1 = f.from_native('1984-07-31 04:31') + result_2 = f.from_native('1984-07-31 04:31:59') + result_3 = f.from_native('1984-07-31 04:31:59.000200') + + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3) + + def test_from_native_datetime_datetime(self): + """ + Make sure from_native() accepts a datetime.datetime instance. + """ + f = serializers.DateTimeField() + result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31)) + result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) + result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31)) + self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59)) + self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + def test_from_native_custom_format(self): + """ + Make sure from_native() accepts custom input formats. + """ + f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) + result = f.from_native('1984 -- 04:59') + + self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result) + + def test_from_native_invalid_default_on_custom_format(self): + """ + Make sure from_native() don't accept default formats if custom format is preset + """ + f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) + + try: + f.from_native('1984-07-31 04:31:59') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.DateTimeField() + result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DateTimeField() + result = f.from_native(None) + + self.assertEqual(result, None) + + def test_from_native_invalid_datetime(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid datetime. + """ + f = serializers.DateTimeField() + + try: + f.from_native('04:61:59') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " + "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_invalid_format(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid format. + """ + f = serializers.DateTimeField() + + try: + f.from_native('04 -- 31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " + "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_to_native(self): + """ + Make sure to_native() returns isoformat as default. + """ + f = serializers.DateTimeField() + + result_1 = f.to_native(datetime.datetime(1984, 7, 31)) + result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) + result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) + result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + self.assertEqual(datetime.datetime(1984, 7, 31), result_1) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4) + + def test_to_native_iso(self): + """ + Make sure to_native() with format=iso-8601 returns iso formatted datetime. + """ + f = serializers.DateTimeField(format='iso-8601') + + result_1 = f.to_native(datetime.datetime(1984, 7, 31)) + result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) + result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) + result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + self.assertEqual('1984-07-31T00:00:00', result_1) + self.assertEqual('1984-07-31T04:31:00', result_2) + self.assertEqual('1984-07-31T04:31:59', result_3) + self.assertEqual('1984-07-31T04:31:59.000200', result_4) + + def test_to_native_custom_format(self): + """ + Make sure to_native() returns correct custom format. + """ + f = serializers.DateTimeField(format="%Y - %H:%M") + + result_1 = f.to_native(datetime.datetime(1984, 7, 31)) + result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) + result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) + result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + self.assertEqual('1984 - 00:00', result_1) + self.assertEqual('1984 - 04:31', result_2) + self.assertEqual('1984 - 04:31', result_3) + self.assertEqual('1984 - 04:31', result_4) + + def test_to_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DateTimeField(required=False) + self.assertEqual(None, f.to_native(None)) + + +class TimeFieldTest(TestCase): + """ + Tests for the TimeField from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts default iso input formats. + """ f = serializers.TimeField() - result = f.from_native('12:34:56.987654') + result_1 = f.from_native('04:31') + result_2 = f.from_native('04:31:59') + result_3 = f.from_native('04:31:59.000200') - self.assertEqual(datetime.time(12, 34, 56, 987654), result) + self.assertEqual(datetime.time(4, 31), result_1) + self.assertEqual(datetime.time(4, 31, 59), result_2) + self.assertEqual(datetime.time(4, 31, 59, 200), result_3) - def test_TimeField_from_native_datetime_time(self): + def test_from_native_datetime_time(self): """ Make sure from_native() accepts a datetime.time instance. """ f = serializers.TimeField() - result = f.from_native(datetime.time(12, 34, 56)) - self.assertEqual(result, datetime.time(12, 34, 56)) + result_1 = f.from_native(datetime.time(4, 31)) + result_2 = f.from_native(datetime.time(4, 31, 59)) + result_3 = f.from_native(datetime.time(4, 31, 59, 200)) + + self.assertEqual(result_1, datetime.time(4, 31)) + self.assertEqual(result_2, datetime.time(4, 31, 59)) + self.assertEqual(result_3, datetime.time(4, 31, 59, 200)) + + def test_from_native_custom_format(self): + """ + Make sure from_native() accepts custom input formats. + """ + f = serializers.TimeField(input_formats=['%H -- %M']) + result = f.from_native('04 -- 31') - def test_TimeField_from_native_empty(self): + self.assertEqual(datetime.time(4, 31), result) + + def test_from_native_invalid_default_on_custom_format(self): + """ + Make sure from_native() don't accept default formats if custom format is preset + """ + f = serializers.TimeField(input_formats=['%H -- %M']) + + try: + f.from_native('04:31:59') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ f = serializers.TimeField() result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.TimeField() + result = f.from_native(None) + self.assertEqual(result, None) - def test_TimeField_from_native_invalid_time(self): + def test_from_native_invalid_time(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid time. + """ + f = serializers.TimeField() + + try: + f.from_native('04:61:59') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " + "hh:mm[:ss[.uuuuuu]]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_invalid_format(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid format. + """ + f = serializers.TimeField() + + try: + f.from_native('04 -- 31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " + "hh:mm[:ss[.uuuuuu]]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_to_native(self): + """ + Make sure to_native() returns time object as default. + """ f = serializers.TimeField() + result_1 = f.to_native(datetime.time(4, 31)) + result_2 = f.to_native(datetime.time(4, 31, 59)) + result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + + self.assertEqual(datetime.time(4, 31), result_1) + self.assertEqual(datetime.time(4, 31, 59), result_2) + self.assertEqual(datetime.time(4, 31, 59, 200), result_3) + + def test_to_native_iso(self): + """ + Make sure to_native() with format='iso-8601' returns iso formatted time. + """ + f = serializers.TimeField(format='iso-8601') + result_1 = f.to_native(datetime.time(4, 31)) + result_2 = f.to_native(datetime.time(4, 31, 59)) + result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + + self.assertEqual('04:31:00', result_1) + self.assertEqual('04:31:59', result_2) + self.assertEqual('04:31:59.000200', result_3) + + def test_to_native_custom_format(self): + """ + Make sure to_native() returns correct custom format. + """ + f = serializers.TimeField(format="%H - %S [%f]") + result_1 = f.to_native(datetime.time(4, 31)) + result_2 = f.to_native(datetime.time(4, 31, 59)) + result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + + self.assertEqual('04 - 00 [000000]', result_1) + self.assertEqual('04 - 59 [000000]', result_2) + self.assertEqual('04 - 59 [000200]', result_3) + + +class DecimalFieldTest(TestCase): + """ + Tests for the DecimalField from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts string values + """ + f = serializers.DecimalField() + result_1 = f.from_native('9000') + result_2 = f.from_native('1.00000001') + + self.assertEqual(Decimal('9000'), result_1) + self.assertEqual(Decimal('1.00000001'), result_2) + + def test_from_native_invalid_string(self): + """ + Make sure from_native() raises ValidationError on passing invalid string + """ + f = serializers.DecimalField() try: - f.from_native('12:69:12') + f.from_native('123.45.6') except validators.ValidationError as e: - self.assertEqual(e.messages, ["'12:69:12' value has an invalid " - "format. It must be a valid time " - "in the HH:MM[:ss[.uuuuuu]] format."]) + self.assertEqual(e.messages, ["Enter a number."]) else: self.fail("ValidationError was not properly raised") - def test_TimeFieldModelSerializer(self): - serializer = TimeFieldModelSerializer() - self.assertTrue(isinstance(serializer.fields['clock'], serializers.TimeField)) + def test_from_native_integer(self): + """ + Make sure from_native() accepts integer values + """ + f = serializers.DecimalField() + result = f.from_native(9000) + + self.assertEqual(Decimal('9000'), result) + + def test_from_native_float(self): + """ + Make sure from_native() accepts float values + """ + f = serializers.DecimalField() + result = f.from_native(1.00000001) + + self.assertEqual(Decimal('1.00000001'), result) + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.DecimalField() + result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DecimalField() + result = f.from_native(None) + + self.assertEqual(result, None) + + def test_to_native(self): + """ + Make sure to_native() returns Decimal as string. + """ + f = serializers.DecimalField() + + result_1 = f.to_native(Decimal('9000')) + result_2 = f.to_native(Decimal('1.00000001')) + + self.assertEqual(Decimal('9000'), result_1) + self.assertEqual(Decimal('1.00000001'), result_2) + + def test_to_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DecimalField(required=False) + self.assertEqual(None, f.to_native(None)) + + def test_valid_serialization(self): + """ + Make sure the serializer works correctly + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_value=9010, + min_value=9000, + max_digits=6, + decimal_places=2) + + self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid()) + self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid()) + self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid()) + + self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid()) + self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid()) + self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid()) + + def test_raise_max_value(self): + """ + Make sure max_value violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_value=100) + + s = DecimalSerializer(data={'decimal_field': '123'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']}) + + def test_raise_min_value(self): + """ + Make sure min_value violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(min_value=100) + + s = DecimalSerializer(data={'decimal_field': '99'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']}) + + def test_raise_max_digits(self): + """ + Make sure max_digits violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_digits=5) + + s = DecimalSerializer(data={'decimal_field': '123.456'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']}) + + def test_raise_max_decimal_places(self): + """ + Make sure max_decimal_places violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(decimal_places=3) + + s = DecimalSerializer(data={'decimal_field': '123.4567'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']}) + + def test_raise_max_whole_digits(self): + """ + Make sure max_whole_digits violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) + + s = DecimalSerializer(data={'decimal_field': '12345.6'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
\ No newline at end of file diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/filters.py new file mode 100644 index 00000000..8ae6d530 --- /dev/null +++ b/rest_framework/tests/filters.py @@ -0,0 +1,474 @@ +from __future__ import unicode_literals +import datetime +from decimal import Decimal +from django.db import models +from django.core.urlresolvers import reverse +from django.test import TestCase +from django.test.client import RequestFactory +from django.utils import unittest +from rest_framework import generics, serializers, status, filters +from rest_framework.compat import django_filters, patterns, url +from rest_framework.tests.models import BasicModel + +factory = RequestFactory() + + +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + +if django_filters: + # Basic filter on a list view. + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_fields = ['decimal', 'date'] + filter_backends = (filters.DjangoFilterBackend,) + + # These class are used to test a filter class. + class SeveralFieldsFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + decimal = django_filters.NumberFilter(lookup_type='lt') + date = django_filters.DateFilter(lookup_type='gt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterClassRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backends = (filters.DjangoFilterBackend,) + + # These classes are used to test a misconfigured filter class. + class MisconfiguredFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + + class Meta: + model = BasicModel + fields = ['text'] + + class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = MisconfiguredFilter + filter_backends = (filters.DjangoFilterBackend,) + + class FilterClassDetailView(generics.RetrieveAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backends = (filters.DjangoFilterBackend,) + + # Regression test for #814 + class FilterableItemSerializer(serializers.ModelSerializer): + class Meta: + model = FilterableItem + + class FilterFieldsQuerysetView(generics.ListCreateAPIView): + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer + filter_fields = ['decimal', 'date'] + filter_backends = (filters.DjangoFilterBackend,) + + class GetQuerysetView(generics.ListCreateAPIView): + serializer_class = FilterableItemSerializer + filter_class = SeveralFieldsFilter + filter_backends = (filters.DjangoFilterBackend,) + + def get_queryset(self): + return FilterableItem.objects.all() + + urlpatterns = patterns('', + url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), + url(r'^$', FilterClassRootView.as_view(), name='root-view'), + url(r'^get-queryset/$', GetQuerysetView.as_view(), + name='get-queryset-view'), + ) + + +class CommonFilteringTestCase(TestCase): + def _serialize_object(self, obj): + return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + + def setUp(self): + """ + Create 10 FilterableItem instances. + """ + base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) + for i in range(10): + text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. + decimal = base_data[1] + i + date = base_data[2] - datetime.timedelta(days=i * 2) + FilterableItem(text=text, decimal=decimal, date=date).save() + + self.objects = FilterableItem.objects + self.data = [ + self._serialize_object(obj) + for obj in self.objects.all() + ] + + +class IntegrationTestFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered list views. + """ + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_fields_root_view(self): + """ + GET requests to paginated ListCreateAPIView should return paginated results. + """ + view = FilterFieldsRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) + + # Tests that the decimal filter works. + search_decimal = Decimal('2.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] == search_decimal] + self.assertEqual(response.data, expected_data) + + # Tests that the date filter works. + search_date = datetime.date(2012, 9, 22) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] == search_date] + self.assertEqual(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_filter_with_queryset(self): + """ + Regression test for #814. + """ + view = FilterFieldsQuerysetView.as_view() + + # Tests that the decimal filter works. + search_decimal = Decimal('2.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] == search_decimal] + self.assertEqual(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_filter_with_get_queryset_only(self): + """ + Regression test for #834. + """ + view = GetQuerysetView.as_view() + request = factory.get('/get-queryset/') + view(request).render() + # Used to raise "issubclass() arg 2 must be a class or tuple of classes" + # here when neither `model' nor `queryset' was specified. + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_class_root_view(self): + """ + GET requests to filtered ListCreateAPIView that have a filter_class set + should return filtered results. + """ + view = FilterClassRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) + + # Tests that the decimal filter set with 'lt' in the filter class works. + search_decimal = Decimal('4.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] < search_decimal] + self.assertEqual(response.data, expected_data) + + # Tests that the date filter set with 'gt' in the filter class works. + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] > search_date] + self.assertEqual(response.data, expected_data) + + # Tests that the text filter set with 'icontains' in the filter class works. + search_text = 'ff' + request = factory.get('/?text=%s' % search_text) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if search_text in f['text'].lower()] + self.assertEqual(response.data, expected_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] > search_date and + f['decimal'] < search_decimal] + self.assertEqual(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_incorrectly_configured_filter(self): + """ + An error should be displayed when the filter class is misconfigured. + """ + view = IncorrectlyConfiguredRootView.as_view() + + request = factory.get('/') + self.assertRaises(AssertionError, view, request) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_unknown_filter(self): + """ + GET requests with filters that aren't configured should return 200. + """ + view = FilterFieldsRootView.as_view() + + search_integer = 10 + request = factory.get('/?integer=%s' % search_integer) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class IntegrationTestDetailFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered detail views. + """ + urls = 'rest_framework.tests.filters' + + def _get_url(self, item): + return reverse('detail-view', kwargs=dict(pk=item.pk)) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_detail_view(self): + """ + GET requests to filtered RetrieveAPIView that have a filter_class set + should return filtered results. + """ + item = self.objects.all()[0] + data = self._serialize_object(item) + + # Basic test with no filter. + response = self.client.get(self._get_url(item)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, data) + + # Tests that the decimal filter set that should fail. + search_decimal = Decimal('4.25') + high_item = self.objects.filter(decimal__gt=search_decimal)[0] + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + # Tests that the decimal filter set that should succeed. + search_decimal = Decimal('4.25') + low_item = self.objects.filter(decimal__lt=search_decimal)[0] + low_item_data = self._serialize_object(low_item) + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, low_item_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] + valid_item_data = self._serialize_object(valid_item) + response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, valid_item_data) + + +class SearchFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class SearchFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # z abc + # zz bcd + # zzz cde + # ... + for idx in range(10): + title = 'z' * (idx + 1) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + SearchFilterModel(title=title, text=text).save() + + def test_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'z', 'text': 'abc'}, + {'id': 2, 'title': 'zz', 'text': 'bcd'} + ] + ) + + def test_exact_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('=title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=zzz') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'zzz', 'text': 'cde'} + ] + ) + + def test_startswith_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', '^text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 2, 'title': 'zz', 'text': 'bcd'} + ] + ) + + +class OrdringFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class OrderingFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # zyx abc + # yxw bcd + # xwv cde + for idx in range(3): + title = ( + chr(ord('z') - idx) + + chr(ord('y') - idx) + + chr(ord('x') - idx) + ) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + OrdringFilterModel(title=title, text=text).save() + + def test_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=text') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + ] + ) + + def test_reverse_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=-text') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_incorrectfield_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=foobar') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_default_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_default_ordering_using_string(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = 'title' + + view = OrderingListView.as_view() + request = factory.get('') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py deleted file mode 100644 index 8c13947c..00000000 --- a/rest_framework/tests/filterset.py +++ /dev/null @@ -1,169 +0,0 @@ -from __future__ import unicode_literals -import datetime -from decimal import Decimal -from django.test import TestCase -from django.test.client import RequestFactory -from django.utils import unittest -from rest_framework import generics, status, filters -from rest_framework.compat import django_filters -from rest_framework.tests.models import FilterableItem, BasicModel - -factory = RequestFactory() - - -if django_filters: - # Basic filter on a list view. - class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_fields = ['decimal', 'date'] - filter_backend = filters.DjangoFilterBackend - - # These class are used to test a filter class. - class SeveralFieldsFilter(django_filters.FilterSet): - text = django_filters.CharFilter(lookup_type='icontains') - decimal = django_filters.NumberFilter(lookup_type='lt') - date = django_filters.DateFilter(lookup_type='gt') - - class Meta: - model = FilterableItem - fields = ['text', 'decimal', 'date'] - - class FilterClassRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend - - # These classes are used to test a misconfigured filter class. - class MisconfiguredFilter(django_filters.FilterSet): - text = django_filters.CharFilter(lookup_type='icontains') - - class Meta: - model = BasicModel - fields = ['text'] - - class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_class = MisconfiguredFilter - filter_backend = filters.DjangoFilterBackend - - -class IntegrationTestFiltering(TestCase): - """ - Integration tests for filtered list views. - """ - - def setUp(self): - """ - Create 10 FilterableItem instances. - """ - base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) - for i in range(10): - text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. - decimal = base_data[1] + i - date = base_data[2] - datetime.timedelta(days=i * 2) - FilterableItem(text=text, decimal=decimal, date=date).save() - - self.objects = FilterableItem.objects - self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} - for obj in self.objects.all() - ] - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_get_filtered_fields_root_view(self): - """ - GET requests to paginated ListCreateAPIView should return paginated results. - """ - view = FilterFieldsRootView.as_view() - - # Basic test with no filter. - request = factory.get('/') - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, self.data) - - # Tests that the decimal filter works. - search_decimal = Decimal('2.25') - request = factory.get('/?decimal=%s' % search_decimal) - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] == search_decimal] - self.assertEqual(response.data, expected_data) - - # Tests that the date filter works. - search_date = datetime.date(2012, 9, 22) - request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] == search_date] - self.assertEqual(response.data, expected_data) - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_get_filtered_class_root_view(self): - """ - GET requests to filtered ListCreateAPIView that have a filter_class set - should return filtered results. - """ - view = FilterClassRootView.as_view() - - # Basic test with no filter. - request = factory.get('/') - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, self.data) - - # Tests that the decimal filter set with 'lt' in the filter class works. - search_decimal = Decimal('4.25') - request = factory.get('/?decimal=%s' % search_decimal) - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] < search_decimal] - self.assertEqual(response.data, expected_data) - - # Tests that the date filter set with 'gt' in the filter class works. - search_date = datetime.date(2012, 10, 2) - request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] > search_date] - self.assertEqual(response.data, expected_data) - - # Tests that the text filter set with 'icontains' in the filter class works. - search_text = 'ff' - request = factory.get('/?text=%s' % search_text) - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if search_text in f['text'].lower()] - self.assertEqual(response.data, expected_data) - - # Tests that multiple filters works. - search_decimal = Decimal('5.25') - search_date = datetime.date(2012, 10, 2) - request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] > search_date and - f['decimal'] < search_decimal] - self.assertEqual(response.data, expected_data) - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_incorrectly_configured_filter(self): - """ - An error should be displayed when the filter class is misconfigured. - """ - view = IncorrectlyConfiguredRootView.as_view() - - request = factory.get('/') - self.assertRaises(AssertionError, view, request) - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_unknown_filter(self): - """ - GET requests with filters that aren't configured should return 200. - """ - view = FilterFieldsRootView.as_view() - - search_integer = 10 - request = factory.get('/?integer=%s' % search_integer) - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index f8f2ddaa..2799d143 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals from django.db import models +from django.shortcuts import get_object_or_404 from django.test import TestCase from rest_framework import generics, serializers, status from rest_framework.tests.utils import RequestFactory @@ -38,6 +39,7 @@ class SlugBasedInstanceView(InstanceView): """ model = SlugBasedModel serializer_class = SlugSerializer + lookup_field = 'slug' class TestRootView(TestCase): @@ -60,7 +62,8 @@ class TestRootView(TestCase): GET requests to ListCreateAPIView should return list of objects. """ request = factory.get('/') - response = self.view(request).render() + with self.assertNumQueries(1): + response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, self.data) @@ -71,7 +74,8 @@ class TestRootView(TestCase): content = {'text': 'foobar'} request = factory.post('/', json.dumps(content), content_type='application/json') - response = self.view(request).render() + with self.assertNumQueries(1): + response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.data, {'id': 4, 'text': 'foobar'}) created = self.objects.get(id=4) @@ -84,7 +88,8 @@ class TestRootView(TestCase): content = {'text': 'foobar'} request = factory.put('/', json.dumps(content), content_type='application/json') - response = self.view(request).render() + with self.assertNumQueries(0): + response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."}) @@ -93,7 +98,8 @@ class TestRootView(TestCase): DELETE requests to ListCreateAPIView should not be allowed """ request = factory.delete('/') - response = self.view(request).render() + with self.assertNumQueries(0): + response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) @@ -102,7 +108,8 @@ class TestRootView(TestCase): OPTIONS requests to ListCreateAPIView should return metadata """ request = factory.options('/') - response = self.view(request).render() + with self.assertNumQueries(0): + response = self.view(request).render() expected = { 'parses': [ 'application/json', @@ -126,7 +133,8 @@ class TestRootView(TestCase): content = {'id': 999, 'text': 'foobar'} request = factory.post('/', json.dumps(content), content_type='application/json') - response = self.view(request).render() + with self.assertNumQueries(1): + response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.data, {'id': 4, 'text': 'foobar'}) created = self.objects.get(id=4) @@ -154,7 +162,8 @@ class TestInstanceView(TestCase): GET requests to RetrieveUpdateDestroyAPIView should return a single object. """ request = factory.get('/1') - response = self.view(request, pk=1).render() + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, self.data[0]) @@ -165,7 +174,8 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.post('/', json.dumps(content), content_type='application/json') - response = self.view(request).render() + with self.assertNumQueries(0): + response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) @@ -176,7 +186,8 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') - response = self.view(request, pk='1').render() + with self.assertNumQueries(2): + response = self.view(request, pk='1').render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) @@ -190,7 +201,8 @@ class TestInstanceView(TestCase): request = factory.patch('/1', json.dumps(content), content_type='application/json') - response = self.view(request, pk=1).render() + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) @@ -201,7 +213,8 @@ class TestInstanceView(TestCase): DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. """ request = factory.delete('/1') - response = self.view(request, pk=1).render() + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertEqual(response.content, six.b('')) ids = [obj.id for obj in self.objects.all()] @@ -212,7 +225,8 @@ class TestInstanceView(TestCase): OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata """ request = factory.options('/') - response = self.view(request).render() + with self.assertNumQueries(0): + response = self.view(request).render() expected = { 'parses': [ 'application/json', @@ -236,7 +250,8 @@ class TestInstanceView(TestCase): content = {'id': 999, 'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') - response = self.view(request, pk=1).render() + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) @@ -251,7 +266,8 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') - response = self.view(request, pk=1).render() + with self.assertNumQueries(3): + response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) @@ -263,10 +279,11 @@ class TestInstanceView(TestCase): at the requested url if it doesn't exist. """ content = {'text': 'foobar'} - # pk fields can not be created on demand, only the database can set th pk for a new object + # pk fields can not be created on demand, only the database can set the pk for a new object request = factory.put('/5', json.dumps(content), content_type='application/json') - response = self.view(request, pk=5).render() + with self.assertNumQueries(3): + response = self.view(request, pk=5).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) new_obj = self.objects.get(pk=5) self.assertEqual(new_obj.text, 'foobar') @@ -279,13 +296,55 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.put('/test_slug', json.dumps(content), content_type='application/json') - response = self.slug_based_view(request, slug='test_slug').render() + with self.assertNumQueries(2): + response = self.slug_based_view(request, slug='test_slug').render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'}) new_obj = SlugBasedModel.objects.get(slug='test_slug') self.assertEqual(new_obj.text, 'foobar') +class TestOverriddenGetObject(TestCase): + """ + Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the + queryset/model mechanism but instead overrides get_object() + """ + def setUp(self): + """ + Create 3 BasicModel intances. + """ + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + + class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): + """ + Example detail view for override of get_object(). + """ + model = BasicModel + + def get_object(self): + pk = int(self.kwargs['pk']) + return get_object_or_404(BasicModel.objects.all(), id=pk) + + self.view = OverriddenGetObjectView.as_view() + + def test_overridden_get_object_view(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should return a single object. + """ + request = factory.get('/1') + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0]) + + # Regression test for #285 class CommentSerializer(serializers.ModelSerializer): @@ -319,7 +378,7 @@ class TestCreateModelWithAutoNowAddField(TestCase): self.assertEqual(created.content, 'foobar') -# Test for particularly ugly regression with m2m in browseable API +# Test for particularly ugly regression with m2m in browsable API class ClassB(models.Model): name = models.CharField(max_length=255) @@ -344,9 +403,76 @@ class ExampleView(generics.ListCreateAPIView): class TestM2MBrowseableAPI(TestCase): def test_m2m_in_browseable_api(self): """ - Test for particularly ugly regression with m2m in browseable API + Test for particularly ugly regression with m2m in browsable API """ request = factory.get('/', HTTP_ACCEPT='text/html') view = ExampleView().as_view() response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class InclusiveFilterBackend(object): + def filter_queryset(self, request, queryset, view): + return queryset.filter(text='foo') + + +class ExclusiveFilterBackend(object): + def filter_queryset(self, request, queryset, view): + return queryset.filter(text='other') + + +class TestFilterBackendAppliedToViews(TestCase): + + def setUp(self): + """ + Create 3 BasicModel instances to filter on. + """ + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + + def test_get_root_view_filters_by_name_with_filter_backend(self): + """ + GET requests to ListCreateAPIView should return filtered list. + """ + root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) + request = factory.get('/') + response = root_view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}]) + + def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): + """ + GET requests to ListCreateAPIView should return empty list when all models are filtered out. + """ + root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) + request = factory.get('/') + response = root_view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, []) + + def test_get_instance_view_filters_out_name_with_filter_backend(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. + """ + instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) + request = factory.get('/1') + response = instance_view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(response.data, {'detail': 'Not found'}) + + def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded + """ + instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) + request = factory.get('/1') + response = instance_view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foo'}) diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index 9a61f299..8fc6ba77 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -27,6 +27,14 @@ class PhotoSerializer(serializers.Serializer): return Photo(**attrs) +class AlbumSerializer(serializers.ModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title') + + class Meta: + model = Album + fields = ('title', 'url') + + class BasicList(generics.ListCreateAPIView): model = BasicModel model_serializer_class = serializers.HyperlinkedModelSerializer @@ -73,6 +81,8 @@ class PhotoListCreate(generics.ListCreateAPIView): class AlbumDetail(generics.RetrieveAPIView): model = Album + serializer_class = AlbumSerializer + lookup_field = 'title' class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): @@ -180,6 +190,36 @@ class TestManyToManyHyperlinkedView(TestCase): self.assertEqual(response.data, self.data[0]) +class TestHyperlinkedIdentityFieldLookup(TestCase): + urls = 'rest_framework.tests.hyperlinkedserializers' + + def setUp(self): + """ + Create 3 Album instances. + """ + titles = ['foo', 'bar', 'baz'] + for title in titles: + album = Album(title=title) + album.save() + self.detail_view = AlbumDetail.as_view() + self.data = { + 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'}, + 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'}, + 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'} + } + + def test_lookup_field(self): + """ + GET requests to AlbumDetail view should return serialized Albums + with a url field keyed by `title`. + """ + for album in Album.objects.all(): + request = factory.get('/albums/{0}/'.format(album.title)) + response = self.detail_view(request, title=album.title) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[album.title]) + + class TestCreateWithForeignKeys(TestCase): urls = 'rest_framework.tests.hyperlinkedserializers' diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index f2117538..40e41a64 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -58,13 +58,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel): rel = models.ManyToManyField(Anchor) -# Model to test filtering. -class FilterableItem(RESTFrameworkModel): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() - - # Model for regression test for #285 class Comment(RESTFrameworkModel): diff --git a/rest_framework/tests/multitable_inheritance.py b/rest_framework/tests/multitable_inheritance.py new file mode 100644 index 00000000..00c15327 --- /dev/null +++ b/rest_framework/tests/multitable_inheritance.py @@ -0,0 +1,67 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import RESTFrameworkModel + + +# Models +class ParentModel(RESTFrameworkModel): + name1 = models.CharField(max_length=100) + + +class ChildModel(ParentModel): + name2 = models.CharField(max_length=100) + + +class AssociatedModel(RESTFrameworkModel): + ref = models.OneToOneField(ParentModel, primary_key=True) + name = models.CharField(max_length=100) + + +# Serializers +class DerivedModelSerializer(serializers.ModelSerializer): + class Meta: + model = ChildModel + + +class AssociatedModelSerializer(serializers.ModelSerializer): + class Meta: + model = AssociatedModel + + +# Tests +class IneritedModelSerializationTests(TestCase): + + def test_multitable_inherited_model_fields_as_expected(self): + """ + Assert that the parent pointer field is not included in the fields + serialized fields + """ + child = ChildModel(name1='parent name', name2='child name') + serializer = DerivedModelSerializer(child) + self.assertEqual(set(serializer.data.keys()), + set(['name1', 'name2', 'id'])) + + def test_onetoone_primary_key_model_fields_as_expected(self): + """ + Assert that a model with a onetoone field that is the primary key is + not treated like a derived model + """ + parent = ParentModel(name1='parent name') + associate = AssociatedModel(name='hello', ref=parent) + serializer = AssociatedModelSerializer(associate) + self.assertEqual(set(serializer.data.keys()), + set(['name', 'ref'])) + + def test_data_is_valid_without_parent_ptr(self): + """ + Assert that the pointer to the parent table is not a required field + for input data + """ + data = { + 'name1': 'parent name', + 'name2': 'child name', + } + serializer = DerivedModelSerializer(data=data) + self.assertEqual(serializer.is_valid(), True) diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 6b9970a6..e538a78e 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -1,17 +1,24 @@ from __future__ import unicode_literals import datetime from decimal import Decimal +from django.db import models from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters -from rest_framework.tests.models import BasicModel, FilterableItem +from rest_framework.tests.models import BasicModel factory = RequestFactory() +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + class RootView(generics.ListCreateAPIView): """ Example description for OPTIONS. @@ -20,21 +27,6 @@ class RootView(generics.ListCreateAPIView): paginate_by = 10 -if django_filters: - class DecimalFilter(django_filters.FilterSet): - decimal = django_filters.NumberFilter(lookup_type='lt') - - class Meta: - model = FilterableItem - fields = ['text', 'decimal', 'date'] - - class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem - paginate_by = 10 - filter_class = DecimalFilter - filter_backend = filters.DjangoFilterBackend - - class DefaultPageSizeKwargView(generics.ListAPIView): """ View for testing default paginate_by_param usage @@ -73,7 +65,9 @@ class IntegrationTestPagination(TestCase): GET requests to paginated ListCreateAPIView should return paginated results. """ request = factory.get('/') - response = self.view(request).render() + # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` + with self.assertNumQueries(2): + response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 26) self.assertEqual(response.data['results'], self.data[:10]) @@ -81,7 +75,8 @@ class IntegrationTestPagination(TestCase): self.assertEqual(response.data['previous'], None) request = factory.get(response.data['next']) - response = self.view(request).render() + with self.assertNumQueries(2): + response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 26) self.assertEqual(response.data['results'], self.data[10:20]) @@ -89,7 +84,8 @@ class IntegrationTestPagination(TestCase): self.assertNotEqual(response.data['previous'], None) request = factory.get(response.data['next']) - response = self.view(request).render() + with self.assertNumQueries(2): + response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 26) self.assertEqual(response.data['results'], self.data[20:]) @@ -112,20 +108,37 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.objects = FilterableItem.objects self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} - for obj in self.objects.all() + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + for obj in self.objects.all() ] - self.view = FilterFieldsRootView.as_view() @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_get_paginated_filtered_root_view(self): + def test_get_django_filter_paginated_filtered_root_view(self): """ GET requests to paginated filtered ListCreateAPIView should return paginated results. The next and previous links should preserve the filtered parameters. """ + class DecimalFilter(django_filters.FilterSet): + decimal = django_filters.NumberFilter(lookup_type='lt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_class = DecimalFilter + filter_backends = (filters.DjangoFilterBackend,) + + view = FilterFieldsRootView.as_view() + + EXPECTED_NUM_QUERIES = 2 + request = factory.get('/?decimal=15.20') - response = self.view(request).render() + with self.assertNumQueries(EXPECTED_NUM_QUERIES): + response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 15) self.assertEqual(response.data['results'], self.data[:10]) @@ -133,7 +146,8 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.assertEqual(response.data['previous'], None) request = factory.get(response.data['next']) - response = self.view(request).render() + with self.assertNumQueries(EXPECTED_NUM_QUERIES): + response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 15) self.assertEqual(response.data['results'], self.data[10:15]) @@ -141,7 +155,53 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.assertNotEqual(response.data['previous'], None) request = factory.get(response.data['previous']) - response = self.view(request).render() + with self.assertNumQueries(EXPECTED_NUM_QUERIES): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) + + def test_get_basic_paginated_filtered_root_view(self): + """ + Same as `test_get_django_filter_paginated_filtered_root_view`, + except using a custom filter backend instead of the django-filter + backend, + """ + + class DecimalFilterBackend(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) + + class BasicFilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_backends = (DecimalFilterBackend,) + + view = BasicFilterFieldsRootView.as_view() + + request = factory.get('/?decimal=15.20') + with self.assertNumQueries(2): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) + + request = factory.get(response.data['next']) + with self.assertNumQueries(2): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[10:15]) + self.assertEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) + + request = factory.get(response.data['previous']) + with self.assertNumQueries(2): + response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 15) self.assertEqual(response.data['results'], self.data[:10]) diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py index 539c5b44..7699e10c 100644 --- a/rest_framework/tests/parsers.py +++ b/rest_framework/tests/parsers.py @@ -1,10 +1,11 @@ from __future__ import unicode_literals from rest_framework.compat import StringIO from django import forms +from django.core.files.uploadhandler import MemoryFileUploadHandler from django.test import TestCase from django.utils import unittest from rest_framework.compat import etree -from rest_framework.parsers import FormParser +from rest_framework.parsers import FormParser, FileUploadParser from rest_framework.parsers import XMLParser import datetime @@ -82,3 +83,33 @@ class TestXMLParser(TestCase): parser = XMLParser() data = parser.parse(self._complex_data_input) self.assertEqual(data, self._complex_data) + + +class TestFileUploadParser(TestCase): + def setUp(self): + class MockRequest(object): + pass + from io import BytesIO + self.stream = BytesIO( + "Test text file".encode('utf-8') + ) + request = MockRequest() + request.upload_handlers = (MemoryFileUploadHandler(),) + request.META = { + 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'), + 'HTTP_CONTENT_LENGTH': 14, + } + self.parser_context = {'request': request, 'kwargs': {}} + + def test_parse(self): + """ Make sure the `QueryDict` works OK """ + parser = FileUploadParser() + self.stream.seek(0) + data_and_files = parser.parse(self.stream, None, self.parser_context) + file_obj = data_and_files.files['file'] + self.assertEqual(file_obj._size, 14) + + def test_get_filename(self): + parser = FileUploadParser() + filename = parser.get_filename(self.stream, None, self.parser_context) + self.assertEqual(filename, 'file.txt'.encode('utf-8')) diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index b5702a48..b1eed9a7 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -26,42 +26,44 @@ urlpatterns = patterns('', ) +# ManyToMany class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): - sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail') - class Meta: model = ManyToManyTarget + fields = ('url', 'name', 'sources') class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = ManyToManySource + fields = ('url', 'name', 'targets') +# ForeignKey class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): - sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail') - class Meta: model = ForeignKeyTarget + fields = ('url', 'name', 'sources') class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = ForeignKeySource + fields = ('url', 'name', 'target') # Nullable ForeignKey class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = NullableForeignKeySource + fields = ('url', 'name', 'target') -# OneToOne +# Nullable OneToOne class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): - nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail') - class Meta: model = OneToOneTarget + fields = ('url', 'name', 'nullable_source') # TODO: Add test that .data cannot be accessed prior to .is_valid diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index a125ba65..f6d006b3 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -6,38 +6,30 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null class ForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: - depth = 1 - model = ForeignKeySource - - -class FlatForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: model = ForeignKeySource + fields = ('id', 'name', 'target') + depth = 1 class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = FlatForeignKeySourceSerializer(many=True) - class Meta: model = ForeignKeyTarget + fields = ('id', 'name', 'sources') + depth = 1 class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: - depth = 1 model = NullableForeignKeySource - - -class NullableOneToOneSourceSerializer(serializers.ModelSerializer): - class Meta: - model = NullableOneToOneSource + fields = ('id', 'name', 'target') + depth = 1 class NullableOneToOneTargetSerializer(serializers.ModelSerializer): - nullable_source = NullableOneToOneSourceSerializer() - class Meta: model = OneToOneTarget + fields = ('id', 'name', 'nullable_source') + depth = 1 class ReverseForeignKeyTests(TestCase): diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index d6ae3176..5ce8b567 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -5,41 +5,44 @@ from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, Fore from rest_framework.compat import six +# ManyToMany class ManyToManyTargetSerializer(serializers.ModelSerializer): - sources = serializers.PrimaryKeyRelatedField(many=True) - class Meta: model = ManyToManyTarget + fields = ('id', 'name', 'sources') class ManyToManySourceSerializer(serializers.ModelSerializer): class Meta: model = ManyToManySource + fields = ('id', 'name', 'targets') +# ForeignKey class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = serializers.PrimaryKeyRelatedField(many=True) - class Meta: model = ForeignKeyTarget + fields = ('id', 'name', 'sources') class ForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = ForeignKeySource + fields = ('id', 'name', 'target') +# Nullable ForeignKey class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = NullableForeignKeySource + fields = ('id', 'name', 'target') -# OneToOne +# Nullable OneToOne class NullableOneToOneTargetSerializer(serializers.ModelSerializer): - nullable_source = serializers.PrimaryKeyRelatedField() - class Meta: model = OneToOneTarget + fields = ('id', 'name', 'nullable_source') # TODO: Add test that .data cannot be accessed prior to .is_valid @@ -407,14 +410,14 @@ class PKNullableOneToOneTests(TestCase): target.save() new_target = OneToOneTarget(name='target-2') new_target.save() - source = NullableOneToOneSource(name='source-1', target=target) + source = NullableOneToOneSource(name='source-1', target=new_target) source.save() def test_reverse_foreign_key_retrieve_with_null(self): queryset = OneToOneTarget.objects.all() serializer = NullableOneToOneTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': 'target-1', 'nullable_source': 1}, - {'id': 2, 'name': 'target-2', 'nullable_source': None}, + {'id': 1, 'name': 'target-1', 'nullable_source': None}, + {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ] self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index 4892f7a6..97e5af20 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -58,6 +58,14 @@ class TestMethodOverloading(TestCase): request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'})) self.assertEqual(request.method, 'DELETE') + def test_x_http_method_override_header(self): + """ + POST requests can also be overloaded to another method by setting + the X-HTTP-Method-Override header. + """ + request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) + self.assertEqual(request.method, 'DELETE') + class TestContentParsing(TestCase): def test_standard_behaviour_determines_no_content_GET(self): diff --git a/rest_framework/tests/routers.py b/rest_framework/tests/routers.py new file mode 100644 index 00000000..4e4765cb --- /dev/null +++ b/rest_framework/tests/routers.py @@ -0,0 +1,55 @@ +from __future__ import unicode_literals +from django.test import TestCase +from django.test.client import RequestFactory +from rest_framework import status +from rest_framework.response import Response +from rest_framework import viewsets +from rest_framework.decorators import link, action +from rest_framework.routers import SimpleRouter +import copy + +factory = RequestFactory() + + +class BasicViewSet(viewsets.ViewSet): + def list(self, request, *args, **kwargs): + return Response({'method': 'list'}) + + @action() + def action1(self, request, *args, **kwargs): + return Response({'method': 'action1'}) + + @action() + def action2(self, request, *args, **kwargs): + return Response({'method': 'action2'}) + + @link() + def link1(self, request, *args, **kwargs): + return Response({'method': 'link1'}) + + @link() + def link2(self, request, *args, **kwargs): + return Response({'method': 'link2'}) + + +class TestSimpleRouter(TestCase): + def setUp(self): + self.router = SimpleRouter() + + def test_link_and_action_decorator(self): + routes = self.router.get_routes(BasicViewSet) + decorator_routes = routes[2:] + # Make sure all these endpoints exist and none have been clobbered + for i, endpoint in enumerate(['action1', 'action2', 'link1', 'link2']): + route = decorator_routes[i] + # check url listing + self.assertEqual(route.url, + '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint)) + # check method to function mapping + if endpoint.startswith('action'): + method_map = 'post' + else: + method_map = 'get' + self.assertEqual(route.mapping[method_map], endpoint) + + diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index d0300f9e..db3881f9 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -3,7 +3,7 @@ from django.utils.datastructures import MultiValueDict from django.test import TestCase from rest_framework import serializers from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, - BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel, + BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) import datetime import pickle @@ -78,6 +78,18 @@ class PersonSerializer(serializers.ModelSerializer): read_only_fields = ('age',) +class PersonSerializerInvalidReadOnly(serializers.ModelSerializer): + """ + Testing for #652. + """ + info = serializers.Field(source='info') + + class Meta: + model = Person + fields = ('name', 'age', 'info') + read_only_fields = ('age', 'info') + + class AlbumsSerializer(serializers.ModelSerializer): class Meta: @@ -189,6 +201,12 @@ class BasicTests(TestCase): # Assert age is unchanged (35) self.assertEqual(instance.age, self.person_data['age']) + def test_invalid_read_only_fields(self): + """ + Regression test for #652. + """ + self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, []) + class DictStyleSerializer(serializers.Serializer): """ @@ -261,25 +279,6 @@ class ValidationTests(TestCase): self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.errors, {}) - def test_bad_type_data_is_false(self): - """ - Data of the wrong type is not valid. - """ - data = ['i am', 'a', 'list'] - serializer = CommentSerializer(self.comment, data=data, many=True) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) - - data = 'and i am a string' - serializer = CommentSerializer(self.comment, data=data) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) - - data = 42 - serializer = CommentSerializer(self.comment, data=data) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) - def test_cross_field_validation(self): class CommentSerializerWithCrossFieldValidator(CommentSerializer): @@ -376,7 +375,6 @@ class CustomValidationTests(TestCase): def validate_email(self, attrs, source): value = attrs[source] - return attrs def validate_content(self, attrs, source): @@ -757,6 +755,43 @@ class ManyRelatedTests(TestCase): self.assertEqual(serializer.data, expected) + def test_include_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + fields = ('id', 'title', 'blogpostcomment_set') + + serializer = BlogPostSerializer(instance=post) + expected = { + 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2] + } + self.assertEqual(serializer.data, expected) + + def test_depth_include_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + fields = ('id', 'title', 'blogpostcomment_set') + depth = 1 + + serializer = BlogPostSerializer(instance=post) + expected = { + 'id': 1, 'title': 'Test blog post', + 'blogpostcomment_set': [ + {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1}, + {'id': 2, 'text': 'I love this blog post', 'blog_post': 1} + ] + } + self.assertEqual(serializer.data, expected) + def test_callable_source(self): post = BlogPost.objects.create(title="Test blog post") post.blogpostcomment_set.create(text="I love this blog post") @@ -786,8 +821,6 @@ class RelatedTraversalTest(TestCase): post = BlogPost.objects.create(title="Test blog post", writer=user) post.blogpostcomment_set.create(text="I love this blog post") - from rest_framework.tests.models import BlogPostComment - class PersonSerializer(serializers.ModelSerializer): class Meta: model = Person @@ -987,23 +1020,26 @@ class SerializerPickleTests(TestCase): class DepthTest(TestCase): def test_implicit_nesting(self): + writer = Person.objects.create(name="django", age=1) post = BlogPost.objects.create(title="Test blog post", writer=writer) + comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) - class BlogPostSerializer(serializers.ModelSerializer): + class BlogPostCommentSerializer(serializers.ModelSerializer): class Meta: - model = BlogPost - depth = 1 + model = BlogPostComment + depth = 2 - serializer = BlogPostSerializer(instance=post) - expected = {'id': 1, 'title': 'Test blog post', - 'writer': {'id': 1, 'name': 'django', 'age': 1}} + serializer = BlogPostCommentSerializer(instance=comment) + expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', + 'writer': {'id': 1, 'name': 'django', 'age': 1}}} self.assertEqual(serializer.data, expected) def test_explicit_nesting(self): writer = Person.objects.create(name="django", age=1) post = BlogPost.objects.create(title="Test blog post", writer=writer) + comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) class PersonSerializer(serializers.ModelSerializer): class Meta: @@ -1015,9 +1051,15 @@ class DepthTest(TestCase): class Meta: model = BlogPost - serializer = BlogPostSerializer(instance=post) - expected = {'id': 1, 'title': 'Test blog post', - 'writer': {'id': 1, 'name': 'django', 'age': 1}} + class BlogPostCommentSerializer(serializers.ModelSerializer): + blog_post = BlogPostSerializer() + + class Meta: + model = BlogPostComment + + serializer = BlogPostCommentSerializer(instance=comment) + expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', + 'writer': {'id': 1, 'name': 'django', 'age': 1}}} self.assertEqual(serializer.data, expected) @@ -1072,3 +1114,32 @@ class NestedSerializerContextTests(TestCase): # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data + + +class DeserializeListTestCase(TestCase): + + def setUp(self): + self.data = { + 'email': 'nobody@nowhere.com', + 'content': 'This is some test content', + 'created': datetime.datetime(2013, 3, 7), + } + + def test_no_errors(self): + data = [self.data.copy() for x in range(0, 3)] + serializer = CommentSerializer(data=data, many=True) + self.assertTrue(serializer.is_valid()) + self.assertTrue(isinstance(serializer.object, list)) + self.assertTrue( + all((isinstance(item, Comment) for item in serializer.object)) + ) + + def test_errors_return_as_list(self): + invalid_item = self.data.copy() + invalid_item['email'] = '' + data = [self.data.copy(), invalid_item, self.data.copy()] + + serializer = CommentSerializer(data=data, many=True) + self.assertFalse(serializer.is_valid()) + expected = [{}, {'email': ['This field is required.']}, {}] + self.assertEqual(serializer.errors, expected) diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py new file mode 100644 index 00000000..8b0ded1a --- /dev/null +++ b/rest_framework/tests/serializer_bulk_update.py @@ -0,0 +1,278 @@ +""" +Tests to cover bulk create and update using serializers. +""" +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers + + +class BulkCreateSerializerTests(TestCase): + """ + Creating multiple instances using serializers. + """ + + def setUp(self): + class BookSerializer(serializers.Serializer): + id = serializers.IntegerField() + title = serializers.CharField(max_length=100) + author = serializers.CharField(max_length=100) + + self.BookSerializer = BookSerializer + + def test_bulk_create_success(self): + """ + Correct bulk update serialization should return the input data. + """ + + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 1, + 'title': 'If this is a man', + 'author': 'Primo Levi' + }, { + 'id': 2, + 'title': 'The wind-up bird chronicle', + 'author': 'Haruki Murakami' + } + ] + + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + + def test_bulk_create_errors(self): + """ + Correct bulk update serialization should return the input data. + """ + + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 1, + 'title': 'If this is a man', + 'author': 'Primo Levi' + }, { + 'id': 'foo', + 'title': 'The wind-up bird chronicle', + 'author': 'Haruki Murakami' + } + ] + expected_errors = [ + {}, + {}, + {'id': ['Enter a whole number.']} + ] + + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + + def test_invalid_list_datatype(self): + """ + Data containing list of incorrect data type should return errors. + """ + data = ['foo', 'bar', 'baz'] + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + + expected_errors = [ + {'non_field_errors': ['Invalid data']}, + {'non_field_errors': ['Invalid data']}, + {'non_field_errors': ['Invalid data']} + ] + + self.assertEqual(serializer.errors, expected_errors) + + def test_invalid_single_datatype(self): + """ + Data containing a single incorrect data type should return errors. + """ + data = 123 + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + + expected_errors = {'non_field_errors': ['Expected a list of items.']} + + self.assertEqual(serializer.errors, expected_errors) + + def test_invalid_single_object(self): + """ + Data containing only a single object, instead of a list of objects + should return errors. + """ + data = { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + } + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + + expected_errors = {'non_field_errors': ['Expected a list of items.']} + + self.assertEqual(serializer.errors, expected_errors) + + +class BulkUpdateSerializerTests(TestCase): + """ + Updating multiple instances using serializers. + """ + + def setUp(self): + class Book(object): + """ + A data type that can be persisted to a mock storage backend + with `.save()` and `.delete()`. + """ + object_map = {} + + def __init__(self, id, title, author): + self.id = id + self.title = title + self.author = author + + def save(self): + Book.object_map[self.id] = self + + def delete(self): + del Book.object_map[self.id] + + class BookSerializer(serializers.Serializer): + id = serializers.IntegerField() + title = serializers.CharField(max_length=100) + author = serializers.CharField(max_length=100) + + def restore_object(self, attrs, instance=None): + if instance: + instance.id = attrs['id'] + instance.title = attrs['title'] + instance.author = attrs['author'] + return instance + return Book(**attrs) + + self.Book = Book + self.BookSerializer = BookSerializer + + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 1, + 'title': 'If this is a man', + 'author': 'Primo Levi' + }, { + 'id': 2, + 'title': 'The wind-up bird chronicle', + 'author': 'Haruki Murakami' + } + ] + + for item in data: + book = Book(item['id'], item['title'], item['author']) + book.save() + + def books(self): + """ + Return all the objects in the mock storage backend. + """ + return self.Book.object_map.values() + + def test_bulk_update_success(self): + """ + Correct bulk update serialization should return the input data. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 2, + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.data, data) + serializer.save() + new_data = self.BookSerializer(self.books(), many=True).data + + self.assertEqual(data, new_data) + + def test_bulk_update_and_create(self): + """ + Bulk update serialization may also include created items. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 3, + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.data, data) + serializer.save() + new_data = self.BookSerializer(self.books(), many=True).data + self.assertEqual(data, new_data) + + def test_bulk_update_invalid_create(self): + """ + Bulk update serialization without allow_add_remove may not create items. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 3, + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + expected_errors = [ + {}, + {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']} + ] + serializer = self.BookSerializer(self.books(), data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + + def test_bulk_update_error(self): + """ + Incorrect bulk update serialization should return error data. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 'foo', + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + expected_errors = [ + {}, + {'id': ['Enter a whole number.']} + ] + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py new file mode 100644 index 00000000..71d0e24b --- /dev/null +++ b/rest_framework/tests/serializer_nested.py @@ -0,0 +1,246 @@ +""" +Tests to cover nested serializers. + +Doesn't cover model serializers. +""" +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers + + +class WritableNestedSerializerBasicTests(TestCase): + """ + Tests for deserializing nested entities. + Basic tests that use serializers that simply restore to dicts. + """ + + def setUp(self): + class TrackSerializer(serializers.Serializer): + order = serializers.IntegerField() + title = serializers.CharField(max_length=100) + duration = serializers.IntegerField() + + class AlbumSerializer(serializers.Serializer): + album_name = serializers.CharField(max_length=100) + artist = serializers.CharField(max_length=100) + tracks = TrackSerializer(many=True) + + self.AlbumSerializer = AlbumSerializer + + def test_nested_validation_success(self): + """ + Correct nested serialization should return the input data. + """ + + data = { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 239} + ] + } + + serializer = self.AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + + def test_nested_validation_error(self): + """ + Incorrect nested serialization should return appropriate error data. + """ + + data = { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} + ] + } + expected_errors = { + 'tracks': [ + {}, + {}, + {'duration': ['Enter a whole number.']} + ] + } + + serializer = self.AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + + def test_many_nested_validation_error(self): + """ + Incorrect nested serialization should return appropriate error data + when multiple entities are being deserialized. + """ + + data = [ + { + 'album_name': 'Russian Red', + 'artist': 'I Love Your Glasses', + 'tracks': [ + {'order': 1, 'title': 'Cigarettes', 'duration': 121}, + {'order': 2, 'title': 'No Past Land', 'duration': 198}, + {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} + ] + }, + { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} + ] + } + ] + expected_errors = [ + {}, + { + 'tracks': [ + {}, + {}, + {'duration': ['Enter a whole number.']} + ] + } + ] + + serializer = self.AlbumSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + + +class WritableNestedSerializerObjectTests(TestCase): + """ + Tests for deserializing nested entities. + These tests use serializers that restore to concrete objects. + """ + + def setUp(self): + # Couple of concrete objects that we're going to deserialize into + class Track(object): + def __init__(self, order, title, duration): + self.order, self.title, self.duration = order, title, duration + + def __eq__(self, other): + return ( + self.order == other.order and + self.title == other.title and + self.duration == other.duration + ) + + class Album(object): + def __init__(self, album_name, artist, tracks): + self.album_name, self.artist, self.tracks = album_name, artist, tracks + + def __eq__(self, other): + return ( + self.album_name == other.album_name and + self.artist == other.artist and + self.tracks == other.tracks + ) + + # And their corresponding serializers + class TrackSerializer(serializers.Serializer): + order = serializers.IntegerField() + title = serializers.CharField(max_length=100) + duration = serializers.IntegerField() + + def restore_object(self, attrs, instance=None): + return Track(attrs['order'], attrs['title'], attrs['duration']) + + class AlbumSerializer(serializers.Serializer): + album_name = serializers.CharField(max_length=100) + artist = serializers.CharField(max_length=100) + tracks = TrackSerializer(many=True) + + def restore_object(self, attrs, instance=None): + return Album(attrs['album_name'], attrs['artist'], attrs['tracks']) + + self.Album, self.Track = Album, Track + self.AlbumSerializer = AlbumSerializer + + def test_nested_validation_success(self): + """ + Correct nested serialization should return a restored object + that corresponds to the input data. + """ + + data = { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 239} + ] + } + expected_object = self.Album( + album_name='Discovery', + artist='Daft Punk', + tracks=[ + self.Track(order=1, title='One More Time', duration=235), + self.Track(order=2, title='Aerodynamic', duration=184), + self.Track(order=3, title='Digital Love', duration=239), + ] + ) + + serializer = self.AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, expected_object) + + def test_many_nested_validation_success(self): + """ + Correct nested serialization should return multiple restored objects + that corresponds to the input data when multiple objects are + being deserialized. + """ + + data = [ + { + 'album_name': 'Russian Red', + 'artist': 'I Love Your Glasses', + 'tracks': [ + {'order': 1, 'title': 'Cigarettes', 'duration': 121}, + {'order': 2, 'title': 'No Past Land', 'duration': 198}, + {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} + ] + }, + { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 239} + ] + } + ] + expected_object = [ + self.Album( + album_name='Russian Red', + artist='I Love Your Glasses', + tracks=[ + self.Track(order=1, title='Cigarettes', duration=121), + self.Track(order=2, title='No Past Land', duration=198), + self.Track(order=3, title='They Don\'t Believe', duration=191), + ] + ), + self.Album( + album_name='Discovery', + artist='Daft Punk', + tracks=[ + self.Track(order=1, title='One More Time', duration=235), + self.Track(order=2, title='Aerodynamic', duration=184), + self.Track(order=3, title='Digital Love', duration=239), + ] + ) + ] + + serializer = self.AlbumSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, expected_object) diff --git a/rest_framework/tests/status.py b/rest_framework/tests/status.py deleted file mode 100644 index e1644a6b..00000000 --- a/rest_framework/tests/status.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Tests for the status module""" -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework import status - - -class TestStatus(TestCase): - """Simple sanity test to check the status module""" - - def test_status(self): - """Ensure the status module is present and correct.""" - self.assertEqual(200, status.HTTP_200_OK) - self.assertEqual(404, status.HTTP_404_NOT_FOUND) diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 810cad63..93ea9816 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -1,3 +1,6 @@ +""" +Provides various throttling policies. +""" from __future__ import unicode_literals from django.core.cache import cache from rest_framework import exceptions @@ -28,9 +31,8 @@ class SimpleRateThrottle(BaseThrottle): A simple cache implementation, that only requires `.get_cache_key()` to be overridden. - The rate (requests / seconds) is set by a :attr:`throttle` attribute - on the :class:`.View` class. The attribute is a string of the form 'number of - requests/period'. + The rate (requests / seconds) is set by a `throttle` attribute on the View + class. The attribute is a string of the form 'number_of_requests/period'. Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index af21ac79..d51374b0 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -1,26 +1,37 @@ from __future__ import unicode_literals from django.core.urlresolvers import resolve, get_script_prefix +from rest_framework.utils.formatting import get_view_name def get_breadcrumbs(url): - """Given a url returns a list of breadcrumbs, which are each a tuple of (name, url).""" + """ + Given a url returns a list of breadcrumbs, which are each a + tuple of (name, url). + """ from rest_framework.views import APIView def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen): - """Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url.""" + """ + Add tuples of (name, url) to the breadcrumbs list, + progressively chomping off parts of the url. + """ try: (view, unused_args, unused_kwargs) = resolve(url) except Exception: pass else: - # Check if this is a REST framework view, and if so add it to the breadcrumbs - if isinstance(getattr(view, 'cls_instance', None), APIView): + # Check if this is a REST framework view, + # and if so add it to the breadcrumbs + cls = getattr(view, 'cls', None) + if cls is not None and issubclass(cls, APIView): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: - breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url)) + suffix = getattr(view, 'suffix', None) + name = get_view_name(view.cls, suffix) + breadcrumbs_list.insert(0, (name, prefix + url)) seen.append(view) if url == '': @@ -28,11 +39,15 @@ def get_breadcrumbs(url): return breadcrumbs_list elif url.endswith('/'): - # Drop trailing slash off the end and continue to try to resolve more breadcrumbs - return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen) - - # Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs - return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen) + # Drop trailing slash off the end and continue to try to + # resolve more breadcrumbs + url = url.rstrip('/') + return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) + + # Drop trailing non-slash off the end and continue to try to + # resolve more breadcrumbs + url = url[:url.rfind('/') + 1] + return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) prefix = get_script_prefix().rstrip('/') url = url[len(prefix):] diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py new file mode 100644 index 00000000..ebadb3a6 --- /dev/null +++ b/rest_framework/utils/formatting.py @@ -0,0 +1,80 @@ +""" +Utility functions to return a formatted name and description for a given view. +""" +from __future__ import unicode_literals + +from django.utils.html import escape +from django.utils.safestring import mark_safe +from rest_framework.compat import apply_markdown +import re + + +def _remove_trailing_string(content, trailing): + """ + Strip trailing component `trailing` from `content` if it exists. + Used when generating names from view classes. + """ + if content.endswith(trailing) and content != trailing: + return content[:-len(trailing)] + return content + + +def _remove_leading_indent(content): + """ + Remove leading indent from a block of text. + Used when generating descriptions from docstrings. + """ + whitespace_counts = [len(line) - len(line.lstrip(' ')) + for line in content.splitlines()[1:] if line.lstrip()] + + # unindent the content if needed + if whitespace_counts: + whitespace_pattern = '^' + (' ' * min(whitespace_counts)) + content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) + content = content.strip('\n') + return content + + +def _camelcase_to_spaces(content): + """ + Translate 'CamelCaseNames' to 'Camel Case Names'. + Used when generating names from view classes. + """ + camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' + content = re.sub(camelcase_boundry, ' \\1', content).strip() + return ' '.join(content.split('_')).title() + + +def get_view_name(cls, suffix=None): + """ + Return a formatted name for an `APIView` class or `@api_view` function. + """ + name = cls.__name__ + name = _remove_trailing_string(name, 'View') + name = _remove_trailing_string(name, 'ViewSet') + name = _camelcase_to_spaces(name) + if suffix: + name += ' ' + suffix + return name + + +def get_view_description(cls, html=False): + """ + Return a description for an `APIView` class or `@api_view` function. + """ + description = cls.__doc__ or '' + description = _remove_leading_indent(description) + if html: + return markup_description(description) + return description + + +def markup_description(description): + """ + Apply HTML markup to the given description. + """ + if apply_markdown: + description = apply_markdown(description) + else: + description = escape(description).replace('\n', '<br />') + return mark_safe(description) diff --git a/rest_framework/views.py b/rest_framework/views.py index 81cbdcbb..555fa2f4 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -1,54 +1,16 @@ """ -Provides an APIView class that is used as the base of all class-based views. +Provides an APIView class that is the base of all views in REST framework. """ from __future__ import unicode_literals from django.core.exceptions import PermissionDenied -from django.http import Http404 -from django.utils.html import escape -from django.utils.safestring import mark_safe +from django.http import Http404, HttpResponse from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions -from rest_framework.compat import View, apply_markdown +from rest_framework.compat import View from rest_framework.response import Response from rest_framework.request import Request from rest_framework.settings import api_settings -import re - - -def _remove_trailing_string(content, trailing): - """ - Strip trailing component `trailing` from `content` if it exists. - Used when generating names from view classes. - """ - if content.endswith(trailing) and content != trailing: - return content[:-len(trailing)] - return content - - -def _remove_leading_indent(content): - """ - Remove leading indent from a block of text. - Used when generating descriptions from docstrings. - """ - whitespace_counts = [len(line) - len(line.lstrip(' ')) - for line in content.splitlines()[1:] if line.lstrip()] - - # unindent the content if needed - if whitespace_counts: - whitespace_pattern = '^' + (' ' * min(whitespace_counts)) - content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) - content = content.strip('\n') - return content - - -def _camelcase_to_spaces(content): - """ - Translate 'CamelCaseNames' to 'Camel Case Names'. - Used when generating names from view classes. - """ - camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' - content = re.sub(camelcase_boundry, ' \\1', content).strip() - return ' '.join(content.split('_')).title() +from rest_framework.utils.formatting import get_view_name, get_view_description class APIView(View): @@ -64,22 +26,21 @@ class APIView(View): @classmethod def as_view(cls, **initkwargs): """ - Override the default :meth:`as_view` to store an instance of the view - as an attribute on the callable function. This allows us to discover - information about the view when we do URL reverse lookups. + Store the original class on the view function. + + This allows us to discover information about the view when we do URL + reverse lookups. Used for breadcrumb generation. """ - # TODO: deprecate? view = super(APIView, cls).as_view(**initkwargs) - view.cls_instance = cls(**initkwargs) + view.cls = cls return view @property def allowed_methods(self): """ - Return the list of allowed HTTP methods, uppercased. + Wrap Django's private `_allowed_methods` interface in a public property. """ - return [method.upper() for method in self.http_method_names - if hasattr(self, method)] + return self._allowed_methods() @property def default_response_headers(self): @@ -90,43 +51,10 @@ class APIView(View): 'Vary': 'Accept' } - def get_name(self): - """ - Return the resource or view class name for use as this view's name. - Override to customize. - """ - # TODO: deprecate? - name = self.__class__.__name__ - name = _remove_trailing_string(name, 'View') - return _camelcase_to_spaces(name) - - def get_description(self, html=False): - """ - Return the resource or view docstring for use as this view's description. - Override to customize. - """ - # TODO: deprecate? - description = self.__doc__ or '' - description = _remove_leading_indent(description) - if html: - return self.markup_description(description) - return description - - def markup_description(self, description): - """ - Apply HTML markup to the description of this view. - """ - # TODO: deprecate? - if apply_markdown: - description = apply_markdown(description) - else: - description = escape(description).replace('\n', '<br />') - return mark_safe(description) - def metadata(self, request): return { - 'name': self.get_name(), - 'description': self.get_description(), + 'name': get_view_name(self.__class__), + 'description': get_view_description(self.__class__), 'renders': [renderer.media_type for renderer in self.renderer_classes], 'parses': [parser.media_type for parser in self.parser_classes], } @@ -140,7 +68,8 @@ class APIView(View): def http_method_not_allowed(self, request, *args, **kwargs): """ - Called if `request.method` does not correspond to a handler method. + If `request.method` does not correspond to a handler method, + determine what kind of exception to raise. """ raise exceptions.MethodNotAllowed(request.method) @@ -327,6 +256,12 @@ class APIView(View): """ Returns the final response object. """ + # Make the error obvious if a proper response is not returned + assert isinstance(response, HttpResponse), ( + 'Expected a `Response` to be returned from the view, ' + 'but received a `%s`' % type(response) + ) + if isinstance(response, Response): if not getattr(request, 'accepted_renderer', None): neg = self.perform_content_negotiation(request, force=True) diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py new file mode 100644 index 00000000..d91323f2 --- /dev/null +++ b/rest_framework/viewsets.py @@ -0,0 +1,139 @@ +""" +ViewSets are essentially just a type of class based view, that doesn't provide +any method handlers, such as `get()`, `post()`, etc... but instead has actions, +such as `list()`, `retrieve()`, `create()`, etc... + +Actions are only bound to methods at the point of instantiating the views. + + user_list = UserViewSet.as_view({'get': 'list'}) + user_detail = UserViewSet.as_view({'get': 'retrieve'}) + +Typically, rather than instantiate views from viewsets directly, you'll +regsiter the viewset with a router and let the URL conf be determined +automatically. + + router = DefaultRouter() + router.register(r'users', UserViewSet, 'user') + urlpatterns = router.urls +""" +from __future__ import unicode_literals + +from functools import update_wrapper +from django.utils.decorators import classonlymethod +from rest_framework import views, generics, mixins + + +class ViewSetMixin(object): + """ + This is the magic. + + Overrides `.as_view()` so that it takes an `actions` keyword that performs + the binding of HTTP methods to actions on the Resource. + + For example, to create a concrete view binding the 'GET' and 'POST' methods + to the 'list' and 'create' actions... + + view = MyViewSet.as_view({'get': 'list', 'post': 'create'}) + """ + + @classonlymethod + def as_view(cls, actions=None, **initkwargs): + """ + Because of the way class based views create a closure around the + instantiated view, we need to totally reimplement `.as_view`, + and slightly modify the view function that is created and returned. + """ + # The suffix initkwarg is reserved for identifing the viewset type + # eg. 'List' or 'Instance'. + cls.suffix = None + + # sanitize keyword arguments + for key in initkwargs: + if key in cls.http_method_names: + raise TypeError("You tried to pass in the %s method name as a " + "keyword argument to %s(). Don't do that." + % (key, cls.__name__)) + if not hasattr(cls, key): + raise TypeError("%s() received an invalid keyword %r" % ( + cls.__name__, key)) + + def view(request, *args, **kwargs): + self = cls(**initkwargs) + # We also store the mapping of request methods to actions, + # so that we can later set the action attribute. + # eg. `self.action = 'list'` on an incoming GET request. + self.action_map = actions + + # Bind methods to actions + # This is the bit that's different to a standard view + for method, action in actions.items(): + handler = getattr(self, action) + setattr(self, method, handler) + + # Patch this in as it's otherwise only present from 1.5 onwards + if hasattr(self, 'get') and not hasattr(self, 'head'): + self.head = self.get + + # And continue as usual + return self.dispatch(request, *args, **kwargs) + + # take name and docstring from class + update_wrapper(view, cls, updated=()) + + # and possible attributes set by decorators + # like csrf_exempt from dispatch + update_wrapper(view, cls.dispatch, assigned=()) + + # We need to set these on the view function, so that breadcrumb + # generation can pick out these bits of information from a + # resolved URL. + view.cls = cls + view.suffix = initkwargs.get('suffix', None) + return view + + def initialize_request(self, request, *args, **kargs): + """ + Set the `.action` attribute on the view, + depending on the request method. + """ + request = super(ViewSetMixin, self).initialize_request(request, *args, **kargs) + self.action = self.action_map.get(request.method.lower()) + return request + + +class ViewSet(ViewSetMixin, views.APIView): + """ + The base ViewSet class does not provide any actions by default. + """ + pass + + +class GenericViewSet(ViewSetMixin, generics.GenericAPIView): + """ + The GenericViewSet class does not provide any actions by default, + but does include the base set of generic view behavior, such as + the `get_object` and `get_queryset` methods. + """ + pass + + +class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, + mixins.ListModelMixin, + GenericViewSet): + """ + A viewset that provides default `list()` and `retrieve()` actions. + """ + pass + + +class ModelViewSet(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + mixins.ListModelMixin, + GenericViewSet): + """ + A viewset that provides default `create()`, `retrieve()`, `update()`, + `partial_update()`, `destroy()` and `list()` actions. + """ + pass |
