diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/__init__.py | 2 | ||||
| -rw-r--r-- | rest_framework/authentication.py | 12 | ||||
| -rw-r--r-- | rest_framework/compat.py | 10 | ||||
| -rw-r--r-- | rest_framework/generics.py | 21 | ||||
| -rw-r--r-- | rest_framework/permissions.py | 2 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 103 | ||||
| -rw-r--r-- | rest_framework/templates/rest_framework/base.html | 2 | ||||
| -rw-r--r-- | rest_framework/templatetags/rest_framework.py | 13 | ||||
| -rw-r--r-- | rest_framework/tests/generics.py | 42 | ||||
| -rw-r--r-- | rest_framework/tests/relations_hyperlink.py | 16 | ||||
| -rw-r--r-- | rest_framework/tests/relations_nested.py | 24 | ||||
| -rw-r--r-- | rest_framework/tests/relations_pk.py | 17 | ||||
| -rw-r--r-- | rest_framework/tests/serializer.py | 37 | ||||
| -rw-r--r-- | rest_framework/views.py | 8 |
14 files changed, 229 insertions, 80 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index c86403d8..856badc6 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.2.5' +__version__ = '2.2.7' VERSION = __version__ # synonym diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 145d4295..1eebb5b9 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -10,7 +10,7 @@ from django.core.exceptions import ImproperlyConfigured from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import oauth, oauth_provider, oauth_provider_store -from rest_framework.compat import oauth2_provider, oauth2_provider_forms +from rest_framework.compat import oauth2_provider from rest_framework.authtoken.models import Token @@ -230,7 +230,7 @@ class OAuthAuthentication(BaseAuthentication): try: consumer_key = oauth_request.get_parameter('oauth_consumer_key') consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key) - except oauth_provider_store.InvalidConsumerError as err: + except oauth_provider.store.InvalidConsumerError as err: raise exceptions.AuthenticationFailed(err) if consumer.status != oauth_provider.consts.ACCEPTED: @@ -240,7 +240,7 @@ class OAuthAuthentication(BaseAuthentication): try: token_param = oauth_request.get_parameter('oauth_token') token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param) - except oauth_provider_store.InvalidTokenError: + except oauth_provider.store.InvalidTokenError: msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token') raise exceptions.AuthenticationFailed(msg) @@ -325,11 +325,13 @@ class OAuth2Authentication(BaseAuthentication): except oauth2_provider.models.AccessToken.DoesNotExist: raise exceptions.AuthenticationFailed('Invalid token') - if not token.user.is_active: + user = token.user + + if not user.is_active: msg = 'User inactive or deleted: %s' % user.username raise exceptions.AuthenticationFailed(msg) - return (token.user, token) + return (user, token) def authenticate_header(self, request): """ diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 8bfebe68..3828555b 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -404,19 +404,23 @@ except ImportError: try: from django.utils.html import smart_urlquote except ImportError: + import re + from django.utils.encoding import smart_str try: from urllib.parse import quote, urlsplit, urlunsplit except ImportError: # Python 2 from urllib import quote from urlparse import urlsplit, urlunsplit + unquoted_percents_re = re.compile(r'%(?![0-9A-Fa-f]{2})') + def smart_urlquote(url): "Quotes a URL if it isn't already quoted." # Handle IDN before quoting. scheme, netloc, path, query, fragment = urlsplit(url) try: - netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE - except UnicodeError: # invalid domain part + netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE + except UnicodeError: # invalid domain part pass else: url = urlunsplit((scheme, netloc, path, query, fragment)) @@ -425,7 +429,7 @@ except ImportError: # contains a % not followed by two hexadecimal digits. See #9655. if '%' not in url or unquoted_percents_re.search(url): # See http://bugs.python.org/issue2637 - url = quote(force_bytes(url), safe=b'!*\'();:@&=+$,/?#[]~') + url = quote(smart_str(url), safe=b'!*\'();:@&=+$,/?#[]~') return force_text(url) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index ea62123d..ae03060b 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -21,19 +21,21 @@ class GenericAPIView(views.APIView): queryset = None serializer_class = None + # Shortcut which may be used in place of `queryset`/`serializer_class` + model = None + filter_backend = api_settings.FILTER_BACKEND paginate_by = api_settings.PAGINATE_BY paginate_by_param = api_settings.PAGINATE_BY_PARAM pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS + model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS page_kwarg = 'page' - lookup_kwarg = 'pk' + lookup_field = 'pk' allow_empty = True ###################################### - # These are all pending deprecation... + # These are pending deprecation... - model = None - model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS pk_url_kwarg = 'pk' slug_url_kwarg = 'slug' slug_field = 'slug' @@ -116,6 +118,11 @@ class GenericAPIView(views.APIView): def filter_queryset(self, queryset): """ Given a queryset, filter it with whichever filter backend is in use. + + You are unlikely to want to override this method, although you may need + to call it either from a list view, or from a custom `get_object` + method if you want to apply the configured filtering backend to the + default queryset. """ if not self.filter_backend: return queryset @@ -164,7 +171,7 @@ class GenericAPIView(views.APIView): # TODO: Deprecation warning return self.model._default_manager.all() - raise ImproperlyConfigured("'%s' must define 'queryset'" + raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" % self.__class__.__name__) def get_object(self, queryset=None): @@ -184,10 +191,10 @@ class GenericAPIView(views.APIView): # Perform the lookup filtering. pk = self.kwargs.get(self.pk_url_kwarg, None) slug = self.kwargs.get(self.slug_url_kwarg, None) - lookup = self.kwargs.get(self.lookup_kwarg, None) + lookup = self.kwargs.get(self.lookup_field, None) if lookup is not None: - filter_kwargs = {self.lookup_kwarg: lookup} + filter_kwargs = {self.lookup_field: lookup} elif pk is not None: # TODO: Deprecation warning filter_kwargs = {'pk': pk} diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index ae895f39..2aa45c71 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -25,7 +25,7 @@ class BasePermission(object): """ Return `True` if permission is granted, `False` otherwise. """ - if len(inspect.getargspec(self.has_permission)[0]) == 4: + if len(inspect.getargspec(self.has_permission).args) == 4: warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. ' 'Use `has_object_permission()` instead for object permissions.', PendingDeprecationWarning, stacklevel=2) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 1b2b0821..b4327af1 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -568,36 +568,73 @@ class ModelSerializer(Serializer): assert cls is not None, \ "Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__ opts = get_concrete_model(cls)._meta - pk_field = opts.pk + ret = SortedDict() + nested = bool(self.opts.depth) - # If model is a child via multitable inheritance, use parent's pk + # Deal with adding the primary key field + pk_field = opts.pk while pk_field.rel and pk_field.rel.parent_link: + # If model is a child via multitable inheritance, use parent's pk pk_field = pk_field.rel.to._meta.pk - fields = [pk_field] - fields += [field for field in opts.fields if field.serialize] - fields += [field for field in opts.many_to_many if field.serialize] + field = self.get_pk_field(pk_field) + if field: + ret[pk_field.name] = field - ret = SortedDict() - nested = bool(self.opts.depth) - is_pk = True # First field in the list is the pk - - for model_field in fields: - if is_pk: - field = self.get_pk_field(model_field) - is_pk = False - elif model_field.rel and nested: - field = self.get_nested_field(model_field) - elif model_field.rel: + # Deal with forward relationships + forward_rels = [field for field in opts.fields if field.serialize] + forward_rels += [field for field in opts.many_to_many if field.serialize] + + for model_field in forward_rels: + if model_field.rel: to_many = isinstance(model_field, models.fields.related.ManyToManyField) - field = self.get_related_field(model_field, to_many=to_many) + related_model = model_field.rel.to + + if model_field.rel and nested: + if len(inspect.getargspec(self.get_nested_field).args) == 2: + # TODO: deprecation warning + field = self.get_nested_field(model_field) + else: + field = self.get_nested_field(model_field, related_model, to_many) + elif model_field.rel: + if len(inspect.getargspec(self.get_nested_field).args) == 3: + # TODO: deprecation warning + field = self.get_related_field(model_field, to_many=to_many) + else: + field = self.get_related_field(model_field, related_model, to_many) else: field = self.get_field(model_field) if field: ret[model_field.name] = field + # Deal with reverse relationships + if not self.opts.fields: + reverse_rels = [] + else: + # Reverse relationships are only included if they are explicitly + # present in the `fields` option on the serializer + reverse_rels = opts.get_all_related_objects() + reverse_rels += opts.get_all_related_many_to_many_objects() + + for relation in reverse_rels: + accessor_name = relation.get_accessor_name() + if accessor_name not in self.opts.fields: + continue + related_model = relation.model + to_many = relation.field.rel.multiple + + if nested: + field = self.get_nested_field(None, related_model, to_many) + else: + field = self.get_related_field(None, related_model, to_many) + + if field: + ret[accessor_name] = field + + # Add the `read_only` flag to any fields that have bee specified + # in the `read_only_fields` option for field_name in self.opts.read_only_fields: assert field_name in ret, \ "read_only_fields on '%s' included invalid item '%s'" % \ @@ -612,27 +649,30 @@ class ModelSerializer(Serializer): """ return self.get_field(model_field) - def get_nested_field(self, model_field): + def get_nested_field(self, model_field, related_model, to_many): """ Creates a default instance of a nested relational field. """ class NestedModelSerializer(ModelSerializer): class Meta: - model = model_field.rel.to - return NestedModelSerializer() + model = related_model + return NestedModelSerializer(many=to_many) - def get_related_field(self, model_field, to_many=False): + def get_related_field(self, model_field, related_model, to_many): """ Creates a default instance of a flat relational field. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) + kwargs = { - 'required': not(model_field.null or model_field.blank), - 'queryset': model_field.rel.to._default_manager, + 'queryset': related_model._default_manager, 'many': to_many } + if model_field: + kwargs['required'] = not(model_field.null or model_field.blank) + return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): @@ -741,7 +781,7 @@ class ModelSerializer(Serializer): Override the default method to also include model field validation. """ instance = super(ModelSerializer, self).from_native(data, files) - if instance: + if not self._errors: return self.full_clean(instance) def save_object(self, obj, **kwargs): @@ -797,21 +837,24 @@ class HyperlinkedModelSerializer(ModelSerializer): return self._default_view_name % format_kwargs def get_pk_field(self, model_field): - return None + if self.opts.fields and model_field.name in self.opts.fields: + return self.get_field(model_field) - def get_related_field(self, model_field, to_many): + def get_related_field(self, model_field, related_model, to_many): """ Creates a default instance of a flat relational field. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - rel = model_field.rel.to kwargs = { - 'required': not(model_field.null or model_field.blank), - 'queryset': rel._default_manager, - 'view_name': self._get_default_view_name(rel), + 'queryset': related_model._default_manager, + 'view_name': self._get_default_view_name(related_model), 'many': to_many } + + if model_field: + kwargs['required'] = not(model_field.null or model_field.blank) + return HyperlinkedRelatedField(**kwargs) def get_identity(self, data): diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 44633f5a..4410f285 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -115,7 +115,7 @@ </div> <div class="response-info"> <pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %} -{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|urlize_quoted_links }}</span> +{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span> {% endfor %} </div>{{ content|urlize_quoted_links }}</pre>{% endautoescape %} </div> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index b6ab2de3..c86b6456 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -181,7 +181,7 @@ TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"] WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), ('"', '"'), ("'", "'")] word_split_re = re.compile(r'(\s+)') -simple_url_re = re.compile(r'^https?://\w', re.IGNORECASE) +simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE) simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE) simple_email_re = re.compile(r'^\S+@\S+\.\S+$') @@ -260,3 +260,14 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru elif autoescape: words[i] = escape(word) return ''.join(words) + + +@register.filter +def break_long_headers(header): + """ + Breaks headers longer than 160 characters (~page length) + when possible (are comma separated) + """ + if len(header) > 160 and ',' in header: + header = mark_safe('<br> ' + ', <br>'.join(header.split(','))) + return header diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index f564890c..4a13389a 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals from django.db import models +from django.shortcuts import get_object_or_404 from django.test import TestCase from rest_framework import generics, serializers, status from rest_framework.tests.utils import RequestFactory @@ -302,6 +303,47 @@ class TestInstanceView(TestCase): self.assertEqual(new_obj.text, 'foobar') +class TestOverriddenGetObject(TestCase): + """ + Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the + queryset/model mechanism but instead overrides get_object() + """ + def setUp(self): + """ + Create 3 BasicModel intances. + """ + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + + class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): + """ + Example detail view for override of get_object(). + """ + model = BasicModel + + def get_object(self): + pk = int(self.kwargs['pk']) + return get_object_or_404(BasicModel.objects.all(), id=pk) + + self.view = OverriddenGetObjectView.as_view() + + def test_overridden_get_object_view(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should return a single object. + """ + request = factory.get('/1') + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0]) + + # Regression test for #285 class CommentSerializer(serializers.ModelSerializer): diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index b5702a48..b1eed9a7 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -26,42 +26,44 @@ urlpatterns = patterns('', ) +# ManyToMany class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): - sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail') - class Meta: model = ManyToManyTarget + fields = ('url', 'name', 'sources') class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = ManyToManySource + fields = ('url', 'name', 'targets') +# ForeignKey class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): - sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail') - class Meta: model = ForeignKeyTarget + fields = ('url', 'name', 'sources') class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = ForeignKeySource + fields = ('url', 'name', 'target') # Nullable ForeignKey class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = NullableForeignKeySource + fields = ('url', 'name', 'target') -# OneToOne +# Nullable OneToOne class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): - nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail') - class Meta: model = OneToOneTarget + fields = ('url', 'name', 'nullable_source') # TODO: Add test that .data cannot be accessed prior to .is_valid diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index a125ba65..f6d006b3 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -6,38 +6,30 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null class ForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: - depth = 1 - model = ForeignKeySource - - -class FlatForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: model = ForeignKeySource + fields = ('id', 'name', 'target') + depth = 1 class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = FlatForeignKeySourceSerializer(many=True) - class Meta: model = ForeignKeyTarget + fields = ('id', 'name', 'sources') + depth = 1 class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: - depth = 1 model = NullableForeignKeySource - - -class NullableOneToOneSourceSerializer(serializers.ModelSerializer): - class Meta: - model = NullableOneToOneSource + fields = ('id', 'name', 'target') + depth = 1 class NullableOneToOneTargetSerializer(serializers.ModelSerializer): - nullable_source = NullableOneToOneSourceSerializer() - class Meta: model = OneToOneTarget + fields = ('id', 'name', 'nullable_source') + depth = 1 class ReverseForeignKeyTests(TestCase): diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index f08e1808..5ce8b567 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -5,41 +5,44 @@ from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, Fore from rest_framework.compat import six +# ManyToMany class ManyToManyTargetSerializer(serializers.ModelSerializer): - sources = serializers.PrimaryKeyRelatedField(many=True) - class Meta: model = ManyToManyTarget + fields = ('id', 'name', 'sources') class ManyToManySourceSerializer(serializers.ModelSerializer): class Meta: model = ManyToManySource + fields = ('id', 'name', 'targets') +# ForeignKey class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = serializers.PrimaryKeyRelatedField(many=True) - class Meta: model = ForeignKeyTarget + fields = ('id', 'name', 'sources') class ForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = ForeignKeySource + fields = ('id', 'name', 'target') +# Nullable ForeignKey class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = NullableForeignKeySource + fields = ('id', 'name', 'target') -# OneToOne +# Nullable OneToOne class NullableOneToOneTargetSerializer(serializers.ModelSerializer): - nullable_source = serializers.PrimaryKeyRelatedField() - class Meta: model = OneToOneTarget + fields = ('id', 'name', 'nullable_source') # TODO: Add test that .data cannot be accessed prior to .is_valid diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 05217f35..3a94fad5 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -738,6 +738,43 @@ class ManyRelatedTests(TestCase): self.assertEqual(serializer.data, expected) + def test_include_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + fields = ('id', 'title', 'blogpostcomment_set') + + serializer = BlogPostSerializer(instance=post) + expected = { + 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2] + } + self.assertEqual(serializer.data, expected) + + def test_depth_include_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + fields = ('id', 'title', 'blogpostcomment_set') + depth = 1 + + serializer = BlogPostSerializer(instance=post) + expected = { + 'id': 1, 'title': 'Test blog post', + 'blogpostcomment_set': [ + {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1}, + {'id': 2, 'text': 'I love this blog post', 'blog_post': 1} + ] + } + self.assertEqual(serializer.data, expected) + def test_callable_source(self): post = BlogPost.objects.create(title="Test blog post") post.blogpostcomment_set.create(text="I love this blog post") diff --git a/rest_framework/views.py b/rest_framework/views.py index d7d3a2e2..b8e948e0 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,7 +3,7 @@ Provides an APIView class that is used as the base of all class-based views. """ from __future__ import unicode_literals from django.core.exceptions import PermissionDenied -from django.http import Http404 +from django.http import Http404, HttpResponse from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions from rest_framework.compat import View @@ -256,6 +256,12 @@ class APIView(View): """ Returns the final response object. """ + # Make the error obvious if a proper response is not returned + assert isinstance(response, HttpResponse), ( + 'Expected a `Response` to be returned from the view, ' + 'but received a `%s`' % type(response) + ) + if isinstance(response, Response): if not getattr(request, 'accepted_renderer', None): neg = self.perform_content_negotiation(request, force=True) |
