diff options
Diffstat (limited to 'rest_framework')
40 files changed, 748 insertions, 171 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 6759680b..2d76b55d 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -8,10 +8,10 @@ ______ _____ _____ _____ __ _ """ __title__ = 'Django REST framework' -__version__ = '2.3.12' +__version__ = '2.3.13' __author__ = 'Tom Christie' __license__ = 'BSD 2-Clause' -__copyright__ = 'Copyright 2011-2013 Tom Christie' +__copyright__ = 'Copyright 2011-2014 Tom Christie' # Version synonym VERSION = __version__ diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index e491ce5f..da9ca510 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -6,6 +6,7 @@ import base64 from django.contrib.auth import authenticate from django.core.exceptions import ImproperlyConfigured +from django.conf import settings from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import oauth, oauth_provider, oauth_provider_store @@ -291,6 +292,7 @@ class OAuth2Authentication(BaseAuthentication): OAuth 2 authentication backend using `django-oauth2-provider` """ www_authenticate_realm = 'api' + allow_query_params_token = settings.DEBUG def __init__(self, *args, **kwargs): super(OAuth2Authentication, self).__init__(*args, **kwargs) @@ -308,7 +310,13 @@ class OAuth2Authentication(BaseAuthentication): auth = get_authorization_header(request).split() - if not auth or auth[0].lower() != b'bearer': + if auth and auth[0].lower() == b'bearer': + access_token = auth[1] + elif 'access_token' in request.POST: + access_token = request.POST['access_token'] + elif 'access_token' in request.GET and self.allow_query_params_token: + access_token = request.GET['access_token'] + else: return None if len(auth) == 1: @@ -318,7 +326,7 @@ class OAuth2Authentication(BaseAuthentication): msg = 'Invalid bearer header. Token string should not contain spaces.' raise exceptions.AuthenticationFailed(msg) - return self.authenticate_credentials(request, auth[1]) + return self.authenticate_credentials(request, access_token) def authenticate_credentials(self, request, access_token): """ @@ -326,11 +334,11 @@ class OAuth2Authentication(BaseAuthentication): """ try: - token = oauth2_provider.models.AccessToken.objects.select_related('user') + token = oauth2_provider.oauth2.models.AccessToken.objects.select_related('user') # provider_now switches to timezone aware datetime when # the oauth2_provider version supports to it. token = token.get(token=access_token, expires__gt=provider_now()) - except oauth2_provider.models.AccessToken.DoesNotExist: + except oauth2_provider.oauth2.models.AccessToken.DoesNotExist: raise exceptions.AuthenticationFailed('Invalid token') user = token.user diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py index 8eac2cc4..167fa531 100644 --- a/rest_framework/authtoken/models.py +++ b/rest_framework/authtoken/models.py @@ -34,7 +34,7 @@ class Token(models.Model): return super(Token, self).save(*args, **kwargs) def generate_key(self): - return binascii.hexlify(os.urandom(20)) + return binascii.hexlify(os.urandom(20)).decode() def __unicode__(self): return self.key diff --git a/rest_framework/compat.py b/rest_framework/compat.py index d283e2f5..d155f554 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -550,13 +550,10 @@ except (ImportError, ImproperlyConfigured): # OAuth 2 support is optional try: - import provider.oauth2 as oauth2_provider - from provider.oauth2 import models as oauth2_provider_models - from provider.oauth2 import forms as oauth2_provider_forms + import provider as oauth2_provider from provider import scope as oauth2_provider_scope from provider import constants as oauth2_constants - from provider import __version__ as provider_version - if provider_version in ('0.2.3', '0.2.4'): + if oauth2_provider.__version__ in ('0.2.3', '0.2.4'): # 0.2.3 and 0.2.4 are supported version that do not support # timezone aware datetimes import datetime @@ -566,8 +563,6 @@ try: from django.utils.timezone import now as provider_now except ImportError: oauth2_provider = None - oauth2_provider_models = None - oauth2_provider_forms = None oauth2_provider_scope = None oauth2_constants = None provider_now = None @@ -584,3 +579,23 @@ if six.PY3: else: def is_non_str_iterable(obj): return hasattr(obj, '__iter__') + + +try: + from django.utils.encoding import python_2_unicode_compatible +except ImportError: + def python_2_unicode_compatible(klass): + """ + A decorator that defines __unicode__ and __str__ methods under Python 2. + Under Python 3 it does nothing. + + To support Python 2 and 3 with a single code base, define a __str__ method + returning text and apply this decorator to the class. + """ + if '__str__' not in klass.__dict__: + raise ValueError("@python_2_unicode_compatible cannot be applied " + "to %s because it doesn't define __str__()." % + klass.__name__) + klass.__unicode__ = klass.__str__ + klass.__str__ = lambda self: self.__unicode__().encode('utf-8') + return klass diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 0ac5866e..5f774a9f 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -20,6 +20,8 @@ class APIException(Exception): def __init__(self, detail=None): self.detail = detail or self.default_detail + def __str__(self): + return self.detail class ParseError(APIException): status_code = status.HTTP_400_BAD_REQUEST diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 05daaab7..8cdc5551 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -164,7 +164,7 @@ class Field(object): Called to set up a field prior to field_to_native or field_from_native. parent - The parent serializer. - model_field - The model field this field corresponds to, if one exists. + field_name - The name of the field being initialized. """ self.parent = parent self.root = parent.root or parent @@ -289,7 +289,7 @@ class WritableField(Field): self.validators = self.default_validators + validators self.default = default if default is not None else self.default - # Widgets are ony used for HTML forms. + # Widgets are only used for HTML forms. widget = widget or self.widget if isinstance(widget, type): widget = widget() @@ -301,6 +301,11 @@ class WritableField(Field): result.validators = self.validators[:] return result + def get_default_value(self): + if is_simple_callable(self.default): + return self.default() + return self.default + def validate(self, value): if value in validators.EMPTY_VALUES and self.required: raise ValidationError(self.error_messages['required']) @@ -349,10 +354,7 @@ class WritableField(Field): except KeyError: if self.default is not None and not self.partial: # Note: partial updates shouldn't set defaults - if is_simple_callable(self.default): - native = self.default() - else: - native = self.default + native = self.get_default_value() else: if self.required: raise ValidationError(self.error_messages['required']) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index de91caed..96d15eb9 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -6,6 +6,7 @@ from __future__ import unicode_literals from django.core.exceptions import ImproperlyConfigured from django.db import models from rest_framework.compat import django_filters, six, guardian, get_model_name +from rest_framework.settings import api_settings from functools import reduce import operator @@ -69,7 +70,8 @@ class DjangoFilterBackend(BaseFilterBackend): class SearchFilter(BaseFilterBackend): - search_param = 'search' # The URL query parameter used for the search. + # The URL query parameter used for the search. + search_param = api_settings.SEARCH_PARAM def get_search_terms(self, request): """ @@ -107,7 +109,8 @@ class SearchFilter(BaseFilterBackend): class OrderingFilter(BaseFilterBackend): - ordering_param = 'ordering' # The URL query parameter used for the ordering. + # The URL query parameter used for the ordering. + ordering_param = api_settings.ORDERING_PARAM ordering_fields = None def get_ordering(self, request): diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 5fbcf700..e1a24dc7 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -116,30 +116,27 @@ class UpdateModelMixin(object): partial = kwargs.pop('partial', False) self.object = self.get_object_or_none() - if self.object is None: - created = True - save_kwargs = {'force_insert': True} - success_status_code = status.HTTP_201_CREATED - else: - created = False - save_kwargs = {'force_update': True} - success_status_code = status.HTTP_200_OK - serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES, partial=partial) - if serializer.is_valid(): - try: - self.pre_save(serializer.object) - except ValidationError as err: - # full_clean on model instance may be called in pre_save, so we - # have to handle eventual errors. - return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST) - self.object = serializer.save(**save_kwargs) - self.post_save(self.object, created=created) - return Response(serializer.data, status=success_status_code) + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + try: + self.pre_save(serializer.object) + except ValidationError as err: + # full_clean on model instance may be called in pre_save, + # so we have to handle eventual errors. + return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST) + + if self.object is None: + self.object = serializer.save(force_insert=True) + self.post_save(self.object, created=True) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + self.object = serializer.save(force_update=True) + self.post_save(self.object, created=False) + return Response(serializer.data, status=status.HTTP_200_OK) def partial_update(self, request, *args, **kwargs): kwargs['partial'] = True diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index f1b3e38d..4990971b 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -10,7 +10,7 @@ from django.core.files.uploadhandler import StopFutureHandlers from django.http import QueryDict from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter -from rest_framework.compat import etree, six, yaml +from rest_framework.compat import etree, six, yaml, force_text from rest_framework.exceptions import ParseError from rest_framework import renderers import json @@ -288,7 +288,7 @@ class FileUploadParser(BaseParser): try: meta = parser_context['request'].META - disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) - return disposition[1]['filename'] + disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8')) + return force_text(disposition[1]['filename']) except (AttributeError, KeyError): pass diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 02185c2f..3463954d 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -33,6 +33,7 @@ class RelatedField(WritableField): many_widget = widgets.SelectMultiple form_field_class = forms.ChoiceField many_form_field_class = forms.MultipleChoiceField + null_values = (None, '', 'None') cache_choices = False empty_label = None @@ -58,6 +59,8 @@ class RelatedField(WritableField): super(RelatedField, self).__init__(*args, **kwargs) if not self.required: + # Accessed in ModelChoiceIterator django/forms/models.py:1034 + # If set adds empty choice. self.empty_label = BLANK_CHOICE_DASH[0][1] self.queryset = queryset @@ -118,6 +121,14 @@ class RelatedField(WritableField): choices = property(_get_choices, _set_choices) + ### Default value handling + + def get_default_value(self): + default = super(RelatedField, self).get_default_value() + if self.many and default is None: + return [] + return default + ### Regular serializer stuff... def field_to_native(self, obj, field_name): @@ -166,11 +177,11 @@ class RelatedField(WritableField): except KeyError: if self.partial: return - value = [] if self.many else None + value = self.get_default_value() - if value in (None, '') and self.required: - raise ValidationError(self.error_messages['required']) - elif value in (None, ''): + if value in self.null_values: + if self.required: + raise ValidationError(self.error_messages['required']) into[(self.source or field_name)] = None elif self.many: into[(self.source or field_name)] = [self.from_native(item) for item in value] diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index e8afc26d..484961ad 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -146,7 +146,7 @@ class XMLRenderer(BaseRenderer): def render(self, data, accepted_media_type=None, renderer_context=None): """ - Renders *obj* into serialized XML. + Renders `data` into serialized XML. """ if data is None: return '' @@ -193,17 +193,26 @@ class YAMLRenderer(BaseRenderer): format = 'yaml' encoder = encoders.SafeDumper charset = 'utf-8' + ensure_ascii = True def render(self, data, accepted_media_type=None, renderer_context=None): """ - Renders *obj* into serialized YAML. + Renders `data` into serialized YAML. """ assert yaml, 'YAMLRenderer requires pyyaml to be installed' if data is None: return '' - return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder) + return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii) + + +class UnicodeYAMLRenderer(YAMLRenderer): + """ + Renderer which serializes to YAML. + Does *not* apply character escaping for non-ascii characters. + """ + ensure_ascii = False class TemplateHTMLRenderer(BaseRenderer): @@ -427,7 +436,7 @@ class BrowsableAPIRenderer(BaseRenderer): files = request.FILES except ParseError: data = None - files = None + files = None else: data = None files = None @@ -544,6 +553,14 @@ class BrowsableAPIRenderer(BaseRenderer): raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request) raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form + response_headers = dict(response.items()) + renderer_content_type = '' + if renderer: + renderer_content_type = '%s' % renderer.media_type + if renderer.charset: + renderer_content_type += ' ;%s' % renderer.charset + response_headers['Content-Type'] = renderer_content_type + context = { 'content': self.get_content(renderer, data, accepted_media_type, renderer_context), 'view': view, @@ -555,6 +572,7 @@ class BrowsableAPIRenderer(BaseRenderer): 'breadcrumblist': self.get_breadcrumbs(request), 'allowed_methods': view.allowed_methods, 'available_formats': [renderer.format for renderer in view.renderer_classes], + 'response_headers': response_headers, 'put_form': self.get_rendered_html_form(view, 'PUT', request), 'post_form': self.get_rendered_html_form(view, 'POST', request), diff --git a/rest_framework/request.py b/rest_framework/request.py index ca70b49e..40467c03 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -346,7 +346,7 @@ class Request(object): media_type = self.content_type if stream is None or media_type is None: - empty_data = QueryDict('', self._request._encoding) + empty_data = QueryDict('', encoding=self._request._encoding) empty_files = MultiValueDict() return (empty_data, empty_files) @@ -362,7 +362,7 @@ class Request(object): # re-raise. Ensures we don't simply repeat the error when # attempting to render the browsable renderer response, or when # logging the request or similar. - self._data = QueryDict('', self._request._encoding) + self._data = QueryDict('', encoding=self._request._encoding) self._files = MultiValueDict() raise diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py index da36d23f..2daaae4e 100755 --- a/rest_framework/runtests/runtests.py +++ b/rest_framework/runtests/runtests.py @@ -26,6 +26,10 @@ def usage(): def main(): + try: + django.setup() + except AttributeError: + pass TestRunner = get_runner(settings) test_runner = TestRunner() diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 10256d47..9cb548a5 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -16,6 +16,7 @@ import datetime import inspect import types from decimal import Decimal +from django.contrib.contenttypes.generic import GenericForeignKey from django.core.paginator import Page from django.db import models from django.forms import widgets @@ -438,16 +439,6 @@ class BaseSerializer(WritableField): raise ValidationError(self.error_messages['required']) return - # Set the serializer object if it exists - obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None - - # If we have a model manager or similar object then we need - # to iterate through each instance. - if (self.many and - not hasattr(obj, '__iter__') and - is_simple_callable(getattr(obj, 'all', None))): - obj = obj.all() - if self.source == '*': if value: reverted_data = self.restore_fields(value, {}) @@ -457,6 +448,16 @@ class BaseSerializer(WritableField): if value in (None, ''): into[(self.source or field_name)] = None else: + # Set the serializer object if it exists + obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None + + # If we have a model manager or similar object then we need + # to iterate through each instance. + if (self.many and + not hasattr(obj, '__iter__') and + is_simple_callable(getattr(obj, 'all', None))): + obj = obj.all() + kwargs = { 'instance': obj, 'data': value, @@ -757,8 +758,11 @@ class ModelSerializer(Serializer): field.read_only = True ret[accessor_name] = field + + # Ensure that 'read_only_fields' is an iterable + assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple' - # Add the `read_only` flag to any fields that have bee specified + # Add the `read_only` flag to any fields that have been specified # in the `read_only_fields` option for field_name in self.opts.read_only_fields: assert field_name not in self.base_fields.keys(), ( @@ -771,7 +775,10 @@ class ModelSerializer(Serializer): "on serializer '%s'." % (field_name, self.__class__.__name__)) ret[field_name].read_only = True - + + # Ensure that 'write_only_fields' is an iterable + assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple' + for field_name in self.opts.write_only_fields: assert field_name not in self.base_fields.keys(), ( "field '%s' on serializer '%s' specified in " @@ -821,6 +828,10 @@ class ModelSerializer(Serializer): if model_field: kwargs['required'] = not(model_field.null or model_field.blank) + if model_field.help_text is not None: + kwargs['help_text'] = model_field.help_text + if model_field.verbose_name is not None: + kwargs['label'] = model_field.verbose_name return PrimaryKeyRelatedField(**kwargs) @@ -881,7 +892,7 @@ class ModelSerializer(Serializer): except KeyError: return ModelField(model_field=model_field, **kwargs) - def get_validation_exclusions(self): + def get_validation_exclusions(self, instance=None): """ Return a list of field names to exclude from model validation. """ @@ -893,7 +904,7 @@ class ModelSerializer(Serializer): field_name = field.source or field_name if field_name in exclusions \ and not field.read_only \ - and field.required \ + and (field.required or hasattr(instance, field_name)) \ and not isinstance(field, Serializer): exclusions.remove(field_name) return exclusions @@ -908,7 +919,7 @@ class ModelSerializer(Serializer): the full_clean validation checking. """ try: - instance.full_clean(exclude=self.get_validation_exclusions()) + instance.full_clean(exclude=self.get_validation_exclusions(instance)) except ValidationError as err: self._errors = err.message_dict return None @@ -937,6 +948,8 @@ class ModelSerializer(Serializer): # Forward m2m relations for field in meta.many_to_many + meta.virtual_fields: + if isinstance(field, GenericForeignKey): + continue if field.name in attrs: m2m_data[field.name] = attrs.pop(field.name) @@ -946,17 +959,15 @@ class ModelSerializer(Serializer): if isinstance(self.fields.get(field_name, None), Serializer): nested_forward_relations[field_name] = attrs[field_name] - # Update an existing instance... - if instance is not None: - for key, val in attrs.items(): - try: - setattr(instance, key, val) - except ValueError: - self._errors[key] = self.error_messages['required'] + # Create an empty instance of the model + if instance is None: + instance = self.opts.model() - # ...or create a new instance - else: - instance = self.opts.model(**attrs) + for key, val in attrs.items(): + try: + setattr(instance, key, val) + except ValueError: + self._errors[key] = self.error_messages['required'] # Any relations that cannot be set until we've # saved the model get hidden away on these @@ -1081,6 +1092,10 @@ class HyperlinkedModelSerializer(ModelSerializer): if model_field: kwargs['required'] = not(model_field.null or model_field.blank) + if model_field.help_text is not None: + kwargs['help_text'] = model_field.help_text + if model_field.verbose_name is not None: + kwargs['label'] = model_field.verbose_name if self.opts.lookup_field: kwargs['lookup_field'] = self.opts.lookup_field diff --git a/rest_framework/settings.py b/rest_framework/settings.py index ce171d6d..38753c96 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -69,6 +69,10 @@ DEFAULTS = { 'PAGINATE_BY_PARAM': None, 'MAX_PAGINATE_BY': None, + # Filtering + 'SEARCH_PARAM': 'search', + 'ORDERING_PARAM': 'ordering', + # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index d19d5a2b..7067ee2f 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -118,7 +118,7 @@ </div> <div class="response-info"> <pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %} -{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span> +{% for key, val in response_headers.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span> {% endfor %} </div>{{ content|urlize_quoted_links }}</pre>{% endautoescape %} </div> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 83c046f9..dff176d6 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -6,7 +6,7 @@ from django.utils.encoding import iri_to_uri from django.utils.html import escape from django.utils.safestring import SafeData, mark_safe from rest_framework.compat import urlparse, force_text, six, smart_urlquote -import re, string +import re register = template.Library() @@ -180,7 +180,7 @@ def add_class(value, css_class): # Bunch of stuff cloned from urlize -TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"] +TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}", "'"] WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), ('"', '"'), ("'", "'")] word_split_re = re.compile(r'(\s+)') @@ -189,6 +189,17 @@ simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net simple_email_re = re.compile(r'^\S+@\S+\.\S+$') +def smart_urlquote_wrapper(matched_url): + """ + Simple wrapper for smart_urlquote. ValueError("Invalid IPv6 URL") can + be raised here, see issue #1386 + """ + try: + return smart_urlquote(matched_url) + except ValueError: + return None + + @register.filter def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True): """ @@ -211,7 +222,6 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru safe_input = isinstance(text, SafeData) words = word_split_re.split(force_text(text)) for i, word in enumerate(words): - match = None if '.' in word or '@' in word or ':' in word: # Deal with punctuation. lead, middle, trail = '', word, '' @@ -233,9 +243,9 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru url = None nofollow_attr = ' rel="nofollow"' if nofollow else '' if simple_url_re.match(middle): - url = smart_urlquote(middle) + url = smart_urlquote_wrapper(middle) elif simple_url_2_re.match(middle): - url = smart_urlquote('http://%s' % middle) + url = smart_urlquote_wrapper('http://%s' % middle) elif not ':' in middle and simple_email_re.match(middle): local, domain = middle.rsplit('@', 1) try: diff --git a/rest_framework/test.py b/rest_framework/test.py index 234d10a4..df5a5b3b 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -8,6 +8,7 @@ from django.conf import settings from django.test.client import Client as DjangoClient from django.test.client import ClientHandler from django.test import testcases +from django.utils.http import urlencode from rest_framework.settings import api_settings from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import force_bytes_or_smart_bytes, six @@ -71,6 +72,17 @@ class APIRequestFactory(DjangoRequestFactory): return ret, content_type + def get(self, path, data=None, **extra): + r = { + 'QUERY_STRING': urlencode(data or {}, doseq=True), + } + # Fix to support old behavior where you have the arguments in the url + # See #1461 + if not data and '?' in path: + r['QUERY_STRING'] = path.split('?')[1] + r.update(extra) + return self.generic('GET', path, **r) + def post(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('POST', path, data, content_type, **extra) diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 32a726c0..0256697a 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -103,7 +103,7 @@ class BlogPostComment(RESTFrameworkModel): class Album(RESTFrameworkModel): title = models.CharField(max_length=100, unique=True) - + ref = models.CharField(max_length=10, unique=True, null=True, blank=True) class Photo(RESTFrameworkModel): description = models.TextField() @@ -143,7 +143,8 @@ class ForeignKeyTarget(RESTFrameworkModel): class ForeignKeySource(RESTFrameworkModel): name = models.CharField(max_length=100) - target = models.ForeignKey(ForeignKeyTarget, related_name='sources') + target = models.ForeignKey(ForeignKeyTarget, related_name='sources', + help_text='Target', verbose_name='Target') # Nullable ForeignKey @@ -168,3 +169,10 @@ class NullableOneToOneSource(RESTFrameworkModel): class BasicModelSerializer(serializers.ModelSerializer): class Meta: model = BasicModel + + +# Models to test filters +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() diff --git a/rest_framework/tests/serializers.py b/rest_framework/tests/serializers.py new file mode 100644 index 00000000..cc943c7d --- /dev/null +++ b/rest_framework/tests/serializers.py @@ -0,0 +1,8 @@ +from rest_framework import serializers + +from rest_framework.tests.models import NullableForeignKeySource + + +class NullableFKSourceSerializer(serializers.ModelSerializer): + class Meta: + model = NullableForeignKeySource diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index f072b81b..a1c43d9c 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -3,6 +3,7 @@ from django.contrib.auth.models import User from django.http import HttpResponse from django.test import TestCase from django.utils import unittest +from django.utils.http import urlencode from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions from rest_framework import permissions @@ -18,8 +19,8 @@ from rest_framework.authentication import ( OAuth2Authentication ) from rest_framework.authtoken.models import Token -from rest_framework.compat import patterns, url, include -from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope +from rest_framework.compat import patterns, url, include, six +from rest_framework.compat import oauth2_provider, oauth2_provider_scope from rest_framework.compat import oauth, oauth_provider from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView @@ -53,10 +54,14 @@ urlpatterns = patterns('', permission_classes=[permissions.TokenHasReadWriteScope])) ) +class OAuth2AuthenticationDebug(OAuth2Authentication): + allow_query_params_token = True + if oauth2_provider is not None: urlpatterns += patterns('', url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), + url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])), url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], permission_classes=[permissions.TokenHasReadWriteScope])), ) @@ -190,6 +195,12 @@ class TokenAuthTests(TestCase): token = Token.objects.create(user=self.user) self.assertTrue(bool(token.key)) + def test_generate_key_returns_string(self): + """Ensure generate_key returns a string""" + token = Token() + key = token.generate_key() + self.assertTrue(isinstance(key, six.string_types)) + def test_token_login_json(self): """Ensure token login view using JSON POST works.""" client = APIClient(enforce_csrf_checks=True) @@ -488,7 +499,7 @@ class OAuth2Tests(TestCase): self.ACCESS_TOKEN = "access_token" self.REFRESH_TOKEN = "refresh_token" - self.oauth2_client = oauth2_provider_models.Client.objects.create( + self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create( client_id=self.CLIENT_ID, client_secret=self.CLIENT_SECRET, redirect_uri='', @@ -497,12 +508,12 @@ class OAuth2Tests(TestCase): user=None, ) - self.access_token = oauth2_provider_models.AccessToken.objects.create( + self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create( token=self.ACCESS_TOKEN, client=self.oauth2_client, user=self.user, ) - self.refresh_token = oauth2_provider_models.RefreshToken.objects.create( + self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create( user=self.user, access_token=self.access_token, client=self.oauth2_client @@ -546,6 +557,27 @@ class OAuth2Tests(TestCase): self.assertEqual(response.status_code, 200) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_passing_auth_url_transport(self): + """Ensure GETing form over OAuth with correct client credentials in form data succeed""" + response = self.csrf_client.post('/oauth2-test/', + data={'access_token': self.access_token.token}) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_passing_auth_url_transport(self): + """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True""" + query = urlencode({'access_token': self.access_token.token}) + response = self.csrf_client.get('/oauth2-test-debug/?%s' % query) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_failing_auth_url_transport(self): + """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False""" + query = urlencode({'access_token': self.access_token.token}) + response = self.csrf_client.get('/oauth2-test/?%s' % query) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_post_form_passing_auth(self): """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" auth = self._create_authorization_header() diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py index 18188186..23226bbc 100644 --- a/rest_framework/tests/test_filters.py +++ b/rest_framework/tests/test_filters.py @@ -7,18 +7,15 @@ from django.test import TestCase from django.utils import unittest from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url +from rest_framework.settings import api_settings from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel +from .models import FilterableItem +from .utils import temporary_setting factory = APIRequestFactory() -class FilterableItem(models.Model): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() - - if django_filters: # Basic filter on a list view. class FilterFieldsRootView(generics.ListCreateAPIView): @@ -128,7 +125,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that the decimal filter works. search_decimal = Decimal('2.25') - request = factory.get('/?decimal=%s' % search_decimal) + request = factory.get('/', {'decimal': '%s' % search_decimal}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['decimal'] == search_decimal] @@ -136,7 +133,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that the date filter works. search_date = datetime.date(2012, 9, 22) - request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' + request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-09-22' response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['date'] == search_date] @@ -151,7 +148,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that the decimal filter works. search_decimal = Decimal('2.25') - request = factory.get('/?decimal=%s' % search_decimal) + request = factory.get('/', {'decimal': '%s' % search_decimal}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['decimal'] == search_decimal] @@ -184,7 +181,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that the decimal filter set with 'lt' in the filter class works. search_decimal = Decimal('4.25') - request = factory.get('/?decimal=%s' % search_decimal) + request = factory.get('/', {'decimal': '%s' % search_decimal}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['decimal'] < search_decimal] @@ -192,7 +189,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that the date filter set with 'gt' in the filter class works. search_date = datetime.date(2012, 10, 2) - request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' + request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-10-02' response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['date'] > search_date] @@ -200,7 +197,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that the text filter set with 'icontains' in the filter class works. search_text = 'ff' - request = factory.get('/?text=%s' % search_text) + request = factory.get('/', {'text': '%s' % search_text}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if search_text in f['text'].lower()] @@ -209,7 +206,10 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that multiple filters works. search_decimal = Decimal('5.25') search_date = datetime.date(2012, 10, 2) - request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) + request = factory.get('/', { + 'decimal': '%s' % (search_decimal,), + 'date': '%s' % (search_date,) + }) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['date'] > search_date and @@ -234,7 +234,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): view = FilterFieldsRootView.as_view() search_integer = 10 - request = factory.get('/?integer=%s' % search_integer) + request = factory.get('/', {'integer': '%s' % search_integer}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -265,14 +265,18 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): # Tests that the decimal filter set that should fail. search_decimal = Decimal('4.25') high_item = self.objects.filter(decimal__gt=search_decimal)[0] - response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) + response = self.client.get( + '{url}'.format(url=self._get_url(high_item)), + {'decimal': '{param}'.format(param=search_decimal)}) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) # Tests that the decimal filter set that should succeed. search_decimal = Decimal('4.25') low_item = self.objects.filter(decimal__lt=search_decimal)[0] low_item_data = self._serialize_object(low_item) - response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) + response = self.client.get( + '{url}'.format(url=self._get_url(low_item)), + {'decimal': '{param}'.format(param=search_decimal)}) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, low_item_data) @@ -281,7 +285,11 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): search_date = datetime.date(2012, 10, 2) valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] valid_item_data = self._serialize_object(valid_item) - response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) + response = self.client.get( + '{url}'.format(url=self._get_url(valid_item)), { + 'decimal': '{decimal}'.format(decimal=search_decimal), + 'date': '{date}'.format(date=search_date) + }) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, valid_item_data) @@ -315,7 +323,7 @@ class SearchFilterTests(TestCase): search_fields = ('title', 'text') view = SearchListView.as_view() - request = factory.get('?search=b') + request = factory.get('/', {'search': 'b'}) response = view(request) self.assertEqual( response.data, @@ -332,7 +340,7 @@ class SearchFilterTests(TestCase): search_fields = ('=title', 'text') view = SearchListView.as_view() - request = factory.get('?search=zzz') + request = factory.get('/', {'search': 'zzz'}) response = view(request) self.assertEqual( response.data, @@ -348,7 +356,7 @@ class SearchFilterTests(TestCase): search_fields = ('title', '^text') view = SearchListView.as_view() - request = factory.get('?search=b') + request = factory.get('/', {'search': 'b'}) response = view(request) self.assertEqual( response.data, @@ -357,6 +365,24 @@ class SearchFilterTests(TestCase): ] ) + def test_search_with_nonstandard_search_param(self): + with temporary_setting('SEARCH_PARAM', 'query', module=filters): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text') + + view = SearchListView.as_view() + request = factory.get('/', {'query': 'b'}) + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'z', 'text': 'abc'}, + {'id': 2, 'title': 'zz', 'text': 'bcd'} + ] + ) + class OrdringFilterModel(models.Model): title = models.CharField(max_length=20) @@ -396,7 +422,7 @@ class OrderingFilterTests(TestCase): ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('?ordering=text') + request = factory.get('/', {'ordering': 'text'}) response = view(request) self.assertEqual( response.data, @@ -415,7 +441,7 @@ class OrderingFilterTests(TestCase): ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('?ordering=-text') + request = factory.get('/', {'ordering': '-text'}) response = view(request) self.assertEqual( response.data, @@ -434,7 +460,7 @@ class OrderingFilterTests(TestCase): ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('?ordering=foobar') + request = factory.get('/', {'ordering': 'foobar'}) response = view(request) self.assertEqual( response.data, @@ -503,7 +529,7 @@ class OrderingFilterTests(TestCase): models.Count("relateds")) view = OrderingListView.as_view() - request = factory.get('?ordering=relateds__count') + request = factory.get('/', {'ordering': 'relateds__count'}) response = view(request) self.assertEqual( response.data, @@ -514,6 +540,26 @@ class OrderingFilterTests(TestCase): ] ) + def test_ordering_with_nonstandard_ordering_param(self): + with temporary_setting('ORDERING_PARAM', 'order', filters): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + ordering_fields = ('text',) + + view = OrderingListView.as_view() + request = factory.get('/', {'order': 'text'}) + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + ] + ) + class SensitiveOrderingFilterModel(models.Model): username = models.CharField(max_length=20) @@ -566,7 +612,7 @@ class SensitiveOrderingFilterTests(TestCase): serializer_class = serializer_cls view = OrderingListView.as_view() - request = factory.get('?ordering=-username') + request = factory.get('/', {'ordering': '-username'}) response = view(request) if serializer_cls == SensitiveDataSerializer3: @@ -596,7 +642,7 @@ class SensitiveOrderingFilterTests(TestCase): serializer_class = serializer_cls view = OrderingListView.as_view() - request = factory.get('?ordering=password') + request = factory.get('/', {'ordering': 'password'}) response = view(request) if serializer_cls == SensitiveDataSerializer3: @@ -612,4 +658,4 @@ class SensitiveOrderingFilterTests(TestCase): {'id': 2, username_field: 'userB'}, # PassC {'id': 3, username_field: 'userC'}, # PassA ] - )
\ No newline at end of file + ) diff --git a/rest_framework/tests/test_genericrelations.py b/rest_framework/tests/test_genericrelations.py index 2d341344..46a2d863 100644 --- a/rest_framework/tests/test_genericrelations.py +++ b/rest_framework/tests/test_genericrelations.py @@ -4,8 +4,10 @@ from django.contrib.contenttypes.generic import GenericRelation, GenericForeignK from django.db import models from django.test import TestCase from rest_framework import serializers +from rest_framework.compat import python_2_unicode_compatible +@python_2_unicode_compatible class Tag(models.Model): """ Tags have a descriptive slug, and are attached to an arbitrary object. @@ -15,10 +17,11 @@ class Tag(models.Model): object_id = models.PositiveIntegerField() tagged_item = GenericForeignKey('content_type', 'object_id') - def __unicode__(self): + def __str__(self): return self.tag +@python_2_unicode_compatible class Bookmark(models.Model): """ A URL bookmark that may have multiple tags attached. @@ -26,10 +29,11 @@ class Bookmark(models.Model): url = models.URLField() tags = GenericRelation(Tag) - def __unicode__(self): + def __str__(self): return 'Bookmark: %s' % self.url +@python_2_unicode_compatible class Note(models.Model): """ A textual note that may have multiple tags attached. @@ -37,7 +41,7 @@ class Note(models.Model): text = models.TextField() tags = GenericRelation(Tag) - def __unicode__(self): + def __str__(self): return 'Note: %s' % self.text @@ -127,3 +131,21 @@ class TestGenericRelations(TestCase): } ] self.assertEqual(serializer.data, expected) + + def test_restore_object_generic_fk(self): + """ + Ensure an object with a generic foreign key can be restored. + """ + + class TagSerializer(serializers.ModelSerializer): + class Meta: + model = Tag + exclude = ('content_type', 'object_id') + + serializer = TagSerializer() + + bookmark = Bookmark(url='http://example.com') + attrs = {'tagged_item': bookmark, 'tag': 'example'} + + tag = serializer.restore_object(attrs) + self.assertEqual(tag.tagged_item, bookmark) diff --git a/rest_framework/tests/test_htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py index 8957a43c..514d9e2b 100644 --- a/rest_framework/tests/test_htmlrenderer.py +++ b/rest_framework/tests/test_htmlrenderer.py @@ -50,7 +50,7 @@ class TemplateHTMLRendererTests(TestCase): """ self.get_template = django.template.loader.get_template - def get_template(template_name): + def get_template(template_name, dirs=None): if template_name == 'example.html': return Template("example: {{ object }}") raise TemplateDoesNotExist(template_name) @@ -108,11 +108,13 @@ class TemplateHTMLRendererExceptionTests(TestCase): def test_not_found_html_view_with_template(self): response = self.client.get('/not_found') self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - self.assertEqual(response.content, six.b("404: Not found")) + self.assertTrue(response.content in ( + six.b("404: Not found"), six.b("404 Not Found"))) self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') def test_permission_denied_html_view_with_template(self): response = self.client.get('/permission_denied') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEqual(response.content, six.b("403: Permission denied")) + self.assertTrue(response.content in ( + six.b("403: Permission denied"), six.b("403 Forbidden"))) self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') diff --git a/rest_framework/tests/test_nullable_fields.py b/rest_framework/tests/test_nullable_fields.py new file mode 100644 index 00000000..6ee55c00 --- /dev/null +++ b/rest_framework/tests/test_nullable_fields.py @@ -0,0 +1,30 @@ +from django.core.urlresolvers import reverse + +from rest_framework.compat import patterns, url +from rest_framework.test import APITestCase +from rest_framework.tests.models import NullableForeignKeySource +from rest_framework.tests.serializers import NullableFKSourceSerializer +from rest_framework.tests.views import NullableFKSourceDetail + + +urlpatterns = patterns( + '', + url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'), +) + + +class NullableForeignKeyTests(APITestCase): + """ + DRF should be able to handle nullable foreign keys when a test + Client POST/PUT request is made with its own serialized object. + """ + urls = 'rest_framework.tests.test_nullable_fields' + + def test_updating_object_with_null_fk(self): + obj = NullableForeignKeySource(name='example', target=None) + obj.save() + serialized_data = NullableFKSourceSerializer(obj).data + + response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data) + + self.assertEqual(response.data, serialized_data) diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index cadb515f..24c1ba39 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -9,14 +9,18 @@ from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel +from .models import FilterableItem factory = APIRequestFactory() +# Helper function to split arguments out of an url +def split_arguments_from_url(url): + if '?' not in url: + return url -class FilterableItem(models.Model): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() + path, args = url.split('?') + args = dict(r.split('=') for r in args.split('&')) + return path, args class RootView(generics.ListCreateAPIView): @@ -84,7 +88,7 @@ class IntegrationTestPagination(TestCase): self.assertNotEqual(response.data['next'], None) self.assertEqual(response.data['previous'], None) - request = factory.get(response.data['next']) + request = factory.get(*split_arguments_from_url(response.data['next'])) with self.assertNumQueries(2): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -93,7 +97,7 @@ class IntegrationTestPagination(TestCase): self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['previous'], None) - request = factory.get(response.data['next']) + request = factory.get(*split_arguments_from_url(response.data['next'])) with self.assertNumQueries(2): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -146,7 +150,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): EXPECTED_NUM_QUERIES = 2 - request = factory.get('/?decimal=15.20') + request = factory.get('/', {'decimal': '15.20'}) with self.assertNumQueries(EXPECTED_NUM_QUERIES): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -155,7 +159,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.assertNotEqual(response.data['next'], None) self.assertEqual(response.data['previous'], None) - request = factory.get(response.data['next']) + request = factory.get(*split_arguments_from_url(response.data['next'])) with self.assertNumQueries(EXPECTED_NUM_QUERIES): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -164,7 +168,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.assertEqual(response.data['next'], None) self.assertNotEqual(response.data['previous'], None) - request = factory.get(response.data['previous']) + request = factory.get(*split_arguments_from_url(response.data['previous'])) with self.assertNumQueries(EXPECTED_NUM_QUERIES): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -191,7 +195,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): view = BasicFilterFieldsRootView.as_view() - request = factory.get('/?decimal=15.20') + request = factory.get('/', {'decimal': '15.20'}) with self.assertNumQueries(2): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -200,7 +204,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.assertNotEqual(response.data['next'], None) self.assertEqual(response.data['previous'], None) - request = factory.get(response.data['next']) + request = factory.get(*split_arguments_from_url(response.data['next'])) with self.assertNumQueries(2): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -209,7 +213,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.assertEqual(response.data['next'], None) self.assertNotEqual(response.data['previous'], None) - request = factory.get(response.data['previous']) + request = factory.get(*split_arguments_from_url(response.data['previous'])) with self.assertNumQueries(2): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -317,7 +321,7 @@ class TestCustomPaginateByParam(TestCase): """ If paginate_by_param is set, the new kwarg should limit per view requests. """ - request = factory.get('/?page_size=5') + request = factory.get('/', {'page_size': 5}) response = self.view(request).render() self.assertEqual(response.data['count'], 13) self.assertEqual(response.data['results'], self.data[:5]) @@ -345,7 +349,7 @@ class TestMaxPaginateByParam(TestCase): """ If max_paginate_by is set, it should limit page size for the view. """ - request = factory.get('/?page_size=10') + request = factory.get('/', data={'page_size': 10}) response = self.view(request).render() self.assertEqual(response.data['count'], 13) self.assertEqual(response.data['results'], self.data[:5]) diff --git a/rest_framework/tests/test_parsers.py b/rest_framework/tests/test_parsers.py index 7699e10c..8af90677 100644 --- a/rest_framework/tests/test_parsers.py +++ b/rest_framework/tests/test_parsers.py @@ -96,7 +96,7 @@ class TestFileUploadParser(TestCase): request = MockRequest() request.upload_handlers = (MemoryFileUploadHandler(),) request.META = { - 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'), + 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt', 'HTTP_CONTENT_LENGTH': 14, } self.parser_context = {'request': request, 'kwargs': {}} @@ -112,4 +112,4 @@ class TestFileUploadParser(TestCase): def test_get_filename(self): parser = FileUploadParser() filename = parser.get_filename(self.stream, None, self.parser_context) - self.assertEqual(filename, 'file.txt'.encode('utf-8')) + self.assertEqual(filename, 'file.txt') diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py index f52e0e1e..37ac826b 100644 --- a/rest_framework/tests/test_relations.py +++ b/rest_framework/tests/test_relations.py @@ -2,8 +2,10 @@ General tests for relational fields. """ from __future__ import unicode_literals +from django import get_version from django.db import models from django.test import TestCase +from django.utils import unittest from rest_framework import serializers from rest_framework.tests.models import BlogPost @@ -118,3 +120,25 @@ class RelatedFieldSourceTests(TestCase): (serializers.ModelSerializer,), attrs) with self.assertRaises(AttributeError): TestSerializer(data={'name': 'foo'}) + +@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6') +class RelatedFieldChoicesTests(TestCase): + """ + Tests for #1408 "Web browseable API doesn't have blank option on drop down list box" + https://github.com/tomchristie/django-rest-framework/issues/1408 + """ + def test_blank_option_is_added_to_choice_if_required_equals_false(self): + """ + + """ + post = BlogPost(title="Checking blank option is added") + post.save() + + queryset = BlogPost.objects.all() + field = serializers.RelatedField(required=False, queryset=queryset) + + choice_count = BlogPost.objects.count() + widget_count = len(field.widget.choices) + + self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added') + diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index d393b0c3..4d9da489 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -3,9 +3,7 @@ from django.db import models from django.test import TestCase from rest_framework import serializers - -class OneToOneTarget(models.Model): - name = models.CharField(max_length=100) +from .models import OneToOneTarget class OneToOneSource(models.Model): diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py index fb33df2c..7cb7d0f9 100644 --- a/rest_framework/tests/test_renderers.py +++ b/rest_framework/tests/test_renderers.py @@ -12,7 +12,7 @@ from rest_framework.compat import yaml, etree, patterns, url, include, six, Stri from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ - XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer + XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer, UnicodeYAMLRenderer from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.settings import api_settings from rest_framework.test import APIRequestFactory @@ -256,6 +256,18 @@ class RendererEndToEndTests(TestCase): self.assertEqual(resp.get('Content-Type', None), None) self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) + def test_contains_headers_of_api_response(self): + """ + Issue #1437 + + Test we display the headers of the API response and not those from the + HTML response + """ + resp = self.client.get('/html1') + self.assertContains(resp, '>GET, HEAD, OPTIONS<') + self.assertContains(resp, '>application/json<') + self.assertNotContains(resp, '>text/html; charset=utf-8<') + _flat_repr = '{"foo": ["bar", "baz"]}' _indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}' @@ -455,6 +467,17 @@ if yaml: self.assertTrue(string in content, '%r not in %r' % (string, content)) + class UnicodeYAMLRendererTests(TestCase): + """ + Tests specific for the Unicode YAML Renderer + """ + def test_proper_encoding(self): + obj = {'countries': ['United Kingdom', 'France', 'España']} + renderer = UnicodeYAMLRenderer() + content = renderer.render(obj, 'application/yaml') + self.assertEqual(content.strip(), 'countries: [United Kingdom, France, España]'.encode('utf-8')) + + class XMLRendererTestCase(TestCase): """ Tests specific to the XML Renderer @@ -601,6 +624,10 @@ class CacheRenderTest(TestCase): method = getattr(self.client, http_method) resp = method(url) del resp.client, resp.request + try: + del resp.wsgi_request + except AttributeError: + pass return resp def test_obj_pickling(self): diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 6b1e333e..e688c823 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -3,15 +3,42 @@ from __future__ import unicode_literals from django.db import models from django.db.models.fields import BLANK_CHOICE_DASH from django.test import TestCase +from django.utils import unittest from django.utils.datastructures import MultiValueDict from django.utils.translation import ugettext_lazy as _ from rest_framework import serializers, fields, relations from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, - ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel) + ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel, + ForeignKeySource, ManyToManySource) from rest_framework.tests.models import BasicModelSerializer import datetime import pickle +try: + import PIL +except: + PIL = None + + +if PIL is not None: + class AMOAFModel(RESTFrameworkModel): + char_field = models.CharField(max_length=1024, blank=True) + comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) + decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) + email_field = models.EmailField(max_length=1024, blank=True) + file_field = models.FileField(upload_to='test', max_length=1024, blank=True) + image_field = models.ImageField(upload_to='test', max_length=1024, blank=True) + slug_field = models.SlugField(max_length=1024, blank=True) + url_field = models.URLField(max_length=1024, blank=True) + + class DVOAFModel(RESTFrameworkModel): + positive_integer_field = models.PositiveIntegerField(blank=True) + positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) + email_field = models.EmailField(blank=True) + file_field = models.FileField(upload_to='test', blank=True) + image_field = models.ImageField(upload_to='test', blank=True) + slug_field = models.SlugField(blank=True) + url_field = models.URLField(blank=True) class SubComment(object): @@ -141,7 +168,7 @@ class AlbumsSerializer(serializers.ModelSerializer): class Meta: model = Album - fields = ['title'] # lists are also valid options + fields = ['title', 'ref'] # lists are also valid options class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): @@ -150,6 +177,16 @@ class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): fields = ['some_integer'] +class ForeignKeySourceSerializer(serializers.ModelSerializer): + class Meta: + model = ForeignKeySource + + +class HyperlinkedForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = ForeignKeySource + + class BasicTests(TestCase): def setUp(self): self.comment = Comment( @@ -482,6 +519,32 @@ class ValidationTests(TestCase): ) self.assertEqual(serializer.is_valid(), True) + def test_writable_star_source_on_nested_serializer_with_parent_object(self): + class TitleSerializer(serializers.Serializer): + title = serializers.WritableField(source='title') + + class AlbumSerializer(serializers.ModelSerializer): + nested = TitleSerializer(source='*') + + class Meta: + model = Album + fields = ('nested',) + + class PhotoSerializer(serializers.ModelSerializer): + album = AlbumSerializer(source='album') + + class Meta: + model = Photo + fields = ('album', ) + + photo = Photo(album=Album()) + + data = {'album': {'nested': {'title': 'test'}}} + + serializer = PhotoSerializer(photo, data=data) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.data, data) + def test_writable_star_source_with_inner_source_fields(self): """ Tests that a serializer with source="*" correctly expands the @@ -591,12 +654,15 @@ class ModelValidationTests(TestCase): """ Just check if serializers.ModelSerializer handles unique checks via .full_clean() """ - serializer = AlbumsSerializer(data={'title': 'a'}) + serializer = AlbumsSerializer(data={'title': 'a', 'ref': '1'}) serializer.is_valid() serializer.save() second_serializer = AlbumsSerializer(data={'title': 'a'}) self.assertFalse(second_serializer.is_valid()) - self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']}) + self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.'],}) + third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}]) + self.assertFalse(third_serializer.is_valid()) + self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}]) def test_foreign_key_is_null_with_partial(self): """ @@ -880,6 +946,58 @@ class DefaultValueTests(TestCase): self.assertEqual(instance.text, 'overridden') +class WritableFieldDefaultValueTests(TestCase): + + def setUp(self): + self.expected = {'default': 'value'} + self.create_field = fields.WritableField + + def test_get_default_value_with_noncallable(self): + field = self.create_field(default=self.expected) + got = field.get_default_value() + self.assertEqual(got, self.expected) + + def test_get_default_value_with_callable(self): + field = self.create_field(default=lambda : self.expected) + got = field.get_default_value() + self.assertEqual(got, self.expected) + + def test_get_default_value_when_not_required(self): + field = self.create_field(default=self.expected, required=False) + got = field.get_default_value() + self.assertEqual(got, self.expected) + + def test_get_default_value_returns_None(self): + field = self.create_field() + got = field.get_default_value() + self.assertIsNone(got) + + def test_get_default_value_returns_non_True_values(self): + values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause + for expected in values: + field = self.create_field(default=expected) + got = field.get_default_value() + self.assertEqual(got, expected) + + +class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests): + + def setUp(self): + self.expected = {'foo': 'bar'} + self.create_field = relations.RelatedField + + def test_get_default_value_returns_empty_list(self): + field = self.create_field(many=True) + got = field.get_default_value() + self.assertListEqual(got, []) + + def test_get_default_value_returns_expected(self): + expected = [1, 2, 3] + field = self.create_field(many=True, default=expected) + got = field.get_default_value() + self.assertListEqual(got, expected) + + class CallableDefaultValueTests(TestCase): def setUp(self): class CallableDefaultValueSerializer(serializers.ModelSerializer): @@ -1493,18 +1611,23 @@ class ManyFieldHelpTextTest(TestCase): self.assertEqual('Some help text.', rel_field.help_text) +class AttributeMappingOnAutogeneratedRelatedFields(TestCase): + + def test_primary_key_related_field(self): + serializer = ForeignKeySourceSerializer() + self.assertEqual(serializer.fields['target'].help_text, 'Target') + self.assertEqual(serializer.fields['target'].label, 'Target') + + def test_hyperlinked_related_field(self): + serializer = HyperlinkedForeignKeySourceSerializer() + self.assertEqual(serializer.fields['target'].help_text, 'Target') + self.assertEqual(serializer.fields['target'].label, 'Target') + + +@unittest.skipUnless(PIL is not None, 'PIL is not installed') class AttributeMappingOnAutogeneratedFieldsTests(TestCase): def setUp(self): - class AMOAFModel(RESTFrameworkModel): - char_field = models.CharField(max_length=1024, blank=True) - comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) - decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) - email_field = models.EmailField(max_length=1024, blank=True) - file_field = models.FileField(max_length=1024, blank=True) - image_field = models.ImageField(max_length=1024, blank=True) - slug_field = models.SlugField(max_length=1024, blank=True) - url_field = models.URLField(max_length=1024, blank=True) class AMOAFSerializer(serializers.ModelSerializer): class Meta: @@ -1574,17 +1697,10 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase): self.field_test('url_field') +@unittest.skipUnless(PIL is not None, 'PIL is not installed') class DefaultValuesOnAutogeneratedFieldsTests(TestCase): def setUp(self): - class DVOAFModel(RESTFrameworkModel): - positive_integer_field = models.PositiveIntegerField(blank=True) - positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) - email_field = models.EmailField(blank=True) - file_field = models.FileField(blank=True) - image_field = models.ImageField(blank=True) - slug_field = models.SlugField(blank=True) - url_field = models.URLField(blank=True) class DVOAFSerializer(serializers.ModelSerializer): class Meta: diff --git a/rest_framework/tests/test_templatetags.py b/rest_framework/tests/test_templatetags.py index 609a9e08..d4da0c23 100644 --- a/rest_framework/tests/test_templatetags.py +++ b/rest_framework/tests/test_templatetags.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals from django.test import TestCase from rest_framework.test import APIRequestFactory -from rest_framework.templatetags.rest_framework import add_query_param +from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links factory = APIRequestFactory() @@ -17,3 +17,35 @@ class TemplateTagTests(TestCase): json_url = add_query_param(request, "format", "json") self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url) self.assertIn("format=json", json_url) + + +class Issue1386Tests(TestCase): + """ + Covers #1386 + """ + + def test_issue_1386(self): + """ + Test function urlize_quoted_links with different args + """ + correct_urls = [ + "asdf.com", + "asdf.net", + "www.as_df.org", + "as.d8f.ghj8.gov", + ] + for i in correct_urls: + res = urlize_quoted_links(i) + self.assertNotEqual(res, i) + self.assertIn(i, res) + + incorrect_urls = [ + "mailto://asdf@fdf.com", + "asdf.netnet", + ] + for i in incorrect_urls: + res = urlize_quoted_links(i) + self.assertEqual(i, res) + + # example from issue #1386, this shouldn't raise an exception + _ = urlize_quoted_links("asdf:[/p]zxcv.com") diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 71bd8b55..a55d4b22 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -152,3 +152,13 @@ class TestAPIRequestFactory(TestCase): simple_png.name = 'test.png' factory = APIRequestFactory() factory.post('/', data={'image': simple_png}) + + def test_request_factory_url_arguments(self): + """ + This is a non regression test against #1461 + """ + factory = APIRequestFactory() + request = factory.get('/view/?demo=test') + self.assertEqual(dict(request.GET), {'demo': ['test']}) + request = factory.get('/view/', {'demo': 'test'}) + self.assertEqual(dict(request.GET), {'demo': ['test']}) diff --git a/rest_framework/tests/test_urlizer.py b/rest_framework/tests/test_urlizer.py new file mode 100644 index 00000000..3dc8e8fe --- /dev/null +++ b/rest_framework/tests/test_urlizer.py @@ -0,0 +1,38 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.templatetags.rest_framework import urlize_quoted_links +import sys + + +class URLizerTests(TestCase): + """ + Test if both JSON and YAML URLs are transformed into links well + """ + def _urlize_dict_check(self, data): + """ + For all items in dict test assert that the value is urlized key + """ + for original, urlized in data.items(): + assert urlize_quoted_links(original, nofollow=False) == urlized + + def test_json_with_url(self): + """ + Test if JSON URLs are transformed into links well + """ + data = {} + data['"url": "http://api/users/1/", '] = \ + '"url": "<a href="http://api/users/1/">http://api/users/1/</a>", ' + data['"foo_set": [\n "http://api/foos/1/"\n], '] = \ + '"foo_set": [\n "<a href="http://api/foos/1/">http://api/foos/1/</a>"\n], ' + self._urlize_dict_check(data) + + def test_yaml_with_url(self): + """ + Test if YAML URLs are transformed into links well + """ + data = {} + data['''{users: 'http://api/users/'}'''] = \ + '''{users: '<a href="http://api/users/">http://api/users/</a>'}''' + data['''foo_set: ['http://api/foos/1/']'''] = \ + '''foo_set: ['<a href="http://api/foos/1/">http://api/foos/1/</a>']''' + self._urlize_dict_check(data) diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py index 124c874d..e13e4078 100644 --- a/rest_framework/tests/test_validation.py +++ b/rest_framework/tests/test_validation.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals +from django.core.validators import MaxValueValidator from django.db import models from django.test import TestCase from rest_framework import generics, serializers, status @@ -102,3 +103,46 @@ class TestAvoidValidation(TestCase): self.assertFalse(serializer.is_valid()) self.assertDictEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) + + +# regression tests for issue: 1493 + +class ValidationMaxValueValidatorModel(models.Model): + number_value = models.PositiveIntegerField(validators=[MaxValueValidator(100)]) + + +class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer): + class Meta: + model = ValidationMaxValueValidatorModel + + +class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView): + model = ValidationMaxValueValidatorModel + serializer_class = ValidationMaxValueValidatorModelSerializer + + +class TestMaxValueValidatorValidation(TestCase): + + def test_max_value_validation_serializer_success(self): + serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 99}) + self.assertTrue(serializer.is_valid()) + + def test_max_value_validation_serializer_fails(self): + serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101}) + self.assertFalse(serializer.is_valid()) + self.assertDictEqual({'number_value': ['Ensure this value is less than or equal to 100.']}, serializer.errors) + + def test_max_value_validation_success(self): + obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) + request = factory.patch('/{0}'.format(obj.pk), {'number_value': 98}, format='json') + view = UpdateMaxValueValidationModel().as_view() + response = view(request, pk=obj.pk).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_max_value_validation_fail(self): + obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) + request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json') + view = UpdateMaxValueValidationModel().as_view() + response = view(request, pk=obj.pk).render() + self.assertEqual(response.content, b'{"number_value": ["Ensure this value is less than or equal to 100."]}') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py new file mode 100644 index 00000000..a8f2eb0b --- /dev/null +++ b/rest_framework/tests/utils.py @@ -0,0 +1,25 @@ +from contextlib import contextmanager +from rest_framework.compat import six +from rest_framework.settings import api_settings + + +@contextmanager +def temporary_setting(setting, value, module=None): + """ + Temporarily change value of setting for test. + + Optionally reload given module, useful when module uses value of setting on + import. + """ + original_value = getattr(api_settings, setting) + setattr(api_settings, setting, value) + + if module is not None: + six.moves.reload_module(module) + + yield + + setattr(api_settings, setting, original_value) + + if module is not None: + six.moves.reload_module(module) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py new file mode 100644 index 00000000..3917b74a --- /dev/null +++ b/rest_framework/tests/views.py @@ -0,0 +1,8 @@ +from rest_framework import generics +from rest_framework.tests.models import NullableForeignKeySource +from rest_framework.tests.serializers import NullableFKSourceSerializer + + +class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView): + model = NullableForeignKeySource + model_serializer_class = NullableFKSourceSerializer diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index c36b58bf..91be9cfd 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -136,6 +136,8 @@ class SimpleRateThrottle(BaseThrottle): remaining_duration = self.duration available_requests = self.num_requests - len(self.history) + 1 + if available_requests <= 0: + return None return remaining_duration / float(available_requests) diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py index c09c2933..92f99efd 100644 --- a/rest_framework/utils/mediatypes.py +++ b/rest_framework/utils/mediatypes.py @@ -74,7 +74,7 @@ class _MediaType(object): return 0 elif self.sub_type == '*': return 1 - elif not self.params or self.params.keys() == ['q']: + elif not self.params or list(self.params.keys()) == ['q']: return 2 return 3 diff --git a/rest_framework/views.py b/rest_framework/views.py index 02a6e25a..a2668f2c 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -131,7 +131,7 @@ class APIView(View): """ If request is not permitted, determine what kind of exception to raise. """ - if not self.request.successful_authenticator: + if not request.successful_authenticator: raise exceptions.NotAuthenticated() raise exceptions.PermissionDenied() @@ -295,7 +295,7 @@ class APIView(View): # Dispatch methods - def initialize_request(self, request, *args, **kargs): + def initialize_request(self, request, *args, **kwargs): """ Returns the initial request object. """ |
