diff options
| author | Xavier Ordoquy | 2014-04-13 00:05:57 +0200 |
|---|---|---|
| committer | Xavier Ordoquy | 2014-04-13 00:05:57 +0200 |
| commit | d08536ad9d026fb7126c430f6d9c18f8540aacd6 (patch) | |
| tree | a8a1d36ce76867e57da23379694ea0609801990b /rest_framework | |
| parent | 2911cd64ad67ba193e3d37322ee71692cb482623 (diff) | |
| parent | 93b9245b8714287a440023451ff7880a2f6e5b32 (diff) | |
| download | django-rest-framework-d08536ad9d026fb7126c430f6d9c18f8540aacd6.tar.bz2 | |
Merge remote-tracking branch 'origin/master' into 2.4.0
Conflicts:
.travis.yml
docs/api-guide/fields.md
docs/api-guide/routers.md
docs/topics/release-notes.md
rest_framework/authentication.py
rest_framework/serializers.py
rest_framework/templatetags/rest_framework.py
rest_framework/tests/test_authentication.py
rest_framework/tests/test_filters.py
rest_framework/tests/test_hyperlinkedserializers.py
rest_framework/tests/test_serializer.py
rest_framework/tests/test_testing.py
rest_framework/utils/encoders.py
tox.ini
Diffstat (limited to 'rest_framework')
55 files changed, 1169 insertions, 232 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index f5483b9d..2d76b55d 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -8,10 +8,10 @@ ______ _____ _____ _____ __ _ """ __title__ = 'Django REST framework' -__version__ = '2.3.10' +__version__ = '2.3.13' __author__ = 'Tom Christie' __license__ = 'BSD 2-Clause' -__copyright__ = 'Copyright 2011-2013 Tom Christie' +__copyright__ = 'Copyright 2011-2014 Tom Christie' # Version synonym VERSION = __version__ diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 1f8d37fa..cbc83574 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -7,6 +7,7 @@ import base64 from django.contrib.auth import authenticate from django.core.exceptions import ImproperlyConfigured from django.middleware.csrf import CsrfViewMiddleware +from django.conf import settings from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import oauth, oauth_provider, oauth_provider_store from rest_framework.compat import oauth2_provider, provider_now, check_nonce @@ -291,6 +292,7 @@ class OAuth2Authentication(BaseAuthentication): OAuth 2 authentication backend using `django-oauth2-provider` """ www_authenticate_realm = 'api' + allow_query_params_token = settings.DEBUG def __init__(self, *args, **kwargs): super(OAuth2Authentication, self).__init__(*args, **kwargs) @@ -308,7 +310,13 @@ class OAuth2Authentication(BaseAuthentication): auth = get_authorization_header(request).split() - if not auth or auth[0].lower() != b'bearer': + if auth and auth[0].lower() == b'bearer': + access_token = auth[1] + elif 'access_token' in request.POST: + access_token = request.POST['access_token'] + elif 'access_token' in request.GET and self.allow_query_params_token: + access_token = request.GET['access_token'] + else: return None if len(auth) == 1: @@ -318,7 +326,7 @@ class OAuth2Authentication(BaseAuthentication): msg = 'Invalid bearer header. Token string should not contain spaces.' raise exceptions.AuthenticationFailed(msg) - return self.authenticate_credentials(request, auth[1]) + return self.authenticate_credentials(request, access_token) def authenticate_credentials(self, request, access_token): """ @@ -326,11 +334,11 @@ class OAuth2Authentication(BaseAuthentication): """ try: - token = oauth2_provider.models.AccessToken.objects.select_related('user') + token = oauth2_provider.oauth2.models.AccessToken.objects.select_related('user') # provider_now switches to timezone aware datetime when # the oauth2_provider version supports to it. token = token.get(token=access_token, expires__gt=provider_now()) - except oauth2_provider.models.AccessToken.DoesNotExist: + except oauth2_provider.oauth2.models.AccessToken.DoesNotExist: raise exceptions.AuthenticationFailed('Invalid token') user = token.user diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py index 024f62bf..8eac2cc4 100644 --- a/rest_framework/authtoken/models.py +++ b/rest_framework/authtoken/models.py @@ -1,5 +1,5 @@ -import uuid -import hmac +import binascii +import os from hashlib import sha1 from django.conf import settings from django.db import models @@ -34,8 +34,7 @@ class Token(models.Model): return super(Token, self).save(*args, **kwargs) def generate_key(self): - unique = uuid.uuid4() - return hmac.new(unique.bytes, digestmod=sha1).hexdigest() + return binascii.hexlify(os.urandom(20)) def __unicode__(self): return self.key diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 45045c0f..a013a155 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -121,7 +121,7 @@ from django.test.client import RequestFactory as DjangoRequestFactory from django.test.client import FakePayload try: # In 1.5 the test client uses force_bytes - from django.utils.encoding import force_bytes_or_smart_bytes + from django.utils.encoding import force_bytes as force_bytes_or_smart_bytes except ImportError: # In 1.4 the test client just uses smart_str from django.utils.encoding import smart_str as force_bytes_or_smart_bytes @@ -216,13 +216,10 @@ except (ImportError, ImproperlyConfigured): # OAuth 2 support is optional try: - import provider.oauth2 as oauth2_provider - from provider.oauth2 import models as oauth2_provider_models - from provider.oauth2 import forms as oauth2_provider_forms + import provider as oauth2_provider from provider import scope as oauth2_provider_scope from provider import constants as oauth2_constants - from provider import __version__ as provider_version - if provider_version in ('0.2.3', '0.2.4'): + if oauth2_provider.__version__ in ('0.2.3', '0.2.4'): # 0.2.3 and 0.2.4 are supported version that do not support # timezone aware datetimes import datetime @@ -232,8 +229,6 @@ try: from django.utils.timezone import now as provider_now except ImportError: oauth2_provider = None - oauth2_provider_models = None - oauth2_provider_forms = None oauth2_provider_scope = None oauth2_constants = None provider_now = None @@ -251,3 +246,23 @@ if six.PY3: else: def is_non_str_iterable(obj): return hasattr(obj, '__iter__') + + +try: + from django.utils.encoding import python_2_unicode_compatible +except ImportError: + def python_2_unicode_compatible(klass): + """ + A decorator that defines __unicode__ and __str__ methods under Python 2. + Under Python 3 it does nothing. + + To support Python 2 and 3 with a single code base, define a __str__ method + returning text and apply this decorator to the class. + """ + if '__str__' not in klass.__dict__: + raise ValueError("@python_2_unicode_compatible cannot be applied " + "to %s because it doesn't define __str__()." % + klass.__name__) + klass.__unicode__ = klass.__str__ + klass.__str__ = lambda self: self.__unicode__().encode('utf-8') + return klass diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 425a7214..5f774a9f 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -6,47 +6,42 @@ In addition Django's built in 403 and 404 exceptions are handled. """ from __future__ import unicode_literals from rest_framework import status +import math class APIException(Exception): """ Base class for REST framework exceptions. - Subclasses should provide `.status_code` and `.detail` properties. + Subclasses should provide `.status_code` and `.default_detail` properties. """ - pass + status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + default_detail = '' + def __init__(self, detail=None): + self.detail = detail or self.default_detail + + def __str__(self): + return self.detail class ParseError(APIException): status_code = status.HTTP_400_BAD_REQUEST default_detail = 'Malformed request.' - def __init__(self, detail=None): - self.detail = detail or self.default_detail - class AuthenticationFailed(APIException): status_code = status.HTTP_401_UNAUTHORIZED default_detail = 'Incorrect authentication credentials.' - def __init__(self, detail=None): - self.detail = detail or self.default_detail - class NotAuthenticated(APIException): status_code = status.HTTP_401_UNAUTHORIZED default_detail = 'Authentication credentials were not provided.' - def __init__(self, detail=None): - self.detail = detail or self.default_detail - class PermissionDenied(APIException): status_code = status.HTTP_403_FORBIDDEN default_detail = 'You do not have permission to perform this action.' - def __init__(self, detail=None): - self.detail = detail or self.default_detail - class MethodNotAllowed(APIException): status_code = status.HTTP_405_METHOD_NOT_ALLOWED @@ -75,14 +70,14 @@ class UnsupportedMediaType(APIException): class Throttled(APIException): status_code = status.HTTP_429_TOO_MANY_REQUESTS - default_detail = "Request was throttled." + default_detail = 'Request was throttled.' extra_detail = "Expected available in %d second%s." def __init__(self, wait=None, detail=None): - import math - self.wait = wait and math.ceil(wait) or None - if wait is not None: - format = detail or self.default_detail + self.extra_detail - self.detail = format % (self.wait, self.wait != 1 and 's' or '') - else: + if wait is None: self.detail = detail or self.default_detail + self.wait = None + else: + format = (detail or self.default_detail) + self.extra_detail + self.detail = format % (wait, wait != 1 and 's' or '') + self.wait = math.ceil(wait) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 16485b41..533de28c 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -166,7 +166,7 @@ class Field(object): Called to set up a field prior to field_to_native or field_from_native. parent - The parent serializer. - model_field - The model field this field corresponds to, if one exists. + field_name - The name of the field being initialized. """ self.parent = parent self.root = parent.root or parent @@ -248,6 +248,7 @@ class WritableField(Field): """ Base for read/write fields. """ + write_only = False default_validators = [] default_error_messages = { 'required': _('This field is required.'), @@ -257,13 +258,17 @@ class WritableField(Field): default = None def __init__(self, source=None, label=None, help_text=None, - read_only=False, required=None, + read_only=False, write_only=False, required=None, validators=[], error_messages=None, widget=None, default=None, blank=None): super(WritableField, self).__init__(source=source, label=label, help_text=help_text) self.read_only = read_only + self.write_only = write_only + + assert not (read_only and write_only), "Cannot set read_only=True and write_only=True" + if required is None: self.required = not(read_only) else: @@ -291,6 +296,11 @@ class WritableField(Field): result.validators = self.validators[:] return result + def get_default_value(self): + if is_simple_callable(self.default): + return self.default() + return self.default + def validate(self, value): if value in validators.EMPTY_VALUES and self.required: raise ValidationError(self.error_messages['required']) @@ -313,6 +323,11 @@ class WritableField(Field): if errors: raise ValidationError(errors) + def field_to_native(self, obj, field_name): + if self.write_only: + return None + return super(WritableField, self).field_to_native(obj, field_name) + def field_from_native(self, data, files, field_name, into): """ Given a dictionary and a field name, updates the dictionary `into`, @@ -334,10 +349,7 @@ class WritableField(Field): except KeyError: if self.default is not None and not self.partial: # Note: partial updates shouldn't set defaults - if is_simple_callable(self.default): - native = self.default() - else: - native = self.default + native = self.get_default_value() else: if self.required: raise ValidationError(self.error_messages['required']) @@ -465,7 +477,8 @@ class URLField(CharField): type_label = 'url' def __init__(self, **kwargs): - kwargs['validators'] = [validators.URLValidator()] + if not 'validators' in kwargs: + kwargs['validators'] = [validators.URLValidator()] super(URLField, self).__init__(**kwargs) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 5c6a187c..96d15eb9 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -3,8 +3,10 @@ Provides generic filtering backends that can be used to filter the results returned by list views. """ from __future__ import unicode_literals +from django.core.exceptions import ImproperlyConfigured from django.db import models from rest_framework.compat import django_filters, six, guardian, get_model_name +from rest_framework.settings import api_settings from functools import reduce import operator @@ -68,7 +70,8 @@ class DjangoFilterBackend(BaseFilterBackend): class SearchFilter(BaseFilterBackend): - search_param = 'search' # The URL query parameter used for the search. + # The URL query parameter used for the search. + search_param = api_settings.SEARCH_PARAM def get_search_terms(self, request): """ @@ -106,7 +109,9 @@ class SearchFilter(BaseFilterBackend): class OrderingFilter(BaseFilterBackend): - ordering_param = 'ordering' # The URL query parameter used for the ordering. + # The URL query parameter used for the ordering. + ordering_param = api_settings.ORDERING_PARAM + ordering_fields = None def get_ordering(self, request): """ @@ -122,17 +127,34 @@ class OrderingFilter(BaseFilterBackend): return (ordering,) return ordering - def remove_invalid_fields(self, queryset, ordering): - field_names = [field.name for field in queryset.model._meta.fields] - field_names += queryset.query.aggregates.keys() - return [term for term in ordering if term.lstrip('-') in field_names] + def remove_invalid_fields(self, queryset, ordering, view): + valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) + + if valid_fields is None: + # Default to allowing filtering on serializer fields + serializer_class = getattr(view, 'serializer_class') + if serializer_class is None: + msg = ("Cannot use %s on a view which does not have either a " + "'serializer_class' or 'ordering_fields' attribute.") + raise ImproperlyConfigured(msg % self.__class__.__name__) + valid_fields = [ + field.source or field_name + for field_name, field in serializer_class().fields.items() + if not getattr(field, 'write_only', False) + ] + elif valid_fields == '__all__': + # View explictly allows filtering on any model field + valid_fields = [field.name for field in queryset.model._meta.fields] + valid_fields += queryset.query.aggregates.keys() + + return [term for term in ordering if term.lstrip('-') in valid_fields] def filter_queryset(self, request, queryset, view): ordering = self.get_ordering(request) if ordering: # Skip any incorrect parameters - ordering = self.remove_invalid_fields(queryset, ordering) + ordering = self.remove_invalid_fields(queryset, ordering, view) if not ordering: # Use 'ordering' attribute by default diff --git a/rest_framework/generics.py b/rest_framework/generics.py index bd33c01a..c3256844 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -352,7 +352,7 @@ class GenericAPIView(views.APIView): def post_delete(self, obj): """ - Placeholder method for calling after saving an object. + Placeholder method for calling after deleting an object. """ pass diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index b62a4cc1..2cc87eef 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -11,6 +11,7 @@ from django.http import Http404 from rest_framework import status from rest_framework.response import Response from rest_framework.request import clone_request +from rest_framework.settings import api_settings import warnings @@ -60,7 +61,7 @@ class CreateModelMixin(object): def get_success_headers(self, data): try: - return {'Location': data['url']} + return {'Location': data[api_settings.URL_FIELD_NAME]} except (TypeError, KeyError): return {} @@ -115,30 +116,27 @@ class UpdateModelMixin(object): partial = kwargs.pop('partial', False) self.object = self.get_object_or_none() - if self.object is None: - created = True - save_kwargs = {'force_insert': True} - success_status_code = status.HTTP_201_CREATED - else: - created = False - save_kwargs = {'force_update': True} - success_status_code = status.HTTP_200_OK - serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES, partial=partial) - if serializer.is_valid(): - try: - self.pre_save(serializer.object) - except ValidationError as err: - # full_clean on model instance may be called in pre_save, so we - # have to handle eventual errors. - return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST) - self.object = serializer.save(**save_kwargs) - self.post_save(self.object, created=created) - return Response(serializer.data, status=success_status_code) + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + try: + self.pre_save(serializer.object) + except ValidationError as err: + # full_clean on model instance may be called in pre_save, + # so we have to handle eventual errors. + return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST) + + if self.object is None: + self.object = serializer.save(force_insert=True) + self.post_save(self.object, created=True) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + self.object = serializer.save(force_update=True) + self.post_save(self.object, created=False) + return Response(serializer.data, status=status.HTTP_200_OK) def partial_update(self, request, *args, **kwargs): kwargs['partial'] = True diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 4785c009..3b234dd5 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -33,6 +33,7 @@ class RelatedField(WritableField): many_widget = widgets.SelectMultiple form_field_class = forms.ChoiceField many_form_field_class = forms.MultipleChoiceField + null_values = (None, '', 'None') cache_choices = False empty_label = None @@ -50,6 +51,8 @@ class RelatedField(WritableField): super(RelatedField, self).__init__(*args, **kwargs) if not self.required: + # Accessed in ModelChoiceIterator django/forms/models.py:1034 + # If set adds empty choice. self.empty_label = BLANK_CHOICE_DASH[0][1] self.queryset = queryset @@ -57,16 +60,11 @@ class RelatedField(WritableField): def initialize(self, parent, field_name): super(RelatedField, self).initialize(parent, field_name) if self.queryset is None and not self.read_only: - try: - manager = getattr(self.parent.opts.model, self.source or field_name) - if hasattr(manager, 'related'): # Forward - self.queryset = manager.related.model._default_manager.all() - else: # Reverse - self.queryset = manager.field.rel.to._default_manager.all() - except Exception: - msg = ('Serializer related fields must include a `queryset`' + - ' argument or set `read_only=True') - raise Exception(msg) + manager = getattr(self.parent.opts.model, self.source or field_name) + if hasattr(manager, 'related'): # Forward + self.queryset = manager.related.model._default_manager.all() + else: # Reverse + self.queryset = manager.field.rel.to._default_manager.all() ### We need this stuff to make form choices work... @@ -115,6 +113,14 @@ class RelatedField(WritableField): choices = property(_get_choices, _set_choices) + ### Default value handling + + def get_default_value(self): + default = super(RelatedField, self).get_default_value() + if self.many and default is None: + return [] + return default + ### Regular serializer stuff... def field_to_native(self, obj, field_name): @@ -163,11 +169,11 @@ class RelatedField(WritableField): except KeyError: if self.partial: return - value = [] if self.many else None + value = self.get_default_value() - if value in (None, '') and self.required: - raise ValidationError(self.error_messages['required']) - elif value in (None, ''): + if value in self.null_values: + if self.required: + raise ValidationError(self.error_messages['required']) into[(self.source or field_name)] = None elif self.many: into[(self.source or field_name)] = [self.from_native(item) for item in value] diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 2fdd3337..7a7da561 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -10,6 +10,7 @@ from __future__ import unicode_literals import copy import json +import django from django import forms from django.core.exceptions import ImproperlyConfigured from django.http.multipartparser import parse_header @@ -145,7 +146,7 @@ class XMLRenderer(BaseRenderer): def render(self, data, accepted_media_type=None, renderer_context=None): """ - Renders *obj* into serialized XML. + Renders `data` into serialized XML. """ if data is None: return '' @@ -195,7 +196,7 @@ class YAMLRenderer(BaseRenderer): def render(self, data, accepted_media_type=None, renderer_context=None): """ - Renders *obj* into serialized YAML. + Renders `data` into serialized YAML. """ assert yaml, 'YAMLRenderer requires pyyaml to be installed' @@ -426,7 +427,7 @@ class BrowsableAPIRenderer(BaseRenderer): files = request.FILES except ParseError: data = None - files = None + files = None else: data = None files = None @@ -543,6 +544,14 @@ class BrowsableAPIRenderer(BaseRenderer): raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request) raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form + response_headers = dict(response.items()) + renderer_content_type = '' + if renderer: + renderer_content_type = '%s' % renderer.media_type + if renderer.charset: + renderer_content_type += ' ;%s' % renderer.charset + response_headers['Content-Type'] = renderer_content_type + context = { 'content': self.get_content(renderer, data, accepted_media_type, renderer_context), 'view': view, @@ -554,6 +563,7 @@ class BrowsableAPIRenderer(BaseRenderer): 'breadcrumblist': self.get_breadcrumbs(request), 'allowed_methods': view.allowed_methods, 'available_formats': [renderer.format for renderer in view.renderer_classes], + 'response_headers': response_headers, 'put_form': self.get_rendered_html_form(view, 'PUT', request), 'post_form': self.get_rendered_html_form(view, 'POST', request), @@ -597,7 +607,7 @@ class MultiPartRenderer(BaseRenderer): media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg' format = 'multipart' charset = 'utf-8' - BOUNDARY = 'BoUnDaRyStRiNg' + BOUNDARY = 'BoUnDaRyStRiNg' if django.VERSION >= (1, 5) else b'BoUnDaRyStRiNg' def render(self, data, accepted_media_type=None, renderer_context=None): return encode_multipart(self.BOUNDARY, data) diff --git a/rest_framework/request.py b/rest_framework/request.py index fcea2508..40467c03 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -223,7 +223,7 @@ class Request(object): def user(self, value): """ Sets the user on the current request. This is necessary to maintain - compatilbility with django.contrib.auth where the user proprety is + compatibility with django.contrib.auth where the user property is set in the login and logout functions. """ self._user = value @@ -279,10 +279,9 @@ class Request(object): if not _hasattr(self, '_method'): self._method = self._request.method - if self._method == 'POST': - # Allow X-HTTP-METHOD-OVERRIDE header - self._method = self.META.get('HTTP_X_HTTP_METHOD_OVERRIDE', - self._method) + # Allow X-HTTP-METHOD-OVERRIDE header + self._method = self.META.get('HTTP_X_HTTP_METHOD_OVERRIDE', + self._method) def _load_stream(self): """ @@ -347,7 +346,7 @@ class Request(object): media_type = self.content_type if stream is None or media_type is None: - empty_data = QueryDict('', self._request._encoding) + empty_data = QueryDict('', encoding=self._request._encoding) empty_files = MultiValueDict() return (empty_data, empty_files) @@ -363,7 +362,7 @@ class Request(object): # re-raise. Ensures we don't simply repeat the error when # attempting to render the browsable renderer response, or when # logging the request or similar. - self._data = QueryDict('', self._request._encoding) + self._data = QueryDict('', encoding=self._request._encoding) self._files = MultiValueDict() raise diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py index da36d23f..2daaae4e 100755 --- a/rest_framework/runtests/runtests.py +++ b/rest_framework/runtests/runtests.py @@ -26,6 +26,10 @@ def usage(): def main(): + try: + django.setup() + except AttributeError: + pass TestRunner = get_runner(settings) test_runner = TestRunner() diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index 12aa73e7..36283d8e 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -97,6 +97,9 @@ INSTALLED_APPS = ( 'rest_framework', 'rest_framework.authtoken', 'rest_framework.tests', + 'rest_framework.tests.accounts', + 'rest_framework.tests.records', + 'rest_framework.tests.users', ) # OAuth is optional and won't work if there is no oauth_provider & oauth2 diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index cbf73fc3..5b14e403 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -13,12 +13,15 @@ response content is handled by parsers and renderers. from __future__ import unicode_literals import copy import datetime +import inspect import types from decimal import Decimal from django.db import models from django.forms import widgets from django.utils.datastructures import SortedDict -from rest_framework.compat import six +from rest_framework.compat import get_concrete_model, six +from rest_framework.settings import api_settings + # Note: We do the following so that users of the framework can use this style: # @@ -31,6 +34,27 @@ from rest_framework.relations import * from rest_framework.fields import * +def _resolve_model(obj): + """ + Resolve supplied `obj` to a Django model class. + + `obj` must be a Django model class itself, or a string + representation of one. Useful in situtations like GH #1225 where + Django may not have resolved a string-based reference to a model in + another model's foreign key definition. + + String representations should have the format: + 'appname.ModelName' + """ + if type(obj) == str and len(obj.split('.')) == 2: + app_name, model_name = obj.split('.') + return models.get_model(app_name, model_name) + elif inspect.isclass(obj) and issubclass(obj, models.Model): + return obj + else: + raise ValueError("{0} is not a Django model".format(obj)) + + def pretty_name(name): """Converts 'first_name' to 'First name'""" if not name: @@ -325,12 +349,13 @@ class BaseSerializer(WritableField): method = getattr(self, 'transform_%s' % field_name, None) if callable(method): value = method(obj, value) - ret[key] = value + if not getattr(field, 'write_only', False): + ret[key] = value ret.fields[key] = self.augment_field(field, field_name, key, value) return ret - def from_native(self, data, files): + def from_native(self, data, files=None): """ Deserialize primitives -> objects. """ @@ -360,6 +385,9 @@ class BaseSerializer(WritableField): Override default so that the serializer can be used as a nested field across relationships. """ + if self.write_only: + return None + if self.source == '*': return self.to_native(obj) @@ -404,16 +432,6 @@ class BaseSerializer(WritableField): raise ValidationError(self.error_messages['required']) return - # Set the serializer object if it exists - obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None - - # If we have a model manager or similar object then we need - # to iterate through each instance. - if (self.many and - not hasattr(obj, '__iter__') and - is_simple_callable(getattr(obj, 'all', None))): - obj = obj.all() - if self.source == '*': if value: reverted_data = self.restore_fields(value, {}) @@ -423,6 +441,16 @@ class BaseSerializer(WritableField): if value in (None, ''): into[(self.source or field_name)] = None else: + # Set the serializer object if it exists + obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None + + # If we have a model manager or similar object then we need + # to iterate through each instance. + if (self.many and + not hasattr(obj, '__iter__') and + is_simple_callable(getattr(obj, 'all', None))): + obj = obj.all() + kwargs = { 'instance': obj, 'data': value, @@ -467,7 +495,7 @@ class BaseSerializer(WritableField): else: many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) if many: - warnings.warn('Implict list/queryset serialization is deprecated. ' + warnings.warn('Implicit list/queryset serialization is deprecated. ' 'Use the `many=True` flag when instantiating the serializer.', DeprecationWarning, stacklevel=3) @@ -524,7 +552,16 @@ class BaseSerializer(WritableField): if self._data is None: obj = self.object - if self.many: + if self.many is not None: + many = self.many + else: + many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) + if many: + warnings.warn('Implicit list/queryset serialization is deprecated. ' + 'Use the `many=True` flag when instantiating the serializer.', + DeprecationWarning, stacklevel=2) + + if many: self._data = [self.to_native(item) for item in obj] else: self._data = self.to_native(obj) @@ -578,6 +615,7 @@ class ModelSerializerOptions(SerializerOptions): super(ModelSerializerOptions, self).__init__(meta) self.model = getattr(meta, 'model', None) self.read_only_fields = getattr(meta, 'read_only_fields', ()) + self.write_only_fields = getattr(meta, 'write_only_fields', ()) class ModelSerializer(Serializer): @@ -641,7 +679,7 @@ class ModelSerializer(Serializer): if model_field.rel: to_many = isinstance(model_field, models.fields.related.ManyToManyField) - related_model = model_field.rel.to + related_model = _resolve_model(model_field.rel.to) if to_many and not model_field.rel.through._meta.auto_created: has_through_model = True @@ -713,20 +751,38 @@ class ModelSerializer(Serializer): field.read_only = True ret[accessor_name] = field + + # Ensure that 'read_only_fields' is an iterable + assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple' - # Add the `read_only` flag to any fields that have bee specified + # Add the `read_only` flag to any fields that have been specified # in the `read_only_fields` option for field_name in self.opts.read_only_fields: - assert field_name not in self.base_fields.keys(), \ - "field '%s' on serializer '%s' specified in " \ - "`read_only_fields`, but also added " \ - "as an explicit field. Remove it from `read_only_fields`." % \ - (field_name, self.__class__.__name__) - assert field_name in ret, \ - "Non-existant field '%s' specified in `read_only_fields` " \ - "on serializer '%s'." % \ - (field_name, self.__class__.__name__) + assert field_name not in self.base_fields.keys(), ( + "field '%s' on serializer '%s' specified in " + "`read_only_fields`, but also added " + "as an explicit field. Remove it from `read_only_fields`." % + (field_name, self.__class__.__name__)) + assert field_name in ret, ( + "Non-existant field '%s' specified in `read_only_fields` " + "on serializer '%s'." % + (field_name, self.__class__.__name__)) ret[field_name].read_only = True + + # Ensure that 'write_only_fields' is an iterable + assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple' + + for field_name in self.opts.write_only_fields: + assert field_name not in self.base_fields.keys(), ( + "field '%s' on serializer '%s' specified in " + "`write_only_fields`, but also added " + "as an explicit field. Remove it from `write_only_fields`." % + (field_name, self.__class__.__name__)) + assert field_name in ret, ( + "Non-existant field '%s' specified in `write_only_fields` " + "on serializer '%s'." % + (field_name, self.__class__.__name__)) + ret[field_name].write_only = True return ret @@ -829,7 +885,7 @@ class ModelSerializer(Serializer): except KeyError: return ModelField(model_field=model_field, **kwargs) - def get_validation_exclusions(self): + def get_validation_exclusions(self, instance=None): """ Return a list of field names to exclude from model validation. """ @@ -841,6 +897,7 @@ class ModelSerializer(Serializer): field_name = field.source or field_name if field_name in exclusions \ and not field.read_only \ + and (field.required or hasattr(instance, field_name)) \ and not isinstance(field, Serializer): exclusions.remove(field_name) return exclusions @@ -855,7 +912,7 @@ class ModelSerializer(Serializer): the full_clean validation checking. """ try: - instance.full_clean(exclude=self.get_validation_exclusions()) + instance.full_clean(exclude=self.get_validation_exclusions(instance)) except ValidationError as err: self._errors = err.message_dict return None @@ -883,7 +940,7 @@ class ModelSerializer(Serializer): m2m_data[field_name] = attrs.pop(field_name) # Forward m2m relations - for field in meta.many_to_many: + for field in meta.many_to_many + meta.virtual_fields: if field.name in attrs: m2m_data[field.name] = attrs.pop(field.name) @@ -979,6 +1036,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): super(HyperlinkedModelSerializerOptions, self).__init__(meta) self.view_name = getattr(meta, 'view_name', None) self.lookup_field = getattr(meta, 'lookup_field', None) + self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME) class HyperlinkedModelSerializer(ModelSerializer): @@ -997,13 +1055,13 @@ class HyperlinkedModelSerializer(ModelSerializer): if self.opts.view_name is None: self.opts.view_name = self._get_default_view_name(self.opts.model) - if 'url' not in fields: + if self.opts.url_field_name not in fields: url_field = self._hyperlink_identify_field_class( view_name=self.opts.view_name, lookup_field=self.opts.lookup_field ) ret = self._dict_class() - ret['url'] = url_field + ret[self.opts.url_field_name] = url_field ret.update(fields) fields = ret @@ -1039,7 +1097,7 @@ class HyperlinkedModelSerializer(ModelSerializer): We need to override the default, to use the url as the identity. """ try: - return data.get('url', None) + return data.get(self.opts.url_field_name, None) except AttributeError: return None diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 383de72e..189131f1 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -70,6 +70,10 @@ DEFAULTS = { 'PAGINATE_BY_PARAM': None, 'MAX_PAGINATE_BY': None, + # Filtering + 'SEARCH_PARAM': 'search', + 'ORDERING_PARAM': 'ordering', + # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -96,6 +100,7 @@ DEFAULTS = { 'URL_FORMAT_OVERRIDE': 'format', 'FORMAT_SUFFIX_KWARG': 'format', + 'URL_FIELD_NAME': 'url', # Input and output formats 'DATE_INPUT_FORMATS': ( diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 42ede968..210741ed 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -34,7 +34,7 @@ <div class="navbar-inner"> <div class="container-fluid"> <span href="/"> - {% block branding %}<a class='brand' href='http://django-rest-framework.org'>Django REST framework <span class="version">{{ version }}</span></a>{% endblock %} + {% block branding %}<a class='brand' rel="nofollow" href='http://www.django-rest-framework.org'>Django REST framework <span class="version">{{ version }}</span></a>{% endblock %} </span> <ul class="nav pull-right"> {% block userlinks %} @@ -119,7 +119,7 @@ </div> <div class="response-info"> <pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %} -{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span> +{% for key, val in response_headers.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span> {% endfor %} </div>{{ content|urlize_quoted_links }}</pre>{% endautoescape %} </div> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 55f36149..a0f9c841 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -2,10 +2,12 @@ from __future__ import unicode_literals, absolute_import from django import template from django.core.urlresolvers import reverse, NoReverseMatch from django.http import QueryDict -from django.utils.html import escape, smart_urlquote +from django.utils.encoding import iri_to_uri +from django.utils.html import escape from django.utils.safestring import SafeData, mark_safe from rest_framework.compat import urlparse, force_text, six -import re, string +from django.utils.html import smart_urlquote +import re register = template.Library() @@ -61,7 +63,9 @@ def add_query_param(request, key, val): """ Add a query parameter to the current request url, and return the new url. """ - return replace_query_param(request.get_full_path(), key, val) + iri = request.get_full_path() + uri = iri_to_uri(iri) + return replace_query_param(uri, key, val) @register.filter @@ -103,6 +107,17 @@ simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net simple_email_re = re.compile(r'^\S+@\S+\.\S+$') +def smart_urlquote_wrapper(matched_url): + """ + Simple wrapper for smart_urlquote. ValueError("Invalid IPv6 URL") can + be raised here, see issue #1386 + """ + try: + return smart_urlquote(matched_url) + except ValueError: + return None + + @register.filter def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True): """ @@ -125,7 +140,6 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru safe_input = isinstance(text, SafeData) words = word_split_re.split(force_text(text)) for i, word in enumerate(words): - match = None if '.' in word or '@' in word or ':' in word: # Deal with punctuation. lead, middle, trail = '', word, '' @@ -147,9 +161,9 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru url = None nofollow_attr = ' rel="nofollow"' if nofollow else '' if simple_url_re.match(middle): - url = smart_urlquote(middle) + url = smart_urlquote_wrapper(middle) elif simple_url_2_re.match(middle): - url = smart_urlquote('http://%s' % middle) + url = smart_urlquote_wrapper('http://%s' % middle) elif not ':' in middle and simple_email_re.match(middle): local, domain = middle.rsplit('@', 1) try: diff --git a/rest_framework/test.py b/rest_framework/test.py index 234d10a4..df5a5b3b 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -8,6 +8,7 @@ from django.conf import settings from django.test.client import Client as DjangoClient from django.test.client import ClientHandler from django.test import testcases +from django.utils.http import urlencode from rest_framework.settings import api_settings from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import force_bytes_or_smart_bytes, six @@ -71,6 +72,17 @@ class APIRequestFactory(DjangoRequestFactory): return ret, content_type + def get(self, path, data=None, **extra): + r = { + 'QUERY_STRING': urlencode(data or {}, doseq=True), + } + # Fix to support old behavior where you have the arguments in the url + # See #1461 + if not data and '?' in path: + r['QUERY_STRING'] = path.split('?')[1] + r.update(extra) + return self.generic('GET', path, **r) + def post(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('POST', path, data, content_type, **extra) diff --git a/rest_framework/tests/accounts/__init__.py b/rest_framework/tests/accounts/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/accounts/__init__.py diff --git a/rest_framework/tests/accounts/models.py b/rest_framework/tests/accounts/models.py new file mode 100644 index 00000000..525e601b --- /dev/null +++ b/rest_framework/tests/accounts/models.py @@ -0,0 +1,8 @@ +from django.db import models + +from rest_framework.tests.users.models import User + + +class Account(models.Model): + owner = models.ForeignKey(User, related_name='accounts_owned') + admins = models.ManyToManyField(User, blank=True, null=True, related_name='accounts_administered') diff --git a/rest_framework/tests/accounts/serializers.py b/rest_framework/tests/accounts/serializers.py new file mode 100644 index 00000000..a27b9ca6 --- /dev/null +++ b/rest_framework/tests/accounts/serializers.py @@ -0,0 +1,11 @@ +from rest_framework import serializers + +from rest_framework.tests.accounts.models import Account +from rest_framework.tests.users.serializers import UserSerializer + + +class AccountSerializer(serializers.ModelSerializer): + admins = UserSerializer(many=True) + + class Meta: + model = Account diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 32a726c0..6c8f2342 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -103,7 +103,7 @@ class BlogPostComment(RESTFrameworkModel): class Album(RESTFrameworkModel): title = models.CharField(max_length=100, unique=True) - + ref = models.CharField(max_length=10, unique=True, null=True, blank=True) class Photo(RESTFrameworkModel): description = models.TextField() @@ -168,3 +168,10 @@ class NullableOneToOneSource(RESTFrameworkModel): class BasicModelSerializer(serializers.ModelSerializer): class Meta: model = BasicModel + + +# Models to test filters +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() diff --git a/rest_framework/tests/records/__init__.py b/rest_framework/tests/records/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/records/__init__.py diff --git a/rest_framework/tests/records/models.py b/rest_framework/tests/records/models.py new file mode 100644 index 00000000..76954807 --- /dev/null +++ b/rest_framework/tests/records/models.py @@ -0,0 +1,6 @@ +from django.db import models + + +class Record(models.Model): + account = models.ForeignKey('accounts.Account', blank=True, null=True) + owner = models.ForeignKey('users.User', blank=True, null=True) diff --git a/rest_framework/tests/serializers.py b/rest_framework/tests/serializers.py new file mode 100644 index 00000000..cc943c7d --- /dev/null +++ b/rest_framework/tests/serializers.py @@ -0,0 +1,8 @@ +from rest_framework import serializers + +from rest_framework.tests.models import NullableForeignKeySource + + +class NullableFKSourceSerializer(serializers.ModelSerializer): + class Meta: + model = NullableForeignKeySource diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index fb0bc694..6c14debb 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -4,6 +4,7 @@ from django.contrib.auth.models import User from django.http import HttpResponse from django.test import TestCase from django.utils import unittest +from django.utils.http import urlencode from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions from rest_framework import permissions @@ -19,7 +20,7 @@ from rest_framework.authentication import ( OAuth2Authentication ) from rest_framework.authtoken.models import Token -from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope +from rest_framework.compat import oauth2_provider, oauth2_provider_scope from rest_framework.compat import oauth, oauth_provider from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView @@ -53,10 +54,14 @@ urlpatterns = patterns('', permission_classes=[permissions.TokenHasReadWriteScope])) ) +class OAuth2AuthenticationDebug(OAuth2Authentication): + allow_query_params_token = True + if oauth2_provider is not None: urlpatterns += patterns('', url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), + url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])), url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], permission_classes=[permissions.TokenHasReadWriteScope])), ) @@ -488,7 +493,7 @@ class OAuth2Tests(TestCase): self.ACCESS_TOKEN = "access_token" self.REFRESH_TOKEN = "refresh_token" - self.oauth2_client = oauth2_provider_models.Client.objects.create( + self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create( client_id=self.CLIENT_ID, client_secret=self.CLIENT_SECRET, redirect_uri='', @@ -497,12 +502,12 @@ class OAuth2Tests(TestCase): user=None, ) - self.access_token = oauth2_provider_models.AccessToken.objects.create( + self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create( token=self.ACCESS_TOKEN, client=self.oauth2_client, user=self.user, ) - self.refresh_token = oauth2_provider_models.RefreshToken.objects.create( + self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create( user=self.user, access_token=self.access_token, client=self.oauth2_client @@ -546,6 +551,27 @@ class OAuth2Tests(TestCase): self.assertEqual(response.status_code, 200) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_passing_auth_url_transport(self): + """Ensure GETing form over OAuth with correct client credentials in form data succeed""" + response = self.csrf_client.post('/oauth2-test/', + data={'access_token': self.access_token.token}) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_passing_auth_url_transport(self): + """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True""" + query = urlencode({'access_token': self.access_token.token}) + response = self.csrf_client.get('/oauth2-test-debug/?%s' % query) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_failing_auth_url_transport(self): + """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False""" + query = urlencode({'access_token': self.access_token.token}) + response = self.csrf_client.get('/oauth2-test/?%s' % query) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_post_form_passing_auth(self): """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" auth = self._create_authorization_header() diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index 5c96bce9..e127feef 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -860,7 +860,9 @@ class SlugFieldTests(TestCase): class URLFieldTests(TestCase): """ - Tests for URLField attribute values + Tests for URLField attribute values. + + (Includes test for #1210, checking that validators can be overridden.) """ class URLFieldModel(RESTFrameworkModel): @@ -902,6 +904,11 @@ class URLFieldTests(TestCase): self.assertEqual(getattr(serializer.fields['url_field'], 'max_length'), 20) + def test_validators_can_be_overridden(self): + url_field = serializers.URLField(validators=[]) + validators = url_field.validators + self.assertEqual([], validators, 'Passing `validators` kwarg should have overridden default validators') + class FieldMetadata(TestCase): def setUp(self): diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py index 8a03a077..2aa6f81a 100644 --- a/rest_framework/tests/test_filters.py +++ b/rest_framework/tests/test_filters.py @@ -1,25 +1,21 @@ from __future__ import unicode_literals import datetime from decimal import Decimal -from django.conf.urls import patterns, url from django.db import models from django.core.urlresolvers import reverse from django.test import TestCase from django.utils import unittest +from django.conf.urls import patterns, url from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel +from .models import FilterableItem +from .utils import temporary_setting factory = APIRequestFactory() -class FilterableItem(models.Model): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() - - if django_filters: # Basic filter on a list view. class FilterFieldsRootView(generics.ListCreateAPIView): @@ -129,7 +125,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that the decimal filter works. search_decimal = Decimal('2.25') - request = factory.get('/?decimal=%s' % search_decimal) + request = factory.get('/', {'decimal': '%s' % search_decimal}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['decimal'] == search_decimal] @@ -137,7 +133,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that the date filter works. search_date = datetime.date(2012, 9, 22) - request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' + request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-09-22' response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['date'] == search_date] @@ -152,7 +148,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that the decimal filter works. search_decimal = Decimal('2.25') - request = factory.get('/?decimal=%s' % search_decimal) + request = factory.get('/', {'decimal': '%s' % search_decimal}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['decimal'] == search_decimal] @@ -185,7 +181,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that the decimal filter set with 'lt' in the filter class works. search_decimal = Decimal('4.25') - request = factory.get('/?decimal=%s' % search_decimal) + request = factory.get('/', {'decimal': '%s' % search_decimal}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['decimal'] < search_decimal] @@ -193,7 +189,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that the date filter set with 'gt' in the filter class works. search_date = datetime.date(2012, 10, 2) - request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' + request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-10-02' response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['date'] > search_date] @@ -201,7 +197,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that the text filter set with 'icontains' in the filter class works. search_text = 'ff' - request = factory.get('/?text=%s' % search_text) + request = factory.get('/', {'text': '%s' % search_text}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if search_text in f['text'].lower()] @@ -210,7 +206,10 @@ class IntegrationTestFiltering(CommonFilteringTestCase): # Tests that multiple filters works. search_decimal = Decimal('5.25') search_date = datetime.date(2012, 10, 2) - request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) + request = factory.get('/', { + 'decimal': '%s' % (search_decimal,), + 'date': '%s' % (search_date,) + }) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['date'] > search_date and @@ -235,7 +234,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): view = FilterFieldsRootView.as_view() search_integer = 10 - request = factory.get('/?integer=%s' % search_integer) + request = factory.get('/', {'integer': '%s' % search_integer}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -266,14 +265,18 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): # Tests that the decimal filter set that should fail. search_decimal = Decimal('4.25') high_item = self.objects.filter(decimal__gt=search_decimal)[0] - response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) + response = self.client.get( + '{url}'.format(url=self._get_url(high_item)), + {'decimal': '{param}'.format(param=search_decimal)}) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) # Tests that the decimal filter set that should succeed. search_decimal = Decimal('4.25') low_item = self.objects.filter(decimal__lt=search_decimal)[0] low_item_data = self._serialize_object(low_item) - response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) + response = self.client.get( + '{url}'.format(url=self._get_url(low_item)), + {'decimal': '{param}'.format(param=search_decimal)}) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, low_item_data) @@ -282,7 +285,11 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): search_date = datetime.date(2012, 10, 2) valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] valid_item_data = self._serialize_object(valid_item) - response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) + response = self.client.get( + '{url}'.format(url=self._get_url(valid_item)), { + 'decimal': '{decimal}'.format(decimal=search_decimal), + 'date': '{date}'.format(date=search_date) + }) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, valid_item_data) @@ -316,7 +323,7 @@ class SearchFilterTests(TestCase): search_fields = ('title', 'text') view = SearchListView.as_view() - request = factory.get('?search=b') + request = factory.get('/', {'search': 'b'}) response = view(request) self.assertEqual( response.data, @@ -333,7 +340,7 @@ class SearchFilterTests(TestCase): search_fields = ('=title', 'text') view = SearchListView.as_view() - request = factory.get('?search=zzz') + request = factory.get('/', {'search': 'zzz'}) response = view(request) self.assertEqual( response.data, @@ -349,7 +356,7 @@ class SearchFilterTests(TestCase): search_fields = ('title', '^text') view = SearchListView.as_view() - request = factory.get('?search=b') + request = factory.get('/', {'search': 'b'}) response = view(request) self.assertEqual( response.data, @@ -358,6 +365,24 @@ class SearchFilterTests(TestCase): ] ) + def test_search_with_nonstandard_search_param(self): + with temporary_setting('SEARCH_PARAM', 'query', module=filters): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text') + + view = SearchListView.as_view() + request = factory.get('/', {'query': 'b'}) + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'z', 'text': 'abc'}, + {'id': 2, 'title': 'zz', 'text': 'bcd'} + ] + ) + class OrdringFilterModel(models.Model): title = models.CharField(max_length=20) @@ -369,7 +394,6 @@ class OrderingFilterRelatedModel(models.Model): related_name="relateds") - class OrderingFilterTests(TestCase): def setUp(self): # Sequence of title/text is: @@ -395,9 +419,10 @@ class OrderingFilterTests(TestCase): model = OrdringFilterModel filter_backends = (filters.OrderingFilter,) ordering = ('title',) + ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('?ordering=text') + request = factory.get('/', {'ordering': 'text'}) response = view(request) self.assertEqual( response.data, @@ -413,9 +438,10 @@ class OrderingFilterTests(TestCase): model = OrdringFilterModel filter_backends = (filters.OrderingFilter,) ordering = ('title',) + ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('?ordering=-text') + request = factory.get('/', {'ordering': '-text'}) response = view(request) self.assertEqual( response.data, @@ -431,9 +457,10 @@ class OrderingFilterTests(TestCase): model = OrdringFilterModel filter_backends = (filters.OrderingFilter,) ordering = ('title',) + ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('?ordering=foobar') + request = factory.get('/', {'ordering': 'foobar'}) response = view(request) self.assertEqual( response.data, @@ -449,6 +476,7 @@ class OrderingFilterTests(TestCase): model = OrdringFilterModel filter_backends = (filters.OrderingFilter,) ordering = ('title',) + oredering_fields = ('text',) view = OrderingListView.as_view() request = factory.get('') @@ -467,6 +495,7 @@ class OrderingFilterTests(TestCase): model = OrdringFilterModel filter_backends = (filters.OrderingFilter,) ordering = 'title' + ordering_fields = ('text',) view = OrderingListView.as_view() request = factory.get('') @@ -495,11 +524,12 @@ class OrderingFilterTests(TestCase): model = OrdringFilterModel filter_backends = (filters.OrderingFilter,) ordering = 'title' + ordering_fields = '__all__' queryset = OrdringFilterModel.objects.all().annotate( models.Count("relateds")) view = OrderingListView.as_view() - request = factory.get('?ordering=relateds__count') + request = factory.get('/', {'ordering': 'relateds__count'}) response = view(request) self.assertEqual( response.data, @@ -510,5 +540,122 @@ class OrderingFilterTests(TestCase): ] ) + def test_ordering_with_nonstandard_ordering_param(self): + with temporary_setting('ORDERING_PARAM', 'order', filters): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + ordering_fields = ('text',) + + view = OrderingListView.as_view() + request = factory.get('/', {'order': 'text'}) + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + ] + ) + + +class SensitiveOrderingFilterModel(models.Model): + username = models.CharField(max_length=20) + password = models.CharField(max_length=100) + + +# Three different styles of serializer. +# All should allow ordering by username, but not by password. +class SensitiveDataSerializer1(serializers.ModelSerializer): + username = serializers.CharField() + class Meta: + model = SensitiveOrderingFilterModel + fields = ('id', 'username') + +class SensitiveDataSerializer2(serializers.ModelSerializer): + username = serializers.CharField() + password = serializers.CharField(write_only=True) + + class Meta: + model = SensitiveOrderingFilterModel + fields = ('id', 'username', 'password') + + +class SensitiveDataSerializer3(serializers.ModelSerializer): + user = serializers.CharField(source='username') + + class Meta: + model = SensitiveOrderingFilterModel + fields = ('id', 'user') + + +class SensitiveOrderingFilterTests(TestCase): + def setUp(self): + for idx in range(3): + username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx] + password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx] + SensitiveOrderingFilterModel(username=username, password=password).save() + + def test_order_by_serializer_fields(self): + for serializer_cls in [ + SensitiveDataSerializer1, + SensitiveDataSerializer2, + SensitiveDataSerializer3 + ]: + class OrderingListView(generics.ListAPIView): + queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') + filter_backends = (filters.OrderingFilter,) + serializer_class = serializer_cls + + view = OrderingListView.as_view() + request = factory.get('/', {'ordering': '-username'}) + response = view(request) + + if serializer_cls == SensitiveDataSerializer3: + username_field = 'user' + else: + username_field = 'username' + + # Note: Inverse username ordering correctly applied. + self.assertEqual( + response.data, + [ + {'id': 3, username_field: 'userC'}, + {'id': 2, username_field: 'userB'}, + {'id': 1, username_field: 'userA'}, + ] + ) + + def test_cannot_order_by_non_serializer_fields(self): + for serializer_cls in [ + SensitiveDataSerializer1, + SensitiveDataSerializer2, + SensitiveDataSerializer3 + ]: + class OrderingListView(generics.ListAPIView): + queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') + filter_backends = (filters.OrderingFilter,) + serializer_class = serializer_cls + + view = OrderingListView.as_view() + request = factory.get('/', {'ordering': 'password'}) + response = view(request) + + if serializer_cls == SensitiveDataSerializer3: + username_field = 'user' + else: + username_field = 'username' + + # Note: The passwords are not in order. Default ordering is used. + self.assertEqual( + response.data, + [ + {'id': 1, username_field: 'userA'}, # PassB + {'id': 2, username_field: 'userB'}, # PassC + {'id': 3, username_field: 'userC'}, # PassA + ] + ) diff --git a/rest_framework/tests/test_genericrelations.py b/rest_framework/tests/test_genericrelations.py index c38bfb9f..fa09c9e6 100644 --- a/rest_framework/tests/test_genericrelations.py +++ b/rest_framework/tests/test_genericrelations.py @@ -4,8 +4,10 @@ from django.contrib.contenttypes.generic import GenericRelation, GenericForeignK from django.db import models from django.test import TestCase from rest_framework import serializers +from rest_framework.compat import python_2_unicode_compatible +@python_2_unicode_compatible class Tag(models.Model): """ Tags have a descriptive slug, and are attached to an arbitrary object. @@ -15,10 +17,11 @@ class Tag(models.Model): object_id = models.PositiveIntegerField() tagged_item = GenericForeignKey('content_type', 'object_id') - def __unicode__(self): + def __str__(self): return self.tag +@python_2_unicode_compatible class Bookmark(models.Model): """ A URL bookmark that may have multiple tags attached. @@ -26,10 +29,11 @@ class Bookmark(models.Model): url = models.URLField() tags = GenericRelation(Tag) - def __unicode__(self): + def __str__(self): return 'Bookmark: %s' % self.url +@python_2_unicode_compatible class Note(models.Model): """ A textual note that may have multiple tags attached. @@ -37,7 +41,7 @@ class Note(models.Model): text = models.TextField() tags = GenericRelation(Tag) - def __unicode__(self): + def __str__(self): return 'Note: %s' % self.text @@ -69,6 +73,35 @@ class TestGenericRelations(TestCase): } self.assertEqual(serializer.data, expected) + def test_generic_nested_relation(self): + """ + Test saving a GenericRelation field via a nested serializer. + """ + + class TagSerializer(serializers.ModelSerializer): + class Meta: + model = Tag + exclude = ('content_type', 'object_id') + + class BookmarkSerializer(serializers.ModelSerializer): + tags = TagSerializer() + + class Meta: + model = Bookmark + exclude = ('id',) + + data = { + 'url': 'https://docs.djangoproject.com/', + 'tags': [ + {'tag': 'contenttypes'}, + {'tag': 'genericrelations'}, + ] + } + serializer = BookmarkSerializer(data=data) + self.assertTrue(serializer.is_valid()) + serializer.save() + self.assertEqual(serializer.object.tags.count(), 2) + def test_generic_fk(self): """ Test a relationship that spans a GenericForeignKey field. diff --git a/rest_framework/tests/test_htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py index 6c570dfd..1cbca04c 100644 --- a/rest_framework/tests/test_htmlrenderer.py +++ b/rest_framework/tests/test_htmlrenderer.py @@ -50,7 +50,7 @@ class TemplateHTMLRendererTests(TestCase): """ self.get_template = django.template.loader.get_template - def get_template(template_name): + def get_template(template_name, dirs=None): if template_name == 'example.html': return Template("example: {{ object }}") raise TemplateDoesNotExist(template_name) @@ -108,11 +108,13 @@ class TemplateHTMLRendererExceptionTests(TestCase): def test_not_found_html_view_with_template(self): response = self.client.get('/not_found') self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - self.assertEqual(response.content, six.b("404: Not found")) + self.assertTrue(response.content in ( + six.b("404: Not found"), six.b("404 Not Found"))) self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') def test_permission_denied_html_view_with_template(self): response = self.client.get('/permission_denied') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEqual(response.content, six.b("403: Permission denied")) + self.assertTrue(response.content in ( + six.b("403: Permission denied"), six.b("403 Forbidden"))) self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py index ea7f70f2..5fb1b47e 100644 --- a/rest_framework/tests/test_hyperlinkedserializers.py +++ b/rest_framework/tests/test_hyperlinkedserializers.py @@ -1,8 +1,9 @@ from __future__ import unicode_literals import json -from django.conf.urls import patterns, url from django.test import TestCase from rest_framework import generics, status, serializers +from django.conf.urls import patterns, url +from rest_framework.settings import api_settings from rest_framework.test import APIRequestFactory from rest_framework.tests.models import ( Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, @@ -331,3 +332,48 @@ class TestOverriddenURLField(TestCase): serializer.data, {'title': 'New blog post', 'url': 'foo bar'} ) + + +class TestURLFieldNameBySettings(TestCase): + urls = 'rest_framework.tests.test_hyperlinkedserializers' + + def setUp(self): + self.saved_url_field_name = api_settings.URL_FIELD_NAME + api_settings.URL_FIELD_NAME = 'global_url_field' + + class Serializer(serializers.HyperlinkedModelSerializer): + + class Meta: + model = BlogPost + fields = ('title', api_settings.URL_FIELD_NAME) + + self.Serializer = Serializer + self.obj = BlogPost.objects.create(title="New blog post") + + def tearDown(self): + api_settings.URL_FIELD_NAME = self.saved_url_field_name + + def test_overridden_url_field_name(self): + request = factory.get('/posts/') + serializer = self.Serializer(self.obj, context={'request': request}) + self.assertIn(api_settings.URL_FIELD_NAME, serializer.data) + + +class TestURLFieldNameByOptions(TestCase): + urls = 'rest_framework.tests.test_hyperlinkedserializers' + + def setUp(self): + class Serializer(serializers.HyperlinkedModelSerializer): + + class Meta: + model = BlogPost + fields = ('title', 'serializer_url_field') + url_field_name = 'serializer_url_field' + + self.Serializer = Serializer + self.obj = BlogPost.objects.create(title="New blog post") + + def test_overridden_url_field_name(self): + request = factory.get('/posts/') + serializer = self.Serializer(self.obj, context={'request': request}) + self.assertIn(self.Serializer.Meta.url_field_name, serializer.data) diff --git a/rest_framework/tests/test_nullable_fields.py b/rest_framework/tests/test_nullable_fields.py new file mode 100644 index 00000000..4812530e --- /dev/null +++ b/rest_framework/tests/test_nullable_fields.py @@ -0,0 +1,30 @@ +from django.core.urlresolvers import reverse + +from django.conf.urls import patterns, url +from rest_framework.test import APITestCase +from rest_framework.tests.models import NullableForeignKeySource +from rest_framework.tests.serializers import NullableFKSourceSerializer +from rest_framework.tests.views import NullableFKSourceDetail + + +urlpatterns = patterns( + '', + url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'), +) + + +class NullableForeignKeyTests(APITestCase): + """ + DRF should be able to handle nullable foreign keys when a test + Client POST/PUT request is made with its own serialized object. + """ + urls = 'rest_framework.tests.test_nullable_fields' + + def test_updating_object_with_null_fk(self): + obj = NullableForeignKeySource(name='example', target=None) + obj.save() + serialized_data = NullableFKSourceSerializer(obj).data + + response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data) + + self.assertEqual(response.data, serialized_data) diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index cadb515f..24c1ba39 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -9,14 +9,18 @@ from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel +from .models import FilterableItem factory = APIRequestFactory() +# Helper function to split arguments out of an url +def split_arguments_from_url(url): + if '?' not in url: + return url -class FilterableItem(models.Model): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() + path, args = url.split('?') + args = dict(r.split('=') for r in args.split('&')) + return path, args class RootView(generics.ListCreateAPIView): @@ -84,7 +88,7 @@ class IntegrationTestPagination(TestCase): self.assertNotEqual(response.data['next'], None) self.assertEqual(response.data['previous'], None) - request = factory.get(response.data['next']) + request = factory.get(*split_arguments_from_url(response.data['next'])) with self.assertNumQueries(2): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -93,7 +97,7 @@ class IntegrationTestPagination(TestCase): self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['previous'], None) - request = factory.get(response.data['next']) + request = factory.get(*split_arguments_from_url(response.data['next'])) with self.assertNumQueries(2): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -146,7 +150,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): EXPECTED_NUM_QUERIES = 2 - request = factory.get('/?decimal=15.20') + request = factory.get('/', {'decimal': '15.20'}) with self.assertNumQueries(EXPECTED_NUM_QUERIES): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -155,7 +159,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.assertNotEqual(response.data['next'], None) self.assertEqual(response.data['previous'], None) - request = factory.get(response.data['next']) + request = factory.get(*split_arguments_from_url(response.data['next'])) with self.assertNumQueries(EXPECTED_NUM_QUERIES): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -164,7 +168,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.assertEqual(response.data['next'], None) self.assertNotEqual(response.data['previous'], None) - request = factory.get(response.data['previous']) + request = factory.get(*split_arguments_from_url(response.data['previous'])) with self.assertNumQueries(EXPECTED_NUM_QUERIES): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -191,7 +195,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): view = BasicFilterFieldsRootView.as_view() - request = factory.get('/?decimal=15.20') + request = factory.get('/', {'decimal': '15.20'}) with self.assertNumQueries(2): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -200,7 +204,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.assertNotEqual(response.data['next'], None) self.assertEqual(response.data['previous'], None) - request = factory.get(response.data['next']) + request = factory.get(*split_arguments_from_url(response.data['next'])) with self.assertNumQueries(2): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -209,7 +213,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.assertEqual(response.data['next'], None) self.assertNotEqual(response.data['previous'], None) - request = factory.get(response.data['previous']) + request = factory.get(*split_arguments_from_url(response.data['previous'])) with self.assertNumQueries(2): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -317,7 +321,7 @@ class TestCustomPaginateByParam(TestCase): """ If paginate_by_param is set, the new kwarg should limit per view requests. """ - request = factory.get('/?page_size=5') + request = factory.get('/', {'page_size': 5}) response = self.view(request).render() self.assertEqual(response.data['count'], 13) self.assertEqual(response.data['results'], self.data[:5]) @@ -345,7 +349,7 @@ class TestMaxPaginateByParam(TestCase): """ If max_paginate_by is set, it should limit page size for the view. """ - request = factory.get('/?page_size=10') + request = factory.get('/', data={'page_size': 10}) response = self.view(request).render() self.assertEqual(response.data['count'], 13) self.assertEqual(response.data['results'], self.data[:5]) diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py index d19219c9..37ac826b 100644 --- a/rest_framework/tests/test_relations.py +++ b/rest_framework/tests/test_relations.py @@ -2,8 +2,10 @@ General tests for relational fields. """ from __future__ import unicode_literals +from django import get_version from django.db import models from django.test import TestCase +from django.utils import unittest from rest_framework import serializers from rest_framework.tests.models import BlogPost @@ -98,3 +100,45 @@ class RelatedFieldSourceTests(TestCase): obj = ClassWithQuerysetMethod() value = field.field_to_native(obj, 'field_name') self.assertEqual(value, ['BlogPost object']) + + # Regression for #1129 + def test_exception_for_incorect_fk(self): + """ + Check that the exception message are correct if the source field + doesn't exist. + """ + from rest_framework.tests.models import ManyToManySource + class Meta: + model = ManyToManySource + attrs = { + 'name': serializers.SlugRelatedField( + slug_field='name', source='banzai'), + 'Meta': Meta, + } + + TestSerializer = type(str('TestSerializer'), + (serializers.ModelSerializer,), attrs) + with self.assertRaises(AttributeError): + TestSerializer(data={'name': 'foo'}) + +@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6') +class RelatedFieldChoicesTests(TestCase): + """ + Tests for #1408 "Web browseable API doesn't have blank option on drop down list box" + https://github.com/tomchristie/django-rest-framework/issues/1408 + """ + def test_blank_option_is_added_to_choice_if_required_equals_false(self): + """ + + """ + post = BlogPost(title="Checking blank option is added") + post.save() + + queryset = BlogPost.objects.all() + field = serializers.RelatedField(required=False, queryset=queryset) + + choice_count = BlogPost.objects.count() + widget_count = len(field.widget.choices) + + self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added') + diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index d393b0c3..4d9da489 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -3,9 +3,7 @@ from django.db import models from django.test import TestCase from rest_framework import serializers - -class OneToOneTarget(models.Model): - name = models.CharField(max_length=100) +from .models import OneToOneTarget class OneToOneSource(models.Model): diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py index 9cb68233..460c02a9 100644 --- a/rest_framework/tests/test_renderers.py +++ b/rest_framework/tests/test_renderers.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from decimal import Decimal from django.conf.urls import patterns, url, include from django.core.cache import cache +from django.db import models from django.test import TestCase from django.utils import unittest from django.utils.translation import ugettext_lazy as _ @@ -35,6 +36,10 @@ expected_results = [ ] +class DummyTestModel(models.Model): + name = models.CharField(max_length=42, default='') + + class BasicRendererTests(TestCase): def test_expected_results(self): for value, renderer_cls, expected in expected_results: @@ -252,6 +257,18 @@ class RendererEndToEndTests(TestCase): self.assertEqual(resp.get('Content-Type', None), None) self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) + def test_contains_headers_of_api_response(self): + """ + Issue #1437 + + Test we display the headers of the API response and not those from the + HTML response + """ + resp = self.client.get('/html1') + self.assertContains(resp, '>GET, HEAD, OPTIONS<') + self.assertContains(resp, '>application/json<') + self.assertNotContains(resp, '>text/html; charset=utf-8<') + _flat_repr = '{"foo": ["bar", "baz"]}' _indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}' @@ -277,6 +294,20 @@ class JSONRendererTests(TestCase): ret = JSONRenderer().render(_('test')) self.assertEqual(ret, b'"test"') + def test_render_queryset_values(self): + o = DummyTestModel.objects.create(name='dummy') + qs = DummyTestModel.objects.values('id', 'name') + ret = JSONRenderer().render(qs) + data = json.loads(ret.decode('utf-8')) + self.assertEquals(data, [{'id': o.id, 'name': o.name}]) + + def test_render_queryset_values_list(self): + o = DummyTestModel.objects.create(name='dummy') + qs = DummyTestModel.objects.values_list('id', 'name') + ret = JSONRenderer().render(qs) + data = json.loads(ret.decode('utf-8')) + self.assertEquals(data, [[o.id, o.name]]) + def test_render_dict_abc_obj(self): class Dict(MutableMapping): def __init__(self): @@ -583,6 +614,10 @@ class CacheRenderTest(TestCase): method = getattr(self.client, http_method) resp = method(url) del resp.client, resp.request + try: + del resp.wsgi_request + except AttributeError: + pass return resp def test_obj_pickling(self): diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py index f07c31a3..e0da5fd4 100644 --- a/rest_framework/tests/test_request.py +++ b/rest_framework/tests/test_request.py @@ -68,6 +68,9 @@ class TestMethodOverloading(TestCase): request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) self.assertEqual(request.method, 'DELETE') + request = Request(factory.get('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) + self.assertEqual(request.method, 'DELETE') + class TestContentParsing(TestCase): def test_standard_behaviour_determines_no_content_GET(self): diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 6d9b85ee..a09bf6f5 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals from django.db import models from django.db.models.fields import BLANK_CHOICE_DASH from django.test import TestCase +from django.utils import unittest from django.utils.datastructures import MultiValueDict from django.utils.translation import ugettext_lazy as _ from rest_framework import serializers, fields, relations @@ -12,6 +13,31 @@ from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, Acti from rest_framework.tests.models import BasicModelSerializer import datetime import pickle +try: + import PIL +except: + PIL = None + + +if PIL is not None: + class AMOAFModel(RESTFrameworkModel): + char_field = models.CharField(max_length=1024, blank=True) + comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) + decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) + email_field = models.EmailField(max_length=1024, blank=True) + file_field = models.FileField(upload_to='test', max_length=1024, blank=True) + image_field = models.ImageField(upload_to='test', max_length=1024, blank=True) + slug_field = models.SlugField(max_length=1024, blank=True) + url_field = models.URLField(max_length=1024, blank=True) + + class DVOAFModel(RESTFrameworkModel): + positive_integer_field = models.PositiveIntegerField(blank=True) + positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) + email_field = models.EmailField(blank=True) + file_field = models.FileField(upload_to='test', blank=True) + image_field = models.ImageField(upload_to='test', blank=True) + slug_field = models.SlugField(blank=True) + url_field = models.URLField(blank=True) class SubComment(object): @@ -71,6 +97,15 @@ class ActionItemSerializer(serializers.ModelSerializer): class Meta: model = ActionItem +class ActionItemSerializerOptionalFields(serializers.ModelSerializer): + """ + Intended to test that fields with `required=False` are excluded from validation. + """ + title = serializers.CharField(required=False) + + class Meta: + model = ActionItem + fields = ('title',) class ActionItemSerializerCustomRestore(serializers.ModelSerializer): @@ -132,7 +167,7 @@ class AlbumsSerializer(serializers.ModelSerializer): class Meta: model = Album - fields = ['title'] # lists are also valid options + fields = ['title', 'ref'] # lists are also valid options class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): @@ -288,7 +323,13 @@ class BasicTests(TestCase): serializer.save() self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.') - + def test_fields_marked_as_not_required_are_excluded_from_validation(self): + """ + Check that fields with `required=False` are included in list of exclusions. + """ + serializer = ActionItemSerializerOptionalFields(self.actionitem) + exclusions = serializer.get_validation_exclusions() + self.assertTrue('title' in exclusions, '`title` field was marked `required=False` and should be excluded') class DictStyleSerializer(serializers.Serializer): @@ -467,6 +508,32 @@ class ValidationTests(TestCase): ) self.assertEqual(serializer.is_valid(), True) + def test_writable_star_source_on_nested_serializer_with_parent_object(self): + class TitleSerializer(serializers.Serializer): + title = serializers.WritableField(source='title') + + class AlbumSerializer(serializers.ModelSerializer): + nested = TitleSerializer(source='*') + + class Meta: + model = Album + fields = ('nested',) + + class PhotoSerializer(serializers.ModelSerializer): + album = AlbumSerializer(source='album') + + class Meta: + model = Photo + fields = ('album', ) + + photo = Photo(album=Album()) + + data = {'album': {'nested': {'title': 'test'}}} + + serializer = PhotoSerializer(photo, data=data) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.data, data) + def test_writable_star_source_with_inner_source_fields(self): """ Tests that a serializer with source="*" correctly expands the @@ -576,12 +643,15 @@ class ModelValidationTests(TestCase): """ Just check if serializers.ModelSerializer handles unique checks via .full_clean() """ - serializer = AlbumsSerializer(data={'title': 'a'}) + serializer = AlbumsSerializer(data={'title': 'a', 'ref': '1'}) serializer.is_valid() serializer.save() second_serializer = AlbumsSerializer(data={'title': 'a'}) self.assertFalse(second_serializer.is_valid()) - self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']}) + self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.'],}) + third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}]) + self.assertFalse(third_serializer.is_valid()) + self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}]) def test_foreign_key_is_null_with_partial(self): """ @@ -865,6 +935,58 @@ class DefaultValueTests(TestCase): self.assertEqual(instance.text, 'overridden') +class WritableFieldDefaultValueTests(TestCase): + + def setUp(self): + self.expected = {'default': 'value'} + self.create_field = fields.WritableField + + def test_get_default_value_with_noncallable(self): + field = self.create_field(default=self.expected) + got = field.get_default_value() + self.assertEqual(got, self.expected) + + def test_get_default_value_with_callable(self): + field = self.create_field(default=lambda : self.expected) + got = field.get_default_value() + self.assertEqual(got, self.expected) + + def test_get_default_value_when_not_required(self): + field = self.create_field(default=self.expected, required=False) + got = field.get_default_value() + self.assertEqual(got, self.expected) + + def test_get_default_value_returns_None(self): + field = self.create_field() + got = field.get_default_value() + self.assertIsNone(got) + + def test_get_default_value_returns_non_True_values(self): + values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause + for expected in values: + field = self.create_field(default=expected) + got = field.get_default_value() + self.assertEqual(got, expected) + + +class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests): + + def setUp(self): + self.expected = {'foo': 'bar'} + self.create_field = relations.RelatedField + + def test_get_default_value_returns_empty_list(self): + field = self.create_field(many=True) + got = field.get_default_value() + self.assertListEqual(got, []) + + def test_get_default_value_returns_expected(self): + expected = [1, 2, 3] + field = self.create_field(many=True, default=expected) + got = field.get_default_value() + self.assertListEqual(got, expected) + + class CallableDefaultValueTests(TestCase): def setUp(self): class CallableDefaultValueSerializer(serializers.ModelSerializer): @@ -1492,19 +1614,10 @@ class ManyFieldHelpTextTest(TestCase): self.assertEqual('Some help text.', rel_field.help_text) +@unittest.skipUnless(PIL is not None, 'PIL is not installed') class AttributeMappingOnAutogeneratedFieldsTests(TestCase): def setUp(self): - class AMOAFModel(RESTFrameworkModel): - char_field = models.CharField(max_length=1024, blank=True) - comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) - decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) - email_field = models.EmailField(max_length=1024, blank=True) - file_field = models.FileField(max_length=1024, blank=True) - image_field = models.ImageField(max_length=1024, blank=True) - slug_field = models.SlugField(max_length=1024, blank=True) - url_field = models.URLField(max_length=1024, blank=True) - nullable_char_field = models.CharField(max_length=1024, blank=True, null=True) class AMOAFSerializer(serializers.ModelSerializer): class Meta: @@ -1581,17 +1694,10 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase): self.field_test('nullable_char_field') +@unittest.skipUnless(PIL is not None, 'PIL is not installed') class DefaultValuesOnAutogeneratedFieldsTests(TestCase): def setUp(self): - class DVOAFModel(RESTFrameworkModel): - positive_integer_field = models.PositiveIntegerField(blank=True) - positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) - email_field = models.EmailField(blank=True) - file_field = models.FileField(blank=True) - image_field = models.ImageField(blank=True) - slug_field = models.SlugField(blank=True) - url_field = models.URLField(blank=True) class DVOAFSerializer(serializers.ModelSerializer): class Meta: @@ -1830,14 +1936,14 @@ class SerializerDefaultTrueBoolean(TestCase): self.assertEqual(serializer.data['cat'], False) self.assertEqual(serializer.data['dog'], False) - + class BoolenFieldTypeTest(TestCase): ''' Ensure the various Boolean based model fields are rendered as the proper field type - + ''' - + def setUp(self): ''' Setup an ActionItemSerializer for BooleanTesting @@ -1853,11 +1959,11 @@ class BoolenFieldTypeTest(TestCase): ''' bfield = self.serializer.get_fields()['done'] self.assertEqual(type(bfield), fields.BooleanField) - + def test_nullbooleanfield_type(self): ''' - Test that BooleanField is infered from models.NullBooleanField - + Test that BooleanField is infered from models.NullBooleanField + https://groups.google.com/forum/#!topic/django-rest-framework/D9mXEftpuQ8 ''' bfield = self.serializer.get_fields()['started'] diff --git a/rest_framework/tests/test_serializer_import.py b/rest_framework/tests/test_serializer_import.py new file mode 100644 index 00000000..9f30a7ff --- /dev/null +++ b/rest_framework/tests/test_serializer_import.py @@ -0,0 +1,19 @@ +from django.test import TestCase + +from rest_framework import serializers +from rest_framework.tests.accounts.serializers import AccountSerializer + + +class ImportingModelSerializerTests(TestCase): + """ + In some situations like, GH #1225, it is possible, especially in + testing, to import a serializer who's related models have not yet + been resolved by Django. `AccountSerializer` is an example of such + a serializer (imported at the top of this file). + """ + def test_import_model_serializer(self): + """ + The serializer at the top of this file should have been + imported successfully, and we should be able to instantiate it. + """ + self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer) diff --git a/rest_framework/tests/test_serializer_nested.py b/rest_framework/tests/test_serializer_nested.py index 7114a060..6d69ffbd 100644 --- a/rest_framework/tests/test_serializer_nested.py +++ b/rest_framework/tests/test_serializer_nested.py @@ -345,4 +345,3 @@ class NestedModelSerializerUpdateTests(TestCase): result = deserialize.object result.save() self.assertEqual(result.id, john.id) - diff --git a/rest_framework/tests/test_serializers.py b/rest_framework/tests/test_serializers.py new file mode 100644 index 00000000..082a400c --- /dev/null +++ b/rest_framework/tests/test_serializers.py @@ -0,0 +1,28 @@ +from django.db import models +from django.test import TestCase + +from rest_framework.serializers import _resolve_model +from rest_framework.tests.models import BasicModel + + +class ResolveModelTests(TestCase): + """ + `_resolve_model` should return a Django model class given the + provided argument is a Django model class itself, or a properly + formatted string representation of one. + """ + def test_resolve_django_model(self): + resolved_model = _resolve_model(BasicModel) + self.assertEqual(resolved_model, BasicModel) + + def test_resolve_string_representation(self): + resolved_model = _resolve_model('tests.BasicModel') + self.assertEqual(resolved_model, BasicModel) + + def test_resolve_non_django_model(self): + with self.assertRaises(ValueError): + _resolve_model(TestCase) + + def test_resolve_improper_string_representation(self): + with self.assertRaises(ValueError): + _resolve_model('BasicModel') diff --git a/rest_framework/tests/test_templatetags.py b/rest_framework/tests/test_templatetags.py new file mode 100644 index 00000000..d4da0c23 --- /dev/null +++ b/rest_framework/tests/test_templatetags.py @@ -0,0 +1,51 @@ +# encoding: utf-8 +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.test import APIRequestFactory +from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links + +factory = APIRequestFactory() + + +class TemplateTagTests(TestCase): + + def test_add_query_param_with_non_latin_charactor(self): + # Ensure we don't double-escape non-latin characters + # that are present in the querystring. + # See #1314. + request = factory.get("/", {'q': '查询'}) + json_url = add_query_param(request, "format", "json") + self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url) + self.assertIn("format=json", json_url) + + +class Issue1386Tests(TestCase): + """ + Covers #1386 + """ + + def test_issue_1386(self): + """ + Test function urlize_quoted_links with different args + """ + correct_urls = [ + "asdf.com", + "asdf.net", + "www.as_df.org", + "as.d8f.ghj8.gov", + ] + for i in correct_urls: + res = urlize_quoted_links(i) + self.assertNotEqual(res, i) + self.assertIn(i, res) + + incorrect_urls = [ + "mailto://asdf@fdf.com", + "asdf.netnet", + ] + for i in incorrect_urls: + res = urlize_quoted_links(i) + self.assertEqual(i, res) + + # example from issue #1386, this shouldn't raise an exception + _ = urlize_quoted_links("asdf:[/p]zxcv.com") diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index c08dd493..83ae8148 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -2,6 +2,8 @@ from __future__ import unicode_literals from django.conf.urls import patterns, url +from io import BytesIO + from django.contrib.auth.models import User from django.test import TestCase from rest_framework.decorators import api_view @@ -143,3 +145,20 @@ class TestAPIRequestFactory(TestCase): force_authenticate(request, user=user) response = view(request) self.assertEqual(response.data['user'], 'example') + + def test_upload_file(self): + # This is a 1x1 black png + simple_png = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82') + simple_png.name = 'test.png' + factory = APIRequestFactory() + factory.post('/', data={'image': simple_png}) + + def test_request_factory_url_arguments(self): + """ + This is a non regression test against #1461 + """ + factory = APIRequestFactory() + request = factory.get('/view/?demo=test') + self.assertEqual(dict(request.GET), {'demo': ['test']}) + request = factory.get('/view/', {'demo': 'test'}) + self.assertEqual(dict(request.GET), {'demo': ['test']}) diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py index 124c874d..e13e4078 100644 --- a/rest_framework/tests/test_validation.py +++ b/rest_framework/tests/test_validation.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals +from django.core.validators import MaxValueValidator from django.db import models from django.test import TestCase from rest_framework import generics, serializers, status @@ -102,3 +103,46 @@ class TestAvoidValidation(TestCase): self.assertFalse(serializer.is_valid()) self.assertDictEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) + + +# regression tests for issue: 1493 + +class ValidationMaxValueValidatorModel(models.Model): + number_value = models.PositiveIntegerField(validators=[MaxValueValidator(100)]) + + +class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer): + class Meta: + model = ValidationMaxValueValidatorModel + + +class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView): + model = ValidationMaxValueValidatorModel + serializer_class = ValidationMaxValueValidatorModelSerializer + + +class TestMaxValueValidatorValidation(TestCase): + + def test_max_value_validation_serializer_success(self): + serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 99}) + self.assertTrue(serializer.is_valid()) + + def test_max_value_validation_serializer_fails(self): + serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101}) + self.assertFalse(serializer.is_valid()) + self.assertDictEqual({'number_value': ['Ensure this value is less than or equal to 100.']}, serializer.errors) + + def test_max_value_validation_success(self): + obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) + request = factory.patch('/{0}'.format(obj.pk), {'number_value': 98}, format='json') + view = UpdateMaxValueValidationModel().as_view() + response = view(request, pk=obj.pk).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_max_value_validation_fail(self): + obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) + request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json') + view = UpdateMaxValueValidationModel().as_view() + response = view(request, pk=obj.pk).render() + self.assertEqual(response.content, b'{"number_value": ["Ensure this value is less than or equal to 100."]}') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/rest_framework/tests/test_write_only_fields.py b/rest_framework/tests/test_write_only_fields.py new file mode 100644 index 00000000..aabb18d6 --- /dev/null +++ b/rest_framework/tests/test_write_only_fields.py @@ -0,0 +1,42 @@ +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class ExampleModel(models.Model): + email = models.EmailField(max_length=100) + password = models.CharField(max_length=100) + + +class WriteOnlyFieldTests(TestCase): + def test_write_only_fields(self): + class ExampleSerializer(serializers.Serializer): + email = serializers.EmailField() + password = serializers.CharField(write_only=True) + + data = { + 'email': 'foo@example.com', + 'password': '123' + } + serializer = ExampleSerializer(data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.object, data) + self.assertEquals(serializer.data, {'email': 'foo@example.com'}) + + def test_write_only_fields_meta(self): + class ExampleSerializer(serializers.ModelSerializer): + class Meta: + model = ExampleModel + fields = ('email', 'password') + write_only_fields = ('password',) + + data = { + 'email': 'foo@example.com', + 'password': '123' + } + serializer = ExampleSerializer(data=data) + self.assertTrue(serializer.is_valid()) + self.assertTrue(isinstance(serializer.object, ExampleModel)) + self.assertEquals(serializer.object.email, data['email']) + self.assertEquals(serializer.object.password, data['password']) + self.assertEquals(serializer.data, {'email': 'foo@example.com'}) diff --git a/rest_framework/tests/users/__init__.py b/rest_framework/tests/users/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/users/__init__.py diff --git a/rest_framework/tests/users/models.py b/rest_framework/tests/users/models.py new file mode 100644 index 00000000..128bac90 --- /dev/null +++ b/rest_framework/tests/users/models.py @@ -0,0 +1,6 @@ +from django.db import models + + +class User(models.Model): + account = models.ForeignKey('accounts.Account', blank=True, null=True, related_name='users') + active_record = models.ForeignKey('records.Record', blank=True, null=True) diff --git a/rest_framework/tests/users/serializers.py b/rest_framework/tests/users/serializers.py new file mode 100644 index 00000000..da496554 --- /dev/null +++ b/rest_framework/tests/users/serializers.py @@ -0,0 +1,8 @@ +from rest_framework import serializers + +from rest_framework.tests.users.models import User + + +class UserSerializer(serializers.ModelSerializer): + class Meta: + model = User diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py new file mode 100644 index 00000000..a8f2eb0b --- /dev/null +++ b/rest_framework/tests/utils.py @@ -0,0 +1,25 @@ +from contextlib import contextmanager +from rest_framework.compat import six +from rest_framework.settings import api_settings + + +@contextmanager +def temporary_setting(setting, value, module=None): + """ + Temporarily change value of setting for test. + + Optionally reload given module, useful when module uses value of setting on + import. + """ + original_value = getattr(api_settings, setting) + setattr(api_settings, setting, value) + + if module is not None: + six.moves.reload_module(module) + + yield + + setattr(api_settings, setting, original_value) + + if module is not None: + six.moves.reload_module(module) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py new file mode 100644 index 00000000..3917b74a --- /dev/null +++ b/rest_framework/tests/views.py @@ -0,0 +1,8 @@ +from rest_framework import generics +from rest_framework.tests.models import NullableForeignKeySource +from rest_framework.tests.serializers import NullableFKSourceSerializer + + +class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView): + model = NullableForeignKeySource + model_serializer_class = NullableFKSourceSerializer diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index c40f3065..fc24c92e 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -155,6 +155,8 @@ class SimpleRateThrottle(BaseThrottle): remaining_duration = self.duration available_requests = self.num_requests - len(self.history) + 1 + if available_requests <= 0: + return None return remaining_duration / float(available_requests) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 229b0b28..c125ac8a 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -3,6 +3,7 @@ Helper classes for parsers. """ from __future__ import unicode_literals from django.utils import timezone +from django.db.models.query import QuerySet from django.utils.datastructures import SortedDict from django.utils.functional import Promise from rest_framework.compat import force_text @@ -43,6 +44,8 @@ class JSONEncoder(json.JSONEncoder): return str(o.total_seconds()) elif isinstance(o, decimal.Decimal): return str(o) + elif isinstance(o, QuerySet): + return list(o) elif hasattr(o, 'tolist'): return o.tolist() elif hasattr(o, '__getitem__'): diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py index c09c2933..92f99efd 100644 --- a/rest_framework/utils/mediatypes.py +++ b/rest_framework/utils/mediatypes.py @@ -74,7 +74,7 @@ class _MediaType(object): return 0 elif self.sub_type == '*': return 1 - elif not self.params or self.params.keys() == ['q']: + elif not self.params or list(self.params.keys()) == ['q']: return 2 return 3 diff --git a/rest_framework/views.py b/rest_framework/views.py index e863af6d..a2668f2c 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -112,12 +112,13 @@ class APIView(View): @property def default_response_headers(self): - # TODO: deprecate? - # TODO: Only vary by accept if multiple renderers - return { + headers = { 'Allow': ', '.join(self.allowed_methods), - 'Vary': 'Accept' } + if len(self.renderer_classes) > 1: + headers['Vary'] = 'Accept' + return headers + def http_method_not_allowed(self, request, *args, **kwargs): """ @@ -130,7 +131,7 @@ class APIView(View): """ If request is not permitted, determine what kind of exception to raise. """ - if not self.request.successful_authenticator: + if not request.successful_authenticator: raise exceptions.NotAuthenticated() raise exceptions.PermissionDenied() @@ -294,7 +295,7 @@ class APIView(View): # Dispatch methods - def initialize_request(self, request, *args, **kargs): + def initialize_request(self, request, *args, **kwargs): """ Returns the initial request object. """ |
