From 52847a215d4e8de88e81d9ae79ce8bee9a36a9a2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 15 Jan 2013 17:50:51 +0000 Subject: Fix implementation --- rest_framework/mixins.py | 3 -- rest_framework/resources.py | 67 ++++++++++++++++----------------------------- 2 files changed, 23 insertions(+), 47 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 8873e4ae..9bd566da 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -25,9 +25,6 @@ class CreateModelMixin(object): return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - def pre_save(self, obj): - pass - class ListModelMixin(object): """ diff --git a/rest_framework/resources.py b/rest_framework/resources.py index dd8a5471..d4019a94 100644 --- a/rest_framework/resources.py +++ b/rest_framework/resources.py @@ -1,31 +1,27 @@ ##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### from functools import update_wrapper -import inspect from django.utils.decorators import classonlymethod -from rest_framework import views, generics - - -def wrapped(source, dest): - """ - Copy public, non-method attributes from source to dest, and return dest. - """ - for attr in [attr for attr in dir(source) - if not attr.startswith('_') and not inspect.ismethod(attr)]: - setattr(dest, attr, getattr(source, attr)) - return dest +from rest_framework import views, generics, mixins ##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### class ResourceMixin(object): """ - Clone Django's `View.as_view()` behaviour *except* using REST framework's - 'method -> action' binding for resources. + This is the magic. + + Overrides `.as_view()` so that it takes an `actions` keyword that performs + the binding of HTTP methods to actions on the Resource. + + For example, to create a concrete view binding the 'GET' and 'POST' methods + to the 'list' and 'create' actions... + + my_resource = MyResource.as_view({'get': 'list', 'post': 'create'}) """ @classonlymethod - def as_view(cls, actions, **initkwargs): + def as_view(cls, actions=None, **initkwargs): """ Main entry point for a request-response process. """ @@ -61,36 +57,19 @@ class ResourceMixin(object): return view -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - class Resource(ResourceMixin, views.APIView): pass -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -class ModelResource(ResourceMixin, views.APIView): - # TODO: Actually delegation won't work - root_class = generics.ListCreateAPIView - detail_class = generics.RetrieveUpdateDestroyAPIView - - def root_view(self): - return wrapped(self, self.root_class()) - - def detail_view(self): - return wrapped(self, self.detail_class()) - - def list(self, request, *args, **kwargs): - return self.root_view().list(request, args, kwargs) - - def create(self, request, *args, **kwargs): - return self.root_view().create(request, args, kwargs) - - def retrieve(self, request, *args, **kwargs): - return self.detail_view().retrieve(request, args, kwargs) - - def update(self, request, *args, **kwargs): - return self.detail_view().update(request, args, kwargs) - - def destroy(self, request, *args, **kwargs): - return self.detail_view().destroy(request, args, kwargs) +# Note the inheritence of both MultipleObjectAPIView *and* SingleObjectAPIView +# is a bit weird given the diamond inheritence, but it will work for now. +# There's some implementation clean up that can happen later. +class ModelResource(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + mixins.ListModelMixin, + ResourceMixin, + generics.MultipleObjectAPIView, + generics.SingleObjectAPIView): + pass -- cgit v1.2.3 From 4a7139e41d2500776c30e663c1cebce74b49270d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 15 Jan 2013 21:49:24 +0000 Subject: Tweaks --- rest_framework/routers.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 rest_framework/routers.py (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py new file mode 100644 index 00000000..a5aef5b7 --- /dev/null +++ b/rest_framework/routers.py @@ -0,0 +1,33 @@ +# Not properly implemented yet, just the basic idea + + +class BaseRouter(object): + def __init__(self): + self.resources = [] + + def register(self, name, resource): + self.resources.append((name, resource)) + + @property + def urlpatterns(self): + ret = [] + + for name, resource in self.resources: + list_actions = { + 'get': getattr(resource, 'list', None), + 'post': getattr(resource, 'create', None) + } + detail_actions = { + 'get': getattr(resource, 'retrieve', None), + 'put': getattr(resource, 'update', None), + 'delete': getattr(resource, 'destroy', None) + } + list_regex = r'^%s/$' % name + detail_regex = r'^%s/(?P[0-9]+)/$' % name + list_name = '%s-list' + detail_name = '%s-detail' + + ret += url(list_regex, resource.as_view(list_actions), list_name) + ret += url(detail_regex, resource.as_view(detail_actions), detail_name) + + return ret -- cgit v1.2.3 From 84be169353f0dd2ceb06fe459b72aa2452fcbeb5 Mon Sep 17 00:00:00 2001 From: Craig de Stigter Date: Fri, 1 Mar 2013 16:13:04 +1300 Subject: fix function names and dotted lookups for use in PrimaryKeyRelatedField.field_to_native (they work in RelatedField.field_to_native already) --- rest_framework/relations.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 0c108717..ef465b3c 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -215,12 +215,20 @@ class PrimaryKeyRelatedField(RelatedField): def field_to_native(self, obj, field_name): if self.many: # To-many relationship - try: + + queryset = None + if not self.source: # Prefer obj.serializable_value for performance reasons - queryset = obj.serializable_value(self.source or field_name) - except AttributeError: + try: + queryset = obj.serializable_value(field_name) + except AttributeError: + pass + if queryset is None: # RelatedManager (reverse relationship) - queryset = getattr(obj, self.source or field_name) + source = self.source or field_name + queryset = obj + for component in source.split('.'): + queryset = get_component(queryset, component) # Forward relationship return [self.to_native(item.pk) for item in queryset.all()] -- cgit v1.2.3 From 922ee61d8611b41e2944b6503af736b1790abe83 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 18 Mar 2013 21:05:13 +0000 Subject: Remove erronous pre_save --- rest_framework/generics.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 55918267..36ecf915 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -82,9 +82,6 @@ class GenericAPIView(views.APIView): """ pass - def pre_save(self, obj): - pass - class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): """ -- cgit v1.2.3 From 7eefcf7e53f2bc37733a601041f23d354c7729f5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 25 Mar 2013 20:26:34 +0000 Subject: Bulk update, allow_add_remove flag --- rest_framework/serializers.py | 16 +++++++----- rest_framework/tests/serializer_bulk_update.py | 34 ++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 11 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 6aca2f57..1b2b0821 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -130,14 +130,14 @@ class BaseSerializer(WritableField): def __init__(self, instance=None, data=None, files=None, context=None, partial=False, many=None, - allow_delete=False, **kwargs): + allow_add_remove=False, **kwargs): super(BaseSerializer, self).__init__(**kwargs) self.opts = self._options_class(self.Meta) self.parent = None self.root = None self.partial = partial self.many = many - self.allow_delete = allow_delete + self.allow_add_remove = allow_add_remove self.context = context or {} @@ -154,8 +154,8 @@ class BaseSerializer(WritableField): if many and instance is not None and not hasattr(instance, '__iter__'): raise ValueError('instance should be a queryset or other iterable with many=True') - if allow_delete and not many: - raise ValueError('allow_delete should only be used for bulk updates, but you have not set many=True') + if allow_add_remove and not many: + raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') ##### # Methods to determine which fields to use when (de)serializing objects. @@ -448,6 +448,10 @@ class BaseSerializer(WritableField): # Determine which object we're updating identity = self.get_identity(item) self.object = identity_to_objects.pop(identity, None) + if self.object is None and not self.allow_add_remove: + ret.append(None) + errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) + continue ret.append(self.from_native(item, None)) errors.append(self._errors) @@ -457,7 +461,7 @@ class BaseSerializer(WritableField): self._errors = any(errors) and errors or [] else: - self._errors = {'non_field_errors': ['Expected a list of items']} + self._errors = {'non_field_errors': ['Expected a list of items.']} else: ret = self.from_native(data, files) @@ -508,7 +512,7 @@ class BaseSerializer(WritableField): else: self.save_object(self.object, **kwargs) - if self.allow_delete and self._deleted: + if self.allow_add_remove and self._deleted: [self.delete_object(item) for item in self._deleted] return self.object diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py index afc1a1a9..8b0ded1a 100644 --- a/rest_framework/tests/serializer_bulk_update.py +++ b/rest_framework/tests/serializer_bulk_update.py @@ -98,7 +98,7 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - expected_errors = {'non_field_errors': ['Expected a list of items']} + expected_errors = {'non_field_errors': ['Expected a list of items.']} self.assertEqual(serializer.errors, expected_errors) @@ -115,7 +115,7 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - expected_errors = {'non_field_errors': ['Expected a list of items']} + expected_errors = {'non_field_errors': ['Expected a list of items.']} self.assertEqual(serializer.errors, expected_errors) @@ -201,11 +201,12 @@ class BulkUpdateSerializerTests(TestCase): 'author': 'Haruki Murakami' } ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.data, data) serializer.save() new_data = self.BookSerializer(self.books(), many=True).data + self.assertEqual(data, new_data) def test_bulk_update_and_create(self): @@ -223,13 +224,36 @@ class BulkUpdateSerializerTests(TestCase): 'author': 'Haruki Murakami' } ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.data, data) serializer.save() new_data = self.BookSerializer(self.books(), many=True).data self.assertEqual(data, new_data) + def test_bulk_update_invalid_create(self): + """ + Bulk update serialization without allow_add_remove may not create items. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 3, + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + expected_errors = [ + {}, + {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']} + ] + serializer = self.BookSerializer(self.books(), data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + def test_bulk_update_error(self): """ Incorrect bulk update serialization should return error data. @@ -249,6 +273,6 @@ class BulkUpdateSerializerTests(TestCase): {}, {'id': ['Enter a whole number.']} ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) self.assertEqual(serializer.is_valid(), False) self.assertEqual(serializer.errors, expected_errors) -- cgit v1.2.3 From 92c929094c88125ea4a2fd359ec99d2b4114f081 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 26 Mar 2013 07:48:53 +0000 Subject: Version 2.2.5 --- rest_framework/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index cf005636..c86403d8 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.2.4' +__version__ = '2.2.5' VERSION = __version__ # synonym -- cgit v1.2.3 From f1b8fee4f1e0ea2503d4e0453bdc3049edaa2598 Mon Sep 17 00:00:00 2001 From: Fernando Rocha Date: Wed, 27 Mar 2013 14:05:46 -0300 Subject: client credentials should be optional (fix #759) client credentials should only be required on token request Signed-off-by: Fernando Rocha --- rest_framework/authentication.py | 32 ++++++++++++++++++-------------- rest_framework/tests/authentication.py | 12 ++++++++++++ 2 files changed, 30 insertions(+), 14 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 8f4ec536..f4626a2e 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -2,14 +2,16 @@ Provides a set of pluggable authentication policies. """ from __future__ import unicode_literals +import base64 +from datetime import datetime + from django.contrib.auth import authenticate from django.core.exceptions import ImproperlyConfigured from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import oauth, oauth_provider, oauth_provider_store -from rest_framework.compat import oauth2_provider, oauth2_provider_forms, oauth2_provider_backends +from rest_framework.compat import oauth2_provider, oauth2_provider_forms from rest_framework.authtoken.models import Token -import base64 def get_authorization_header(request): @@ -314,22 +316,24 @@ class OAuth2Authentication(BaseAuthentication): """ Authenticate the request, given the access token. """ + client = None # Authenticate the client - oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST) - if not oauth2_client_form.is_valid(): - raise exceptions.AuthenticationFailed('Client could not be validated') - client = oauth2_client_form.cleaned_data.get('client') - - # Retrieve the `OAuth2AccessToken` instance from the access_token - auth_backend = oauth2_provider_backends.AccessTokenBackend() - token = auth_backend.authenticate(access_token, client) - if token is None: - raise exceptions.AuthenticationFailed('Invalid token') + if 'client_id' in request.REQUEST: + oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST) + if not oauth2_client_form.is_valid(): + raise exceptions.AuthenticationFailed('Client could not be validated') + client = oauth2_client_form.cleaned_data.get('client') - user = token.user + try: + token = oauth2_provider.models.AccessToken.objects.select_related('user') + if client is not None: + token = token.filter(client=client) + token = token.get(token=access_token, expires__gt=datetime.now()) + except oauth2_provider.models.AccessToken.DoesNotExist: + raise exceptions.AuthenticationFailed('Invalid token') - if not user.is_active: + if not token.user.is_active: msg = 'User inactive or deleted: %s' % user.username raise exceptions.AuthenticationFailed(msg) diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index b663ca48..375b19bd 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -516,6 +516,18 @@ class OAuth2Tests(TestCase): response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_passing_auth_without_client_params(self): + """ + Ensure GETing form over OAuth without client credentials + + Regression test for issue #759: + https://github.com/tomchristie/django-rest-framework/issues/759 + """ + auth = self._create_authorization_header() + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_post_form_passing_auth(self): """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" -- cgit v1.2.3 From 5f48b4a77e0a767694a32310a6368cd32b9a924c Mon Sep 17 00:00:00 2001 From: glic3rinu Date: Wed, 27 Mar 2013 22:43:41 +0100 Subject: Refactored urlize_quoted_links code, now based on Django 1.5 urlize --- rest_framework/templatetags/rest_framework.py | 79 +++++++++++++++------------ 1 file changed, 45 insertions(+), 34 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index c21ddcd7..50e485db 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals, absolute_import from django import template from django.core.urlresolvers import reverse, NoReverseMatch from django.http import QueryDict -from django.utils.html import escape +from django.utils.html import escape, smart_urlquote from django.utils.safestring import SafeData, mark_safe from rest_framework.compat import urlparse from rest_framework.compat import force_text @@ -112,22 +112,6 @@ def replace_query_param(url, key, val): class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') -# Bunch of stuff cloned from urlize -LEADING_PUNCTUATION = ['(', '<', '<', '"', "'"] -TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '>', '"', "'"] -DOTS = ['·', '*', '\xe2\x80\xa2', '•', '•', '•'] -unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)') -word_split_re = re.compile(r'(\s+)') -punctuation_re = re.compile('^(?P(?:%s)*)(?P.*?)(?P(?:%s)*)$' % \ - ('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]), - '|'.join([re.escape(x) for x in TRAILING_PUNCTUATION]))) -simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$') -link_target_attribute_re = re.compile(r'(]*?)target=[^\s>]+') -html_gunk_re = re.compile(r'(?:
|<\/i>|<\/b>|<\/em>|<\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE) -hard_coded_bullets_re = re.compile(r'((?:

(?:%s).*?[a-zA-Z].*?

\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL) -trailing_empty_content_re = re.compile(r'(?:

(?: |\s|
)*?

\s*)+\Z') - - # And the template tags themselves... @register.simple_tag @@ -195,15 +179,25 @@ def add_class(value, css_class): return value +# Bunch of stuff cloned from urlize +TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)'] +WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), + ('"', '"'), ("'", "'")] +word_split_re = re.compile(r'(\s+)') +simple_url_re = re.compile(r'^https?://\w', re.IGNORECASE) +simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE) +simple_email_re = re.compile(r'^\S+@\S+\.\S+$') + + @register.filter def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True): """ Converts any URLs in text into clickable links. - Works on http://, https://, www. links and links ending in .org, .net or - .com. Links can have trailing punctuation (periods, commas, close-parens) - and leading punctuation (opening parens) and it'll still do the right - thing. + Works on http://, https://, www. links, and also on links ending in one of + the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org). + Links can have trailing punctuation (periods, commas, close-parens) and + leading punctuation (opening parens) and it'll still do the right thing. If trim_url_limit is not None, the URLs in link text longer than this limit will truncated to trim_url_limit-3 characters and appended with an elipsis. @@ -216,24 +210,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x safe_input = isinstance(text, SafeData) words = word_split_re.split(force_text(text)) - nofollow_attr = nofollow and ' rel="nofollow"' or '' for i, word in enumerate(words): match = None if '.' in word or '@' in word or ':' in word: - match = punctuation_re.match(word) - if match: - lead, middle, trail = match.groups() + # Deal with punctuation. + lead, middle, trail = '', word, '' + for punctuation in TRAILING_PUNCTUATION: + if middle.endswith(punctuation): + middle = middle[:-len(punctuation)] + trail = punctuation + trail + for opening, closing in WRAPPING_PUNCTUATION: + if middle.startswith(opening): + middle = middle[len(opening):] + lead = lead + opening + # Keep parentheses at the end only if they're balanced. + if (middle.endswith(closing) + and middle.count(closing) == middle.count(opening) + 1): + middle = middle[:-len(closing)] + trail = closing + trail + # Make URL we want to point to. url = None - if middle.startswith('http://') or middle.startswith('https://'): - url = middle - elif middle.startswith('www.') or ('@' not in middle and \ - middle and middle[0] in string.ascii_letters + string.digits and \ - (middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))): - url = 'http://%s' % middle - elif '@' in middle and not ':' in middle and simple_email_re.match(middle): - url = 'mailto:%s' % middle + nofollow_attr = ' rel="nofollow"' if nofollow else '' + if simple_url_re.match(middle): + url = smart_urlquote(middle) + elif simple_url_2_re.match(middle): + url = smart_urlquote('http://%s' % middle) + elif not ':' in middle and simple_email_re.match(middle): + local, domain = middle.rsplit('@', 1) + try: + domain = domain.encode('idna').decode('ascii') + except UnicodeError: + continue + url = 'mailto:%s@%s' % (local, domain) nofollow_attr = '' + # Make link. if url: trimmed = trim_url(middle) @@ -251,4 +262,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru words[i] = mark_safe(word) elif autoescape: words[i] = escape(word) - return mark_safe(''.join(words)) + return ''.join(words) -- cgit v1.2.3 From 2c0363ddaec22ac54385f7e0c2e1401ed3ff0879 Mon Sep 17 00:00:00 2001 From: glic3rinu Date: Wed, 27 Mar 2013 22:58:11 +0100 Subject: Added quotes to TRAILING_PUNCTUATION used by urlize_quoted_links --- rest_framework/templatetags/rest_framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 50e485db..78a3a9a1 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -180,7 +180,7 @@ def add_class(value, css_class): # Bunch of stuff cloned from urlize -TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)'] +TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"] WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), ('"', '"'), ("'", "'")] word_split_re = re.compile(r'(\s+)') -- cgit v1.2.3 From b2cea84fae4f721e8eb6432b3d1bab1309e21a00 Mon Sep 17 00:00:00 2001 From: Fernando Rocha Date: Wed, 27 Mar 2013 19:00:36 -0300 Subject: Complete remove of client checks from oauth2 Signed-off-by: Fernando Rocha --- rest_framework/authentication.py | 12 ++---------- rest_framework/tests/authentication.py | 9 --------- 2 files changed, 2 insertions(+), 19 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index f4626a2e..145d4295 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -316,19 +316,11 @@ class OAuth2Authentication(BaseAuthentication): """ Authenticate the request, given the access token. """ - client = None - - # Authenticate the client - if 'client_id' in request.REQUEST: - oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST) - if not oauth2_client_form.is_valid(): - raise exceptions.AuthenticationFailed('Client could not be validated') - client = oauth2_client_form.cleaned_data.get('client') try: token = oauth2_provider.models.AccessToken.objects.select_related('user') - if client is not None: - token = token.filter(client=client) + # TODO: Change to timezone aware datetime when oauth2_provider add + # support to it. token = token.get(token=access_token, expires__gt=datetime.now()) except oauth2_provider.models.AccessToken.DoesNotExist: raise exceptions.AuthenticationFailed('Invalid token') diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 375b19bd..629db422 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -499,15 +499,6 @@ class OAuth2Tests(TestCase): response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_with_wrong_client_data_failing_auth(self): - """Ensure GETing form over OAuth with incorrect client credentials fails""" - auth = self._create_authorization_header() - params = self._client_credentials_params() - params['client_id'] += 'a' - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_get_form_passing_auth(self): """Ensure GETing form over OAuth with correct client credentials succeed""" -- cgit v1.2.3 From 8ec60a22e1c14792b7021ff9b4e940e16528788a Mon Sep 17 00:00:00 2001 From: Pierre Dulac Date: Thu, 28 Mar 2013 00:57:23 +0100 Subject: Remove client credentials from all OAuth 2 tests --- rest_framework/tests/authentication.py | 45 ++++++++-------------------------- 1 file changed, 10 insertions(+), 35 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 629db422..8e6d3e51 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -466,17 +466,13 @@ class OAuth2Tests(TestCase): def _create_authorization_header(self, token=None): return "Bearer {0}".format(token or self.access_token.token) - def _client_credentials_params(self): - return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET} - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_get_form_with_wrong_authorization_header_token_type_failing(self): """Ensure that a wrong token type lead to the correct HTTP error status code""" auth = "Wrong token-type-obsviously" response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -485,8 +481,7 @@ class OAuth2Tests(TestCase): auth = "Bearer wrong token format" response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -495,27 +490,13 @@ class OAuth2Tests(TestCase): auth = "Bearer wrong-token" response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_get_form_passing_auth(self): """Ensure GETing form over OAuth with correct client credentials succeed""" auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_passing_auth_without_client_params(self): - """ - Ensure GETing form over OAuth without client credentials - - Regression test for issue #759: - https://github.com/tomchristie/django-rest-framework/issues/759 - """ - auth = self._create_authorization_header() response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) @@ -523,8 +504,7 @@ class OAuth2Tests(TestCase): def test_post_form_passing_auth(self): """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -532,16 +512,14 @@ class OAuth2Tests(TestCase): """Ensure POSTing when there is no OAuth access token in db fails""" self.access_token.delete() auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_post_form_with_refresh_token_failing_auth(self): """Ensure POSTing with refresh token instead of access token fails""" auth = self._create_authorization_header(token=self.refresh_token.token) - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -550,8 +528,7 @@ class OAuth2Tests(TestCase): self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late self.access_token.save() auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) self.assertIn('Invalid token', response.content) @@ -562,10 +539,9 @@ class OAuth2Tests(TestCase): read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] read_only_access_token.save() auth = self._create_authorization_header(token=read_only_access_token.token) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) - response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -575,6 +551,5 @@ class OAuth2Tests(TestCase): read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] read_write_access_token.save() auth = self._create_authorization_header(token=read_write_access_token.token) - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) -- cgit v1.2.3 From fa61b2b2f10bf07e3cb87ca947ce7f0ca51a2ede Mon Sep 17 00:00:00 2001 From: Pierre Dulac Date: Thu, 28 Mar 2013 01:05:51 +0100 Subject: Remove oauth2-provider backends reference from compat.py --- rest_framework/compat.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 7b2ef738..c3e423e8 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -445,14 +445,12 @@ except ImportError: # OAuth 2 support is optional try: import provider.oauth2 as oauth2_provider - from provider.oauth2 import backends as oauth2_provider_backends from provider.oauth2 import models as oauth2_provider_models from provider.oauth2 import forms as oauth2_provider_forms from provider import scope as oauth2_provider_scope from provider import constants as oauth2_constants except ImportError: oauth2_provider = None - oauth2_provider_backends = None oauth2_provider_models = None oauth2_provider_forms = None oauth2_provider_scope = None -- cgit v1.2.3 From b10663e02408404844aca4b362aa24df816aca98 Mon Sep 17 00:00:00 2001 From: Kevin Stone Date: Wed, 27 Mar 2013 17:55:36 -0700 Subject: Fixed DjangoFilterBackend not returning a query set. Fixed bug unveiled in #682. Signed-off-by: Kevin Stone --- rest_framework/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 6fea46fa..413fa0d2 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -55,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend): filter_class = self.get_filter_class(view) if filter_class: - return filter_class(request.QUERY_PARAMS, queryset=queryset) + return filter_class(request.QUERY_PARAMS, queryset=queryset).qs return queryset -- cgit v1.2.3 From d4df617f8c1980c1d5f1b91a6b9928185c4c4dce Mon Sep 17 00:00:00 2001 From: Kevin Stone Date: Wed, 27 Mar 2013 18:29:50 -0700 Subject: Added unit test for failing DjangoFilterBackend on SingleObjectMixin that was resolved in b10663e02408404844aca4b362aa24df816aca98 Signed-off-by: Kevin Stone --- rest_framework/tests/filterset.py | 75 +++++++++++++++++++++++++++++++++++---- 1 file changed, 69 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 238da56e..1a71558c 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -1,11 +1,12 @@ from __future__ import unicode_literals import datetime from decimal import Decimal +from django.core.urlresolvers import reverse from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, filters -from rest_framework.compat import django_filters +from rest_framework.compat import django_filters, patterns, url from rest_framework.tests.models import FilterableItem, BasicModel factory = RequestFactory() @@ -46,12 +47,21 @@ if django_filters: filter_class = MisconfiguredFilter filter_backend = filters.DjangoFilterBackend + class FilterClassDetailView(generics.RetrieveAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backend = filters.DjangoFilterBackend + + urlpatterns = patterns('', + url(r'^(?P\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), + url(r'^$', FilterClassRootView.as_view(), name='root-view'), + ) -class IntegrationTestFiltering(TestCase): - """ - Integration tests for filtered list views. - """ +class CommonFilteringTestCase(TestCase): + def _serialize_object(self, obj): + return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + def setUp(self): """ Create 10 FilterableItem instances. @@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase): self.objects = FilterableItem.objects self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + self._serialize_object(obj) for obj in self.objects.all() ] + +class IntegrationTestFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered list views. + """ + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_fields_root_view(self): """ @@ -167,3 +183,50 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?integer=%s' % search_integer) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class IntegrationTestDetailFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered detail views. + """ + urls = 'rest_framework.tests.filterset' + + def _get_url(self, item): + return reverse('detail-view', kwargs=dict(pk=item.pk)) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_detail_view(self): + """ + GET requests to filtered RetrieveAPIView that have a filter_class set + should return filtered results. + """ + item = self.objects.all()[0] + data = self._serialize_object(item) + + # Basic test with no filter. + response = self.client.get(self._get_url(item)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, data) + + # Tests that the decimal filter set that should fail. + search_decimal = Decimal('4.25') + high_item = self.objects.filter(decimal__gt=search_decimal)[0] + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + # Tests that the decimal filter set that should succeed. + search_decimal = Decimal('4.25') + low_item = self.objects.filter(decimal__lt=search_decimal)[0] + low_item_data = self._serialize_object(low_item) + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, low_item_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] + valid_item_data = self._serialize_object(valid_item) + response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, valid_item_data) -- cgit v1.2.3 From 3774ba3ed2af918563eb6ed945cc13aa7fa2345a Mon Sep 17 00:00:00 2001 From: glic3rinu Date: Thu, 28 Mar 2013 12:01:08 +0100 Subject: Added force_text to compat --- rest_framework/compat.py | 31 +++++++++++++++++++++++++++ rest_framework/templatetags/rest_framework.py | 3 ++- 2 files changed, 33 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 7b2ef738..f0bb9c08 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -395,6 +395,37 @@ except ImportError: kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None) return datetime.datetime(**kw) + +# smart_urlquote is new on Django 1.4 +try: + from django.utils.html import smart_urlquote +except ImportError: + try: + from urllib.parse import quote, urlsplit, urlunsplit + except ImportError: # Python 2 + from urllib import quote + from urlparse import urlsplit, urlunsplit + + def smart_urlquote(url): + "Quotes a URL if it isn't already quoted." + # Handle IDN before quoting. + scheme, netloc, path, query, fragment = urlsplit(url) + try: + netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE + except UnicodeError: # invalid domain part + pass + else: + url = urlunsplit((scheme, netloc, path, query, fragment)) + + # An URL is considered unquoted if it contains no % characters or + # contains a % not followed by two hexadecimal digits. See #9655. + if '%' not in url or unquoted_percents_re.search(url): + # See http://bugs.python.org/issue2637 + url = quote(force_bytes(url), safe=b'!*\'();:@&=+$,/?#[]~') + + return force_text(url) + + # Markdown is optional try: import markdown diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 78a3a9a1..33bae241 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -2,11 +2,12 @@ from __future__ import unicode_literals, absolute_import from django import template from django.core.urlresolvers import reverse, NoReverseMatch from django.http import QueryDict -from django.utils.html import escape, smart_urlquote +from django.utils.html import escape from django.utils.safestring import SafeData, mark_safe from rest_framework.compat import urlparse from rest_framework.compat import force_text from rest_framework.compat import six +from rest_framework.compat import smart_urlquote import re import string -- cgit v1.2.3 From 9c32f048b51ec6852236363932f0ab0dcc7473ac Mon Sep 17 00:00:00 2001 From: glic3rinu Date: Thu, 28 Mar 2013 12:01:47 +0100 Subject: Cleaned imports on templatetags/rest_framework module --- rest_framework/templatetags/rest_framework.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 33bae241..b6ab2de3 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -4,12 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch from django.http import QueryDict from django.utils.html import escape from django.utils.safestring import SafeData, mark_safe -from rest_framework.compat import urlparse -from rest_framework.compat import force_text -from rest_framework.compat import six -from rest_framework.compat import smart_urlquote -import re -import string +from rest_framework.compat import urlparse, force_text, six, smart_urlquote +import re, string register = template.Library() -- cgit v1.2.3 From 4531ded061831a9cf402c6c5d84e42f31bc025ad Mon Sep 17 00:00:00 2001 From: Kevin Stone Date: Thu, 28 Mar 2013 18:48:48 -0700 Subject: Removed pagination regression special case for Django<1.4. Having DjangoFilterBackend return an actual query set fixes this issue. Signed-off-by: Kevin Stone --- rest_framework/tests/pagination.py | 10 ---------- 1 file changed, 10 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index d2c9b051..6b8ef02f 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase): view = FilterFieldsRootView.as_view() EXPECTED_NUM_QUERIES = 2 - if django.VERSION < (1, 4): - # On Django 1.3 we need to use django-filter 0.5.4 - # - # The filter objects there don't expose a `.count()` method, - # which means we only make a single query *but* it's a single - # query across *all* of the queryset, instead of a COUNT and then - # a SELECT with a LIMIT. - # - # Although this is fewer queries, it's actually a regression. - EXPECTED_NUM_QUERIES = 1 request = factory.get('/?decimal=15.20') with self.assertNumQueries(EXPECTED_NUM_QUERIES): -- cgit v1.2.3 From ec076a00786c6b89a55b6ffe2556bb3b777100f5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sun, 31 Mar 2013 11:36:58 +0100 Subject: Add viewsets/routers to indexs etc --- rest_framework/routers.py | 33 --------------------------------- rest_framework/viewsets.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 33 deletions(-) delete mode 100644 rest_framework/routers.py create mode 100644 rest_framework/viewsets.py (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py deleted file mode 100644 index a5aef5b7..00000000 --- a/rest_framework/routers.py +++ /dev/null @@ -1,33 +0,0 @@ -# Not properly implemented yet, just the basic idea - - -class BaseRouter(object): - def __init__(self): - self.resources = [] - - def register(self, name, resource): - self.resources.append((name, resource)) - - @property - def urlpatterns(self): - ret = [] - - for name, resource in self.resources: - list_actions = { - 'get': getattr(resource, 'list', None), - 'post': getattr(resource, 'create', None) - } - detail_actions = { - 'get': getattr(resource, 'retrieve', None), - 'put': getattr(resource, 'update', None), - 'delete': getattr(resource, 'destroy', None) - } - list_regex = r'^%s/$' % name - detail_regex = r'^%s/(?P[0-9]+)/$' % name - list_name = '%s-list' - detail_name = '%s-detail' - - ret += url(list_regex, resource.as_view(list_actions), list_name) - ret += url(detail_regex, resource.as_view(detail_actions), detail_name) - - return ret diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py new file mode 100644 index 00000000..a5aef5b7 --- /dev/null +++ b/rest_framework/viewsets.py @@ -0,0 +1,33 @@ +# Not properly implemented yet, just the basic idea + + +class BaseRouter(object): + def __init__(self): + self.resources = [] + + def register(self, name, resource): + self.resources.append((name, resource)) + + @property + def urlpatterns(self): + ret = [] + + for name, resource in self.resources: + list_actions = { + 'get': getattr(resource, 'list', None), + 'post': getattr(resource, 'create', None) + } + detail_actions = { + 'get': getattr(resource, 'retrieve', None), + 'put': getattr(resource, 'update', None), + 'delete': getattr(resource, 'destroy', None) + } + list_regex = r'^%s/$' % name + detail_regex = r'^%s/(?P[0-9]+)/$' % name + list_name = '%s-list' + detail_name = '%s-detail' + + ret += url(list_regex, resource.as_view(list_actions), list_name) + ret += url(detail_regex, resource.as_view(detail_actions), detail_name) + + return ret -- cgit v1.2.3 From 76d1c47905680fafa32596d1dda8d9ae20827acf Mon Sep 17 00:00:00 2001 From: glic3rinu Date: Mon, 1 Apr 2013 20:15:05 +0200 Subject: Fixed IPv6 support for urlize_quoted_links --- rest_framework/templatetags/rest_framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index b6ab2de3..1d7a499f 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -181,7 +181,7 @@ TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"] WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), ('"', '"'), ("'", "'")] word_split_re = re.compile(r'(\s+)') -simple_url_re = re.compile(r'^https?://\w', re.IGNORECASE) +simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE) simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE) simple_email_re = re.compile(r'^\S+@\S+\.\S+$') -- cgit v1.2.3 From 889558365bb3947ab77f47207381d5ff6316fa4f Mon Sep 17 00:00:00 2001 From: J. Paul Reed Date: Tue, 2 Apr 2013 01:41:31 -0700 Subject: Don't have the ModelSerializer trust deserialized objects to not have redefine bool()ean-ness. If the model we're using the ModelSerializer for has redefined methods that act as a boolean (__bool__ or __len__), it may not return the object even though it is_valid(), and should. --- rest_framework/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 1b2b0821..e28bbe81 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -741,7 +741,7 @@ class ModelSerializer(Serializer): Override the default method to also include model field validation. """ instance = super(ModelSerializer, self).from_native(data, files) - if instance: + if not self._errors: return self.full_clean(instance) def save_object(self, obj, **kwargs): -- cgit v1.2.3 From 74fbd5ccc5b2aa2f0aab25ead5ffa36024079fcf Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 3 Apr 2013 09:20:36 +0100 Subject: Fix bug with inactive user accessing OAuth --- rest_framework/authentication.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 145d4295..3e7e89e8 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -10,7 +10,7 @@ from django.core.exceptions import ImproperlyConfigured from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import oauth, oauth_provider, oauth_provider_store -from rest_framework.compat import oauth2_provider, oauth2_provider_forms +from rest_framework.compat import oauth2_provider from rest_framework.authtoken.models import Token @@ -325,11 +325,13 @@ class OAuth2Authentication(BaseAuthentication): except oauth2_provider.models.AccessToken.DoesNotExist: raise exceptions.AuthenticationFailed('Invalid token') - if not token.user.is_active: + user = token.user + + if not user.is_active: msg = 'User inactive or deleted: %s' % user.username raise exceptions.AuthenticationFailed(msg) - return (token.user, token) + return (user, token) def authenticate_header(self, request): """ -- cgit v1.2.3 From 80d28de03477a8dab3832707ca4489c4b2e78e5d Mon Sep 17 00:00:00 2001 From: Atle Frenvik Sveen Date: Wed, 3 Apr 2013 13:10:41 +0200 Subject: Fix the fact that InvalidConsumerError and InvalidTokenError wasn't imported correctly from oauth_provider --- rest_framework/authentication.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 3e7e89e8..1eebb5b9 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -230,7 +230,7 @@ class OAuthAuthentication(BaseAuthentication): try: consumer_key = oauth_request.get_parameter('oauth_consumer_key') consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key) - except oauth_provider_store.InvalidConsumerError as err: + except oauth_provider.store.InvalidConsumerError as err: raise exceptions.AuthenticationFailed(err) if consumer.status != oauth_provider.consts.ACCEPTED: @@ -240,7 +240,7 @@ class OAuthAuthentication(BaseAuthentication): try: token_param = oauth_request.get_parameter('oauth_token') token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param) - except oauth_provider_store.InvalidTokenError: + except oauth_provider.store.InvalidTokenError: msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token') raise exceptions.AuthenticationFailed(msg) -- cgit v1.2.3 From 92b5db593953f03a17ca0fcee2b9ea91a29cb143 Mon Sep 17 00:00:00 2001 From: glic3rinu Date: Thu, 4 Apr 2013 12:11:04 +0200 Subject: Added break_long_headers on templatetags and base template --- rest_framework/templates/rest_framework/base.html | 2 +- rest_framework/templatetags/rest_framework.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 44633f5a..4410f285 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -115,7 +115,7 @@
HTTP {{ response.status_code }} {{ response.status_text }}{% autoescape off %} -{% for key, val in response.items %}{{ key }}: {{ val|urlize_quoted_links }} +{% for key, val in response.items %}{{ key }}: {{ val|break_long_headers|urlize_quoted_links }} {% endfor %}
{{ content|urlize_quoted_links }}
{% endautoescape %}
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 1d7a499f..189e82f6 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -260,3 +260,14 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru elif autoescape: words[i] = escape(word) return ''.join(words) + + +@register.filter +def break_long_headers(header): + """ + Breaks headers longer than 160 characters (~page length) + when possible (are comma separated) + """ + if len(header) > 160: + header = mark_safe('
' + ',
'.join(header.split(','))) + return header -- cgit v1.2.3 From b6c7730d7f31e84b5f120071ddf9c7ab08e4e7da Mon Sep 17 00:00:00 2001 From: glic3rinu Date: Thu, 4 Apr 2013 14:01:47 +0200 Subject: Fixed comma detection in break_long_headers templatetag --- rest_framework/templatetags/rest_framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 189e82f6..c86b6456 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -268,6 +268,6 @@ def break_long_headers(header): Breaks headers longer than 160 characters (~page length) when possible (are comma separated) """ - if len(header) > 160: + if len(header) > 160 and ',' in header: header = mark_safe('
' + ',
'.join(header.split(','))) return header -- cgit v1.2.3 From c785628300d2b7cce63862a18915c537f8a3ab24 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 20:00:44 +0100 Subject: Fleshing out viewsets/routers --- rest_framework/resources.py | 75 ---------------------------- rest_framework/routers.py | 43 ++++++++++++++++ rest_framework/viewsets.py | 119 ++++++++++++++++++++++++++++++++------------ 3 files changed, 129 insertions(+), 108 deletions(-) delete mode 100644 rest_framework/resources.py create mode 100644 rest_framework/routers.py (limited to 'rest_framework') diff --git a/rest_framework/resources.py b/rest_framework/resources.py deleted file mode 100644 index d4019a94..00000000 --- a/rest_framework/resources.py +++ /dev/null @@ -1,75 +0,0 @@ -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -from functools import update_wrapper -from django.utils.decorators import classonlymethod -from rest_framework import views, generics, mixins - - -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -class ResourceMixin(object): - """ - This is the magic. - - Overrides `.as_view()` so that it takes an `actions` keyword that performs - the binding of HTTP methods to actions on the Resource. - - For example, to create a concrete view binding the 'GET' and 'POST' methods - to the 'list' and 'create' actions... - - my_resource = MyResource.as_view({'get': 'list', 'post': 'create'}) - """ - - @classonlymethod - def as_view(cls, actions=None, **initkwargs): - """ - Main entry point for a request-response process. - """ - # sanitize keyword arguments - for key in initkwargs: - if key in cls.http_method_names: - raise TypeError("You tried to pass in the %s method name as a " - "keyword argument to %s(). Don't do that." - % (key, cls.__name__)) - if not hasattr(cls, key): - raise TypeError("%s() received an invalid keyword %r" % ( - cls.__name__, key)) - - def view(request, *args, **kwargs): - self = cls(**initkwargs) - - # Bind methods to actions - for method, action in actions.items(): - handler = getattr(self, action) - setattr(self, method, handler) - - # As you were, solider. - if hasattr(self, 'get') and not hasattr(self, 'head'): - self.head = self.get - return self.dispatch(request, *args, **kwargs) - - # take name and docstring from class - update_wrapper(view, cls, updated=()) - - # and possible attributes set by decorators - # like csrf_exempt from dispatch - update_wrapper(view, cls.dispatch, assigned=()) - return view - - -class Resource(ResourceMixin, views.APIView): - pass - - -# Note the inheritence of both MultipleObjectAPIView *and* SingleObjectAPIView -# is a bit weird given the diamond inheritence, but it will work for now. -# There's some implementation clean up that can happen later. -class ModelResource(mixins.CreateModelMixin, - mixins.RetrieveModelMixin, - mixins.UpdateModelMixin, - mixins.DestroyModelMixin, - mixins.ListModelMixin, - ResourceMixin, - generics.MultipleObjectAPIView, - generics.SingleObjectAPIView): - pass diff --git a/rest_framework/routers.py b/rest_framework/routers.py new file mode 100644 index 00000000..63eae5d7 --- /dev/null +++ b/rest_framework/routers.py @@ -0,0 +1,43 @@ +from django.conf.urls import url, patterns + + +class BaseRouter(object): + def __init__(self): + self.registry = [] + + def register(self, prefix, viewset, base_name): + self.registry.append((prefix, viewset, base_name)) + + def get_urlpatterns(self): + raise NotImplemented('get_urlpatterns must be overridden') + + @property + def urlpatterns(self): + if not hasattr(self, '_urlpatterns'): + print self.get_urlpatterns() + self._urlpatterns = patterns('', *self.get_urlpatterns()) + return self._urlpatterns + + +class DefaultRouter(BaseRouter): + route_list = [ + (r'$', {'get': 'list', 'post': 'create'}, '%s-list'), + (r'(?P[^/]+)/$', {'get': 'retrieve', 'put': 'update', 'delete': 'destroy'}, '%s-detail'), + ] + + def get_urlpatterns(self): + ret = [] + for prefix, viewset, base_name in self.registry: + for suffix, action_mapping, name_format in self.route_list: + + # Only actions which actually exist on the viewset will be bound + bound_actions = {} + for method, action in action_mapping.items(): + if hasattr(viewset, action): + bound_actions[method] = action + + regex = prefix + suffix + view = viewset.as_view(bound_actions) + name = name_format % base_name + ret.append(url(regex, view, name=name)) + return ret diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index a5aef5b7..887a9722 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -1,33 +1,86 @@ -# Not properly implemented yet, just the basic idea - - -class BaseRouter(object): - def __init__(self): - self.resources = [] - - def register(self, name, resource): - self.resources.append((name, resource)) - - @property - def urlpatterns(self): - ret = [] - - for name, resource in self.resources: - list_actions = { - 'get': getattr(resource, 'list', None), - 'post': getattr(resource, 'create', None) - } - detail_actions = { - 'get': getattr(resource, 'retrieve', None), - 'put': getattr(resource, 'update', None), - 'delete': getattr(resource, 'destroy', None) - } - list_regex = r'^%s/$' % name - detail_regex = r'^%s/(?P[0-9]+)/$' % name - list_name = '%s-list' - detail_name = '%s-detail' - - ret += url(list_regex, resource.as_view(list_actions), list_name) - ret += url(detail_regex, resource.as_view(detail_actions), detail_name) - - return ret +from functools import update_wrapper +from django.utils.decorators import classonlymethod +from rest_framework import views, generics, mixins + + +class ViewSetMixin(object): + """ + This is the magic. + + Overrides `.as_view()` so that it takes an `actions` keyword that performs + the binding of HTTP methods to actions on the Resource. + + For example, to create a concrete view binding the 'GET' and 'POST' methods + to the 'list' and 'create' actions... + + view = MyViewSet.as_view({'get': 'list', 'post': 'create'}) + """ + + @classonlymethod + def as_view(cls, actions=None, **initkwargs): + """ + Main entry point for a request-response process. + + Because of the way class based views create a closure around the + instantiated view, we need to totally reimplement `.as_view`, + and slightly modify the view function that is created and returned. + """ + # sanitize keyword arguments + for key in initkwargs: + if key in cls.http_method_names: + raise TypeError("You tried to pass in the %s method name as a " + "keyword argument to %s(). Don't do that." + % (key, cls.__name__)) + if not hasattr(cls, key): + raise TypeError("%s() received an invalid keyword %r" % ( + cls.__name__, key)) + + def view(request, *args, **kwargs): + self = cls(**initkwargs) + + # Bind methods to actions + # This is the bit that's different to a standard view + for method, action in actions.items(): + handler = getattr(self, action) + setattr(self, method, handler) + + # Patch this in as it's otherwise only present from 1.5 onwards + if hasattr(self, 'get') and not hasattr(self, 'head'): + self.head = self.get + + # And continue as usual + return self.dispatch(request, *args, **kwargs) + + # take name and docstring from class + update_wrapper(view, cls, updated=()) + + # and possible attributes set by decorators + # like csrf_exempt from dispatch + update_wrapper(view, cls.dispatch, assigned=()) + return view + + +class ViewSet(ViewSetMixin, views.APIView): + pass + + +# Note the inheritence of both MultipleObjectAPIView *and* SingleObjectAPIView +# is a bit weird given the diamond inheritence, but it will work for now. +# There's some implementation clean up that can happen later. +class ModelViewSet(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + mixins.ListModelMixin, + ViewSetMixin, + generics.MultipleObjectAPIView, + generics.SingleObjectAPIView): + pass + + +class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, + mixins.ListModelMixin, + ViewSetMixin, + generics.MultipleObjectAPIView, + generics.SingleObjectAPIView): + pass -- cgit v1.2.3 From fb41d2ac8f495ae0728e3f38c6a21306f0507316 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 20:35:40 +0100 Subject: Add support for action and link routing --- rest_framework/decorators.py | 22 ++++++++++++++++++++++ rest_framework/routers.py | 20 ++++++++++++++++++++ 2 files changed, 42 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 8250cd3b..00b37f8b 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -97,3 +97,25 @@ def permission_classes(permission_classes): func.permission_classes = permission_classes return func return decorator + + +def link(**kwargs): + """ + Used to mark a method on a ViewSet that should be routed for GET requests. + """ + def decorator(func): + func.bind_to_method = 'get' + func.kwargs = kwargs + return func + return decorator + + +def action(**kwargs): + """ + Used to mark a method on a ViewSet that should be routed for POST requests. + """ + def decorator(func): + func.bind_to_method = 'post' + func.kwargs = kwargs + return func + return decorator diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 63eae5d7..d1e96156 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -24,10 +24,12 @@ class DefaultRouter(BaseRouter): (r'$', {'get': 'list', 'post': 'create'}, '%s-list'), (r'(?P[^/]+)/$', {'get': 'retrieve', 'put': 'update', 'delete': 'destroy'}, '%s-detail'), ] + extra_routes = (r'(?P[^/]+)/%s/$', '%s-%s') def get_urlpatterns(self): ret = [] for prefix, viewset, base_name in self.registry: + # Bind standard routes for suffix, action_mapping, name_format in self.route_list: # Only actions which actually exist on the viewset will be bound @@ -36,8 +38,26 @@ class DefaultRouter(BaseRouter): if hasattr(viewset, action): bound_actions[method] = action + # Build the url pattern regex = prefix + suffix view = viewset.as_view(bound_actions) name = name_format % base_name ret.append(url(regex, view, name=name)) + + # Bind any extra @action or @link routes + for attr in dir(viewset): + func = getattr(viewset, attr) + http_method = getattr(func, 'bind_to_method', None) + if not http_method: + continue + + regex_format, name_format = self.extra_routes + + # Build the url pattern + regex = regex_format % attr + view = viewset.as_view({http_method: attr}, **func.kwargs) + name = name_format % (base_name, attr) + ret.append(url(regex, view, name=name)) + + # Return a list of url patterns return ret -- cgit v1.2.3 From 9e24db022cd8da1a588dd43e6239e07798881c02 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 20:38:42 +0100 Subject: Commenting --- rest_framework/routers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index d1e96156..283add8d 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -29,7 +29,7 @@ class DefaultRouter(BaseRouter): def get_urlpatterns(self): ret = [] for prefix, viewset, base_name in self.registry: - # Bind standard routes + # Bind standard CRUD routes for suffix, action_mapping, name_format in self.route_list: # Only actions which actually exist on the viewset will be bound @@ -44,10 +44,12 @@ class DefaultRouter(BaseRouter): name = name_format % base_name ret.append(url(regex, view, name=name)) - # Bind any extra @action or @link routes + # Bind any extra `@action` or `@link` routes for attr in dir(viewset): func = getattr(viewset, attr) http_method = getattr(func, 'bind_to_method', None) + + # Skip if this is not an @action or @link method if not http_method: continue -- cgit v1.2.3 From f68721ade8d66806296323116ff9a61773ad2be1 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 21:42:26 +0100 Subject: Factor view names/descriptions out of View class --- rest_framework/renderers.py | 11 ++--- rest_framework/routers.py | 34 ++++++++------ rest_framework/utils/breadcrumbs.py | 5 ++- rest_framework/utils/formatting.py | 77 ++++++++++++++++++++++++++++++++ rest_framework/views.py | 89 ++++--------------------------------- rest_framework/viewsets.py | 5 ++- 6 files changed, 117 insertions(+), 104 deletions(-) create mode 100644 rest_framework/utils/formatting.py (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 4c15e0db..752306ad 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -24,6 +24,7 @@ from rest_framework.settings import api_settings from rest_framework.request import clone_request from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs +from rest_framework.utils.formatting import get_view_name, get_view_description from rest_framework import exceptions, parsers, status, VERSION @@ -438,16 +439,10 @@ class BrowsableAPIRenderer(BaseRenderer): return GenericContentForm() def get_name(self, view): - try: - return view.get_name() - except AttributeError: - return smart_text(view.__class__.__name__) + return get_view_name(view.__class__) def get_description(self, view): - try: - return view.get_description(html=True) - except AttributeError: - return smart_text(view.__doc__ or '') + return get_view_description(view.__class__, html=True) def render(self, data, accepted_media_type=None, renderer_context=None): """ diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 283add8d..c37909ff 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -14,23 +14,31 @@ class BaseRouter(object): @property def urlpatterns(self): if not hasattr(self, '_urlpatterns'): - print self.get_urlpatterns() self._urlpatterns = patterns('', *self.get_urlpatterns()) return self._urlpatterns class DefaultRouter(BaseRouter): route_list = [ - (r'$', {'get': 'list', 'post': 'create'}, '%s-list'), - (r'(?P[^/]+)/$', {'get': 'retrieve', 'put': 'update', 'delete': 'destroy'}, '%s-detail'), + (r'$', {'get': 'list', 'post': 'create'}, 'list'), + (r'(?P[^/]+)/$', {'get': 'retrieve', 'put': 'update', 'delete': 'destroy'}, 'detail'), ] - extra_routes = (r'(?P[^/]+)/%s/$', '%s-%s') + extra_routes = r'(?P[^/]+)/%s/$' + name_format = '%s-%s' def get_urlpatterns(self): ret = [] for prefix, viewset, base_name in self.registry: + # Bind regular views + if not getattr(viewset, '_is_viewset', False): + regex = prefix + view = viewset + name = base_name + ret.append(url(regex, view, name=name)) + continue + # Bind standard CRUD routes - for suffix, action_mapping, name_format in self.route_list: + for suffix, action_mapping, action_name in self.route_list: # Only actions which actually exist on the viewset will be bound bound_actions = {} @@ -40,25 +48,25 @@ class DefaultRouter(BaseRouter): # Build the url pattern regex = prefix + suffix - view = viewset.as_view(bound_actions) - name = name_format % base_name + view = viewset.as_view(bound_actions, name_suffix=action_name) + name = self.name_format % (base_name, action_name) ret.append(url(regex, view, name=name)) # Bind any extra `@action` or `@link` routes - for attr in dir(viewset): - func = getattr(viewset, attr) + for action_name in dir(viewset): + func = getattr(viewset, action_name) http_method = getattr(func, 'bind_to_method', None) # Skip if this is not an @action or @link method if not http_method: continue - regex_format, name_format = self.extra_routes + suffix = self.extra_routes % action_name # Build the url pattern - regex = regex_format % attr - view = viewset.as_view({http_method: attr}, **func.kwargs) - name = name_format % (base_name, attr) + regex = prefix + suffix + view = viewset.as_view({http_method: action_name}, **func.kwargs) + name = self.name_format % (base_name, action_name) ret.append(url(regex, view, name=name)) # Return a list of url patterns diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index af21ac79..18b3b207 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals from django.core.urlresolvers import resolve, get_script_prefix +from rest_framework.utils.formatting import get_view_name def get_breadcrumbs(url): @@ -16,11 +17,11 @@ def get_breadcrumbs(url): pass else: # Check if this is a REST framework view, and if so add it to the breadcrumbs - if isinstance(getattr(view, 'cls_instance', None), APIView): + if issubclass(getattr(view, 'cls', None), APIView): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: - breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url)) + breadcrumbs_list.insert(0, (get_view_name(view.cls), prefix + url)) seen.append(view) if url == '': diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py new file mode 100644 index 00000000..79566db1 --- /dev/null +++ b/rest_framework/utils/formatting.py @@ -0,0 +1,77 @@ +""" +Utility functions to return a formatted name and description for a given view. +""" +from __future__ import unicode_literals + +from django.utils.html import escape +from django.utils.safestring import mark_safe +from rest_framework.compat import apply_markdown +import re + + +def _remove_trailing_string(content, trailing): + """ + Strip trailing component `trailing` from `content` if it exists. + Used when generating names from view classes. + """ + if content.endswith(trailing) and content != trailing: + return content[:-len(trailing)] + return content + + +def _remove_leading_indent(content): + """ + Remove leading indent from a block of text. + Used when generating descriptions from docstrings. + """ + whitespace_counts = [len(line) - len(line.lstrip(' ')) + for line in content.splitlines()[1:] if line.lstrip()] + + # unindent the content if needed + if whitespace_counts: + whitespace_pattern = '^' + (' ' * min(whitespace_counts)) + content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) + content = content.strip('\n') + return content + + +def _camelcase_to_spaces(content): + """ + Translate 'CamelCaseNames' to 'Camel Case Names'. + Used when generating names from view classes. + """ + camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' + content = re.sub(camelcase_boundry, ' \\1', content).strip() + return ' '.join(content.split('_')).title() + + +def get_view_name(cls): + """ + Return a formatted name for an `APIView` class or `@api_view` function. + """ + name = cls.__name__ + name = _remove_trailing_string(name, 'View') + name = _remove_trailing_string(name, 'ViewSet') + return _camelcase_to_spaces(name) + + +def get_view_description(cls, html=False): + """ + Return a description for an `APIView` class or `@api_view` function. + """ + description = cls.__doc__ or '' + description = _remove_leading_indent(description) + if html: + return markup_description(description) + return description + + +def markup_description(description): + """ + Apply HTML markup to the given description. + """ + if apply_markdown: + description = apply_markdown(description) + else: + description = escape(description).replace('\n', '
') + return mark_safe(description) diff --git a/rest_framework/views.py b/rest_framework/views.py index 81cbdcbb..12298ca5 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -4,51 +4,13 @@ Provides an APIView class that is used as the base of all class-based views. from __future__ import unicode_literals from django.core.exceptions import PermissionDenied from django.http import Http404 -from django.utils.html import escape -from django.utils.safestring import mark_safe from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions -from rest_framework.compat import View, apply_markdown +from rest_framework.compat import View from rest_framework.response import Response from rest_framework.request import Request from rest_framework.settings import api_settings -import re - - -def _remove_trailing_string(content, trailing): - """ - Strip trailing component `trailing` from `content` if it exists. - Used when generating names from view classes. - """ - if content.endswith(trailing) and content != trailing: - return content[:-len(trailing)] - return content - - -def _remove_leading_indent(content): - """ - Remove leading indent from a block of text. - Used when generating descriptions from docstrings. - """ - whitespace_counts = [len(line) - len(line.lstrip(' ')) - for line in content.splitlines()[1:] if line.lstrip()] - - # unindent the content if needed - if whitespace_counts: - whitespace_pattern = '^' + (' ' * min(whitespace_counts)) - content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) - content = content.strip('\n') - return content - - -def _camelcase_to_spaces(content): - """ - Translate 'CamelCaseNames' to 'Camel Case Names'. - Used when generating names from view classes. - """ - camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' - content = re.sub(camelcase_boundry, ' \\1', content).strip() - return ' '.join(content.split('_')).title() +from rest_framework.utils.formatting import get_view_name, get_view_description class APIView(View): @@ -64,13 +26,13 @@ class APIView(View): @classmethod def as_view(cls, **initkwargs): """ - Override the default :meth:`as_view` to store an instance of the view - as an attribute on the callable function. This allows us to discover - information about the view when we do URL reverse lookups. + Store the original class on the view function. + + This allows us to discover information about the view when we do URL + reverse lookups. Used for breadcrumb generation. """ - # TODO: deprecate? view = super(APIView, cls).as_view(**initkwargs) - view.cls_instance = cls(**initkwargs) + view.cls = cls return view @property @@ -90,43 +52,10 @@ class APIView(View): 'Vary': 'Accept' } - def get_name(self): - """ - Return the resource or view class name for use as this view's name. - Override to customize. - """ - # TODO: deprecate? - name = self.__class__.__name__ - name = _remove_trailing_string(name, 'View') - return _camelcase_to_spaces(name) - - def get_description(self, html=False): - """ - Return the resource or view docstring for use as this view's description. - Override to customize. - """ - # TODO: deprecate? - description = self.__doc__ or '' - description = _remove_leading_indent(description) - if html: - return self.markup_description(description) - return description - - def markup_description(self, description): - """ - Apply HTML markup to the description of this view. - """ - # TODO: deprecate? - if apply_markdown: - description = apply_markdown(description) - else: - description = escape(description).replace('\n', '
') - return mark_safe(description) - def metadata(self, request): return { - 'name': self.get_name(), - 'description': self.get_description(), + 'name': get_view_name(self.__class__), + 'description': get_view_description(self.__class__), 'renders': [renderer.media_type for renderer in self.renderer_classes], 'parses': [parser.media_type for parser in self.parser_classes], } diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 887a9722..0818c0d9 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -15,9 +15,10 @@ class ViewSetMixin(object): view = MyViewSet.as_view({'get': 'list', 'post': 'create'}) """ + _is_viewset = True @classonlymethod - def as_view(cls, actions=None, **initkwargs): + def as_view(cls, actions=None, name_suffix=None, **initkwargs): """ Main entry point for a request-response process. @@ -57,6 +58,8 @@ class ViewSetMixin(object): # and possible attributes set by decorators # like csrf_exempt from dispatch update_wrapper(view, cls.dispatch, assigned=()) + + view.cls = cls return view -- cgit v1.2.3 From fd3f538e9f9ef5d4d929c107b9619e0735e426f1 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 21:48:23 +0100 Subject: Fix up view name/description tests --- rest_framework/tests/description.py | 63 ++++++++++++++----------------------- 1 file changed, 23 insertions(+), 40 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py index 5b3315bc..52c1a34c 100644 --- a/rest_framework/tests/description.py +++ b/rest_framework/tests/description.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from django.test import TestCase from rest_framework.views import APIView from rest_framework.compat import apply_markdown +from rest_framework.utils.formatting import get_view_name, get_view_description # We check that docstrings get nicely un-indented. DESCRIPTION = """an example docstring @@ -49,22 +50,16 @@ MARKED_DOWN_gte_21 = """

an example docstring

class TestViewNamesAndDescriptions(TestCase): - def test_resource_name_uses_classname_by_default(self): - """Ensure Resource names are based on the classname by default.""" + def test_view_name_uses_class_name(self): + """ + Ensure view names are based on the class name. + """ class MockView(APIView): pass - self.assertEqual(MockView().get_name(), 'Mock') + self.assertEqual(get_view_name(MockView), 'Mock') - def test_resource_name_can_be_set_explicitly(self): - """Ensure Resource names can be set using the 'get_name' method.""" - example = 'Some Other Name' - class MockView(APIView): - def get_name(self): - return example - self.assertEqual(MockView().get_name(), example) - - def test_resource_description_uses_docstring_by_default(self): - """Ensure Resource names are based on the docstring by default.""" + def test_view_description_uses_docstring(self): + """Ensure view descriptions are based on the docstring.""" class MockView(APIView): """an example docstring ==================== @@ -81,44 +76,32 @@ class TestViewNamesAndDescriptions(TestCase): # hash style header #""" - self.assertEqual(MockView().get_description(), DESCRIPTION) - - def test_resource_description_can_be_set_explicitly(self): - """Ensure Resource descriptions can be set using the 'get_description' method.""" - example = 'Some other description' - - class MockView(APIView): - """docstring""" - def get_description(self): - return example - self.assertEqual(MockView().get_description(), example) + self.assertEqual(get_view_description(MockView), DESCRIPTION) - def test_resource_description_supports_unicode(self): + def test_view_description_supports_unicode(self): + """ + Unicode in docstrings should be respected. + """ class MockView(APIView): """Проверка""" pass - self.assertEqual(MockView().get_description(), "Проверка") - - - def test_resource_description_does_not_require_docstring(self): - """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method.""" - example = 'Some other description' - - class MockView(APIView): - def get_description(self): - return example - self.assertEqual(MockView().get_description(), example) + self.assertEqual(get_view_description(MockView), "Проверка") - def test_resource_description_can_be_empty(self): - """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string.""" + def test_view_description_can_be_empty(self): + """ + Ensure that if a view has no docstring, + then it's description is the empty string. + """ class MockView(APIView): pass - self.assertEqual(MockView().get_description(), '') + self.assertEqual(get_view_description(MockView), '') def test_markdown(self): - """Ensure markdown to HTML works as expected""" + """ + Ensure markdown to HTML works as expected. + """ if apply_markdown: gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21 lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21 -- cgit v1.2.3 From c2280e34ece1867432c87a9654d31a708281b05a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 21:53:15 +0100 Subject: Version 2.2.6 --- rest_framework/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index c86403d8..7ac12058 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.2.5' +__version__ = '2.2.6' VERSION = __version__ # synonym -- cgit v1.2.3 From 371698331c979305b5684f864ee6bf5b6d11a44e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 4 Apr 2013 22:24:30 +0100 Subject: Tweaks --- rest_framework/generics.py | 9 +++------ rest_framework/mixins.py | 4 ++++ rest_framework/routers.py | 12 ++++++++++-- 3 files changed, 17 insertions(+), 8 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 36ecf915..dea980a5 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -187,8 +187,7 @@ class UpdateAPIView(mixins.UpdateModelMixin, return self.update(request, *args, **kwargs) def patch(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) + return self.partial_update(request, *args, **kwargs) class ListCreateAPIView(mixins.ListModelMixin, @@ -217,8 +216,7 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, return self.update(request, *args, **kwargs) def patch(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) + return self.partial_update(request, *args, **kwargs) class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, @@ -248,8 +246,7 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, return self.update(request, *args, **kwargs) def patch(self, request, *args, **kwargs): - kwargs['partial'] = True - return self.update(request, *args, **kwargs) + return self.partial_update(request, *args, **kwargs) def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 7d9a6e65..c700602e 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -137,6 +137,10 @@ class UpdateModelMixin(object): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def partial_update(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + def pre_save(self, obj): """ Set any attributes on the object that are implicit in the request. diff --git a/rest_framework/routers.py b/rest_framework/routers.py index c37909ff..afc51f3b 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -20,8 +20,16 @@ class BaseRouter(object): class DefaultRouter(BaseRouter): route_list = [ - (r'$', {'get': 'list', 'post': 'create'}, 'list'), - (r'(?P[^/]+)/$', {'get': 'retrieve', 'put': 'update', 'delete': 'destroy'}, 'detail'), + (r'$', { + 'get': 'list', + 'post': 'create' + }, 'list'), + (r'(?P[^/]+)/$', { + 'get': 'retrieve', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy' + }, 'detail'), ] extra_routes = r'(?P[^/]+)/%s/$' name_format = '%s-%s' -- cgit v1.2.3 From c73d0e1e39e661c7324eb0df8c3ce6e18f57915b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Apr 2013 18:22:39 +0100 Subject: Minor cleaning up on View --- rest_framework/compat.py | 20 ++++++++++++-------- rest_framework/views.py | 8 ++++---- 2 files changed, 16 insertions(+), 12 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 6551723a..8bfebe68 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -87,9 +87,7 @@ else: raise ImportError("User model is not to be found.") -# First implementation of Django class-based views did not include head method -# in base View class - https://code.djangoproject.com/ticket/15668 -if django.VERSION >= (1, 4): +if django.VERSION >= (1, 5): from django.views.generic import View else: from django.views.generic import View as _View @@ -97,6 +95,8 @@ else: from django.utils.functional import update_wrapper class View(_View): + # 1.3 does not include head method in base View class + # See: https://code.djangoproject.com/ticket/15668 @classonlymethod def as_view(cls, **initkwargs): """ @@ -126,11 +126,15 @@ else: update_wrapper(view, cls.dispatch, assigned=()) return view -# Taken from @markotibold's attempt at supporting PATCH. -# https://github.com/markotibold/django-rest-framework/tree/patch -http_method_names = set(View.http_method_names) -http_method_names.add('patch') -View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django + # _allowed_methods only present from 1.5 onwards + def _allowed_methods(self): + return [m.upper() for m in self.http_method_names if hasattr(self, m)] + + +# PATCH method is not implemented by Django +if 'patch' not in View.http_method_names: + View.http_method_names = View.http_method_names + ['patch'] + # PUT, DELETE do not require CSRF until 1.4. They should. Make it better. if django.VERSION >= (1, 4): diff --git a/rest_framework/views.py b/rest_framework/views.py index 12298ca5..d7d3a2e2 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -38,10 +38,9 @@ class APIView(View): @property def allowed_methods(self): """ - Return the list of allowed HTTP methods, uppercased. + Wrap Django's private `_allowed_methods` interface in a public property. """ - return [method.upper() for method in self.http_method_names - if hasattr(self, method)] + return self._allowed_methods() @property def default_response_headers(self): @@ -69,7 +68,8 @@ class APIView(View): def http_method_not_allowed(self, request, *args, **kwargs): """ - Called if `request.method` does not correspond to a handler method. + If `request.method` does not correspond to a handler method, + determine what kind of exception to raise. """ raise exceptions.MethodNotAllowed(request.method) -- cgit v1.2.3 From 099163f81f9d89746de50f3aed2955ead54dba4e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Apr 2013 18:45:15 +0100 Subject: Removed SingleObjectMixin and MultipleObjectMixin --- rest_framework/generics.py | 139 +++++++++++++++++++++++++++++++++------------ rest_framework/mixins.py | 5 +- 2 files changed, 106 insertions(+), 38 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index dea980a5..af3b69da 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -4,21 +4,35 @@ Generic views that provide commonly needed behaviour. from __future__ import unicode_literals from rest_framework import views, mixins from rest_framework.settings import api_settings -from django.views.generic.detail import SingleObjectMixin -from django.views.generic.list import MultipleObjectMixin - +from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist +from django.core.paginator import Paginator, InvalidPage +from django.http import Http404 +from django.utils.translation import ugettext as _ ### Base classes for the generic views ### + class GenericAPIView(views.APIView): """ Base class for all other generic views. """ - model = None + queryset = None serializer_class = None - model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS + filter_backend = api_settings.FILTER_BACKEND + paginate_by = api_settings.PAGINATE_BY + paginate_by_param = api_settings.PAGINATE_BY_PARAM + pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS + allow_empty = True + page_kwarg = 'page' + + # Pending deprecation + model = None + model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS + pk_url_kwarg = 'pk' # Not provided in Django 1.3 + slug_url_kwarg = 'slug' # Not provided in Django 1.3 + slug_field = 'slug' def filter_queryset(self, queryset): """ @@ -82,15 +96,7 @@ class GenericAPIView(views.APIView): """ pass - -class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): - """ - Base class for generic views onto a queryset. - """ - - paginate_by = api_settings.PAGINATE_BY - paginate_by_param = api_settings.PAGINATE_BY_PARAM - pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS + # Pagination def get_pagination_serializer(self, page=None): """ @@ -116,28 +122,81 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): pass return self.paginate_by - -class SingleObjectAPIView(SingleObjectMixin, GenericAPIView): - """ - Base class for generic views onto a model instance. - """ - - pk_url_kwarg = 'pk' # Not provided in Django 1.3 - slug_url_kwarg = 'slug' # Not provided in Django 1.3 - slug_field = 'slug' + def paginate_queryset(self, queryset, page_size, paginator_class=Paginator): + """ + Paginate a queryset. + """ + paginator = paginator_class(queryset, page_size, allow_empty_first_page=self.allow_empty) + page_kwarg = self.page_kwarg + page = self.kwargs.get(page_kwarg) or self.request.GET.get(page_kwarg) or 1 + try: + page_number = int(page) + except ValueError: + if page == 'last': + page_number = paginator.num_pages + else: + raise Http404(_("Page is not 'last', nor can it be converted to an int.")) + try: + page = paginator.page(page_number) + return (paginator, page, page.object_list, page.has_other_pages()) + except InvalidPage as e: + raise Http404(_('Invalid page (%(page_number)s): %(message)s') % { + 'page_number': page_number, + 'message': str(e) + }) + + def get_queryset(self): + """ + Get the list of items for this view. This must be an iterable, and may + be a queryset (in which qs-specific behavior will be enabled). + """ + if self.queryset is not None: + queryset = self.queryset + if hasattr(queryset, '_clone'): + queryset = queryset._clone() + elif self.model is not None: + queryset = self.model._default_manager.all() + else: + raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" + % self.__class__.__name__) + return queryset def get_object(self, queryset=None): """ - Override default to add support for object-level permissions. + Returns the object the view is displaying. + By default this requires `self.queryset` and a `pk` or `slug` argument + in the URLconf, but subclasses can override this to return any object. """ - obj = super(SingleObjectAPIView, self).get_object(queryset) + # Use a custom queryset if provided; this is required for subclasses + # like DateDetailView + if queryset is None: + queryset = self.get_queryset() + # Next, try looking up by primary key. + pk = self.kwargs.get(self.pk_url_kwarg, None) + slug = self.kwargs.get(self.slug_url_kwarg, None) + if pk is not None: + queryset = queryset.filter(pk=pk) + # Next, try looking up by slug. + elif slug is not None: + queryset = queryset.filter(**{self.slug_field: slug}) + # If none of those are defined, it's an error. + else: + raise AttributeError("Generic detail view %s must be called with " + "either an object pk or a slug." + % self.__class__.__name__) + try: + # Get the single item from the filtered queryset + obj = queryset.get() + except ObjectDoesNotExist: + raise Http404(_("No %(verbose_name)s found matching the query") % + {'verbose_name': queryset.model._meta.verbose_name}) + self.check_object_permissions(self.request, obj) return obj ### Concrete view classes that provide method handlers ### -### by composing the mixin classes with a base view. ### - +### by composing the mixin classes with the base view. ### class CreateAPIView(mixins.CreateModelMixin, GenericAPIView): @@ -150,7 +209,7 @@ class CreateAPIView(mixins.CreateModelMixin, class ListAPIView(mixins.ListModelMixin, - MultipleObjectAPIView): + GenericAPIView): """ Concrete view for listing a queryset. """ @@ -159,7 +218,7 @@ class ListAPIView(mixins.ListModelMixin, class RetrieveAPIView(mixins.RetrieveModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving a model instance. """ @@ -168,7 +227,7 @@ class RetrieveAPIView(mixins.RetrieveModelMixin, class DestroyAPIView(mixins.DestroyModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for deleting a model instance. @@ -178,7 +237,7 @@ class DestroyAPIView(mixins.DestroyModelMixin, class UpdateAPIView(mixins.UpdateModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for updating a model instance. @@ -192,7 +251,7 @@ class UpdateAPIView(mixins.UpdateModelMixin, class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, - MultipleObjectAPIView): + GenericAPIView): """ Concrete view for listing a queryset or creating a model instance. """ @@ -205,7 +264,7 @@ class ListCreateAPIView(mixins.ListModelMixin, class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving, updating a model instance. """ @@ -221,7 +280,7 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, mixins.DestroyModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving or deleting a model instance. """ @@ -235,7 +294,7 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving, updating or deleting a model instance. """ @@ -250,3 +309,13 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) + + +### Deprecated classes ### + +class MultipleObjectAPIView(GenericAPIView): + pass + + +class SingleObjectAPIView(GenericAPIView): + pass diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index c700602e..b15cb11f 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -72,8 +72,7 @@ class ListModelMixin(object): # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. - allow_empty = self.get_allow_empty() - if not allow_empty and not self.object_list: + if not self.allow_empty and not self.object_list: class_name = self.__class__.__name__ error_msg = self.empty_error % {'class_name': class_name} raise Http404(error_msg) @@ -148,7 +147,7 @@ class UpdateModelMixin(object): # pk and/or slug attributes are implicit in the URL. pk = self.kwargs.get(self.pk_url_kwarg, None) slug = self.kwargs.get(self.slug_url_kwarg, None) - slug_field = slug and self.get_slug_field() or None + slug_field = slug and self.slug_field or None if pk: setattr(obj, 'pk', pk) -- cgit v1.2.3 From dc45bc7bfad64a17f3e5ed0f5a487bccc379aac2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Apr 2013 19:01:01 +0100 Subject: Add lookup_kwarg --- rest_framework/generics.py | 18 ++++++++++++------ rest_framework/tests/filterset.py | 6 +++--- 2 files changed, 15 insertions(+), 9 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index af3b69da..d4a50dcd 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -26,6 +26,7 @@ class GenericAPIView(views.APIView): pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS allow_empty = True page_kwarg = 'page' + lookup_kwarg = 'pk' # Pending deprecation model = None @@ -167,23 +168,26 @@ class GenericAPIView(views.APIView): By default this requires `self.queryset` and a `pk` or `slug` argument in the URLconf, but subclasses can override this to return any object. """ - # Use a custom queryset if provided; this is required for subclasses - # like DateDetailView + # Determine the base queryset to use. if queryset is None: queryset = self.get_queryset() - # Next, try looking up by primary key. + + # Perform the lookup filtering. pk = self.kwargs.get(self.pk_url_kwarg, None) slug = self.kwargs.get(self.slug_url_kwarg, None) - if pk is not None: + lookup = self.kwargs.get(self.lookup_kwarg, None) + + if lookup is not None: + queryset = queryset.filter(**{self.lookup_kwarg: lookup}) + elif pk is not None: queryset = queryset.filter(pk=pk) - # Next, try looking up by slug. elif slug is not None: queryset = queryset.filter(**{self.slug_field: slug}) - # If none of those are defined, it's an error. else: raise AttributeError("Generic detail view %s must be called with " "either an object pk or a slug." % self.__class__.__name__) + try: # Get the single item from the filtered queryset obj = queryset.get() @@ -191,7 +195,9 @@ class GenericAPIView(views.APIView): raise Http404(_("No %(verbose_name)s found matching the query") % {'verbose_name': queryset.model._meta.verbose_name}) + # May raise a permission denied self.check_object_permissions(self.request, obj) + return obj diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 1a71558c..1e53a5cd 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -61,7 +61,7 @@ if django_filters: class CommonFilteringTestCase(TestCase): def _serialize_object(self, obj): return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} - + def setUp(self): """ Create 10 FilterableItem instances. @@ -190,7 +190,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): Integration tests for filtered detail views. """ urls = 'rest_framework.tests.filterset' - + def _get_url(self, item): return reverse('detail-view', kwargs=dict(pk=item.pk)) @@ -221,7 +221,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, low_item_data) - + # Tests that multiple filters works. search_decimal = Decimal('5.25') search_date = datetime.date(2012, 10, 2) -- cgit v1.2.3 From 1de6cff11b71e4aaa7b76219d4d2118021e23a00 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Apr 2013 19:06:49 +0100 Subject: Cleaning up get_object and get_queryset --- rest_framework/generics.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index d4a50dcd..4ae2ac8e 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -148,25 +148,22 @@ class GenericAPIView(views.APIView): def get_queryset(self): """ - Get the list of items for this view. This must be an iterable, and may - be a queryset (in which qs-specific behavior will be enabled). + Get the list of items for this view. + + This must be an iterable, and may be a queryset. """ if self.queryset is not None: - queryset = self.queryset - if hasattr(queryset, '_clone'): - queryset = queryset._clone() - elif self.model is not None: - queryset = self.model._default_manager.all() - else: - raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" - % self.__class__.__name__) - return queryset + return self.queryset._clone() + + if self.model is not None: + return self.model._default_manager.all() + + raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" + % self.__class__.__name__) def get_object(self, queryset=None): """ Returns the object the view is displaying. - By default this requires `self.queryset` and a `pk` or `slug` argument - in the URLconf, but subclasses can override this to return any object. """ # Determine the base queryset to use. if queryset is None: -- cgit v1.2.3 From 9bb1277e512a88e6c11c52457d0c24e73f30bb98 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Apr 2013 19:37:19 +0100 Subject: Cleaning up around bits of API that will be pending deprecation --- rest_framework/generics.py | 116 +++++++++++++++++++++++++++------------------ rest_framework/mixins.py | 9 ++-- rest_framework/viewsets.py | 6 +-- 3 files changed, 75 insertions(+), 56 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 4ae2ac8e..124dba38 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -35,15 +35,6 @@ class GenericAPIView(views.APIView): slug_url_kwarg = 'slug' # Not provided in Django 1.3 slug_field = 'slug' - def filter_queryset(self, queryset): - """ - Given a queryset, filter it with whichever filter backend is in use. - """ - if not self.filter_backend: - return queryset - backend = self.filter_backend() - return backend.filter_queryset(self.request, queryset, self) - def get_serializer_context(self): """ Extra context provided to the serializer class. @@ -54,24 +45,6 @@ class GenericAPIView(views.APIView): 'view': self } - def get_serializer_class(self): - """ - Return the class to use for the serializer. - - Defaults to using `self.serializer_class`, falls back to constructing a - model serializer class using `self.model_serializer_class`, with - `self.model` as the model. - """ - serializer_class = self.serializer_class - - if serializer_class is None: - class DefaultSerializer(self.model_serializer_class): - class Meta: - model = self.model - serializer_class = DefaultSerializer - - return serializer_class - def get_serializer(self, instance=None, data=None, files=None, many=False, partial=False): """ @@ -83,22 +56,6 @@ class GenericAPIView(views.APIView): return serializer_class(instance, data=data, files=files, many=many, partial=partial, context=context) - def pre_save(self, obj): - """ - Placeholder method for calling before saving an object. - May be used eg. to set attributes on the object that are implicit - in either the request, or the url. - """ - pass - - def post_save(self, obj, created=False): - """ - Placeholder method for calling after saving an object. - """ - pass - - # Pagination - def get_pagination_serializer(self, page=None): """ Return a serializer instance to use with paginated data. @@ -111,9 +68,14 @@ class GenericAPIView(views.APIView): context = self.get_serializer_context() return pagination_serializer_class(instance=page, context=context) - def get_paginate_by(self, queryset): + def get_paginate_by(self, queryset=None): """ Return the size of pages to use with pagination. + + If `PAGINATE_BY_PARAM` is set it will attempt to get the page size + from a named query parameter in the url, eg. ?page_size=100 + + Otherwise defaults to using `self.paginate_by`. """ if self.paginate_by_param: query_params = self.request.QUERY_PARAMS @@ -121,6 +83,7 @@ class GenericAPIView(views.APIView): return int(query_params[self.paginate_by_param]) except (KeyError, ValueError): pass + return self.paginate_by def paginate_queryset(self, queryset, page_size, paginator_class=Paginator): @@ -146,16 +109,54 @@ class GenericAPIView(views.APIView): 'message': str(e) }) + def filter_queryset(self, queryset): + """ + Given a queryset, filter it with whichever filter backend is in use. + """ + if not self.filter_backend: + return queryset + backend = self.filter_backend() + return backend.filter_queryset(self.request, queryset, self) + + ### The following methods provide default implementations + ### that you may want to override for more complex cases. + + def get_serializer_class(self): + """ + Return the class to use for the serializer. + Defaults to using `self.serializer_class`. + + You may want to override this if you need to provide different + serializations depending on the incoming request. + + (Eg. admins get full serialization, others get basic serilization) + """ + serializer_class = self.serializer_class + if serializer_class is not None: + return serializer_class + + # TODO: Deprecation warning + class DefaultSerializer(self.model_serializer_class): + class Meta: + model = self.model + return DefaultSerializer + def get_queryset(self): """ Get the list of items for this view. - This must be an iterable, and may be a queryset. + Defaults to using `self.queryset`. + + You may want to override this if you need to provide different + querysets depending on the incoming request. + + (Eg. return a list of items that is specific to the user) """ if self.queryset is not None: return self.queryset._clone() if self.model is not None: + # TODO: Deprecation warning return self.model._default_manager.all() raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" @@ -164,10 +165,14 @@ class GenericAPIView(views.APIView): def get_object(self, queryset=None): """ Returns the object the view is displaying. + + You may want to override this if you need to provide non-standard + queryset lookups. Eg if objects are referenced using multiple + keyword arguments in the url conf. """ # Determine the base queryset to use. if queryset is None: - queryset = self.get_queryset() + queryset = self.filter_queryset(self.get_queryset()) # Perform the lookup filtering. pk = self.kwargs.get(self.pk_url_kwarg, None) @@ -177,8 +182,10 @@ class GenericAPIView(views.APIView): if lookup is not None: queryset = queryset.filter(**{self.lookup_kwarg: lookup}) elif pk is not None: + # TODO: Deprecation warning queryset = queryset.filter(pk=pk) elif slug is not None: + # TODO: Deprecation warning queryset = queryset.filter(**{self.slug_field: slug}) else: raise AttributeError("Generic detail view %s must be called with " @@ -197,6 +204,23 @@ class GenericAPIView(views.APIView): return obj + ### The following methods are intended to be overridden. + + def pre_save(self, obj): + """ + Placeholder method for calling before saving an object. + + May be used to set attributes on the object that are implicit + in either the request, or the url. + """ + pass + + def post_save(self, obj, created=False): + """ + Placeholder method for calling after saving an object. + """ + pass + ### Concrete view classes that provide method handlers ### ### by composing the mixin classes with the base view. ### diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index b15cb11f..6e40b5c4 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -67,8 +67,7 @@ class ListModelMixin(object): empty_error = "Empty list and '%(class_name)s.allow_empty' is False." def list(self, request, *args, **kwargs): - queryset = self.get_queryset() - self.object_list = self.filter_queryset(queryset) + self.object_list = self.filter_queryset(self.get_queryset()) # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. @@ -79,7 +78,7 @@ class ListModelMixin(object): # Pagination size is set by the `.paginate_by` attribute, # which may be `None` to disable pagination. - page_size = self.get_paginate_by(self.object_list) + page_size = self.get_paginate_by() if page_size: packed = self.paginate_queryset(self.object_list, page_size) paginator, page, queryset, is_paginated = packed @@ -96,9 +95,7 @@ class RetrieveModelMixin(object): Should be mixed in with `SingleObjectAPIView`. """ def retrieve(self, request, *args, **kwargs): - queryset = self.get_queryset() - filtered_queryset = self.filter_queryset(queryset) - self.object = self.get_object(filtered_queryset) + self.object = self.get_object() serializer = self.get_serializer(self.object) return Response(serializer.data) diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 0818c0d9..28ab30e2 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -76,14 +76,12 @@ class ModelViewSet(mixins.CreateModelMixin, mixins.DestroyModelMixin, mixins.ListModelMixin, ViewSetMixin, - generics.MultipleObjectAPIView, - generics.SingleObjectAPIView): + generics.GenericAPIView): pass class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, ViewSetMixin, - generics.MultipleObjectAPIView, - generics.SingleObjectAPIView): + generics.GenericAPIView): pass -- cgit v1.2.3 From 07af4373616c28e7600ee2ec7981b5a1d0a92f7d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Apr 2013 19:47:16 +0100 Subject: Cleaning up around bits of API that will be pending deprecation --- rest_framework/generics.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 124dba38..ba7d1f43 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -24,15 +24,17 @@ class GenericAPIView(views.APIView): paginate_by = api_settings.PAGINATE_BY paginate_by_param = api_settings.PAGINATE_BY_PARAM pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS - allow_empty = True page_kwarg = 'page' lookup_kwarg = 'pk' + allow_empty = True + + ###################################### + # These are all pending deprecation... - # Pending deprecation model = None model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS - pk_url_kwarg = 'pk' # Not provided in Django 1.3 - slug_url_kwarg = 'slug' # Not provided in Django 1.3 + pk_url_kwarg = 'pk' + slug_url_kwarg = 'slug' slug_field = 'slug' def get_serializer_context(self): @@ -90,7 +92,8 @@ class GenericAPIView(views.APIView): """ Paginate a queryset. """ - paginator = paginator_class(queryset, page_size, allow_empty_first_page=self.allow_empty) + paginator = paginator_class(queryset, page_size, + allow_empty_first_page=self.allow_empty) page_kwarg = self.page_kwarg page = self.kwargs.get(page_kwarg) or self.request.GET.get(page_kwarg) or 1 try: @@ -118,6 +121,7 @@ class GenericAPIView(views.APIView): backend = self.filter_backend() return backend.filter_queryset(self.request, queryset, self) + ######################## ### The following methods provide default implementations ### that you may want to override for more complex cases. @@ -204,7 +208,9 @@ class GenericAPIView(views.APIView): return obj - ### The following methods are intended to be overridden. + ######################## + ### The following are placeholder methods, + ### and are intended to be overridden. def pre_save(self, obj): """ @@ -222,8 +228,10 @@ class GenericAPIView(views.APIView): pass +########################################################## ### Concrete view classes that provide method handlers ### ### by composing the mixin classes with the base view. ### +########################################################## class CreateAPIView(mixins.CreateModelMixin, GenericAPIView): @@ -338,7 +346,9 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, return self.destroy(request, *args, **kwargs) +########################## ### Deprecated classes ### +########################## class MultipleObjectAPIView(GenericAPIView): pass -- cgit v1.2.3 From 3f91379e4eaf07418a99fda1932af91511c55e7b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 10 Apr 2013 09:24:24 +0100 Subject: Fix 1.3 compat issue. Closes #780 --- rest_framework/compat.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 6551723a..067e9018 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -400,19 +400,23 @@ except ImportError: try: from django.utils.html import smart_urlquote except ImportError: + import re + from django.utils.encoding import smart_str try: from urllib.parse import quote, urlsplit, urlunsplit except ImportError: # Python 2 from urllib import quote from urlparse import urlsplit, urlunsplit + unquoted_percents_re = re.compile(r'%(?![0-9A-Fa-f]{2})') + def smart_urlquote(url): "Quotes a URL if it isn't already quoted." # Handle IDN before quoting. scheme, netloc, path, query, fragment = urlsplit(url) try: - netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE - except UnicodeError: # invalid domain part + netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE + except UnicodeError: # invalid domain part pass else: url = urlunsplit((scheme, netloc, path, query, fragment)) @@ -421,7 +425,7 @@ except ImportError: # contains a % not followed by two hexadecimal digits. See #9655. if '%' not in url or unquoted_percents_re.search(url): # See http://bugs.python.org/issue2637 - url = quote(force_bytes(url), safe=b'!*\'();:@&=+$,/?#[]~') + url = quote(smart_str(url), safe=b'!*\'();:@&=+$,/?#[]~') return force_text(url) -- cgit v1.2.3 From 76e039d70e8fc7f1d5c65180cb544abab81e600e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 10 Apr 2013 22:38:02 +0100 Subject: First pass on automatically including reverse relationship --- rest_framework/serializers.py | 43 ++++++++++++++++++++++++++++++++------ rest_framework/tests/serializer.py | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e28bbe81..eac909c7 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -598,6 +598,24 @@ class ModelSerializer(Serializer): if field: ret[model_field.name] = field + # Reverse relationships are only included if they are explicitly + # present in `Meta.fields`. + if self.opts.fields: + reverse = opts.get_all_related_objects() + reverse += opts.get_all_related_many_to_many_objects() + for rel in reverse: + name = rel.get_accessor_name() + if name not in self.opts.fields: + continue + + if nested: + field = self.get_nested_field(None, rel) + else: + field = self.get_related_field(None, rel, to_many=True) + + if field: + ret[name] = field + for field_name in self.opts.read_only_fields: assert field_name in ret, \ "read_only_fields on '%s' included invalid item '%s'" % \ @@ -612,24 +630,36 @@ class ModelSerializer(Serializer): """ return self.get_field(model_field) - def get_nested_field(self, model_field): + def get_nested_field(self, model_field, rel=None): """ Creates a default instance of a nested relational field. """ + if rel: + model_class = rel.model + else: + model_class = model_field.rel.to + class NestedModelSerializer(ModelSerializer): class Meta: - model = model_field.rel.to + model = model_class return NestedModelSerializer() - def get_related_field(self, model_field, to_many=False): + def get_related_field(self, model_field, rel=None, to_many=False): """ Creates a default instance of a flat relational field. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) + if rel: + model_class = rel.model + required = True + else: + model_class = model_field.rel.to + required = not(model_field.null or model_field.blank) + kwargs = { - 'required': not(model_field.null or model_field.blank), - 'queryset': model_field.rel.to._default_manager, + 'required': required, + 'queryset': model_class._default_manager, 'many': to_many } @@ -797,7 +827,8 @@ class HyperlinkedModelSerializer(ModelSerializer): return self._default_view_name % format_kwargs def get_pk_field(self, model_field): - return None + if self.opts.fields and model_field.name in self.opts.fields: + return self.get_field(model_field) def get_related_field(self, model_field, to_many): """ diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 05217f35..3a94fad5 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -738,6 +738,43 @@ class ManyRelatedTests(TestCase): self.assertEqual(serializer.data, expected) + def test_include_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + fields = ('id', 'title', 'blogpostcomment_set') + + serializer = BlogPostSerializer(instance=post) + expected = { + 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2] + } + self.assertEqual(serializer.data, expected) + + def test_depth_include_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + fields = ('id', 'title', 'blogpostcomment_set') + depth = 1 + + serializer = BlogPostSerializer(instance=post) + expected = { + 'id': 1, 'title': 'Test blog post', + 'blogpostcomment_set': [ + {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1}, + {'id': 2, 'text': 'I love this blog post', 'blog_post': 1} + ] + } + self.assertEqual(serializer.data, expected) + def test_callable_source(self): post = BlogPost.objects.create(title="Test blog post") post.blogpostcomment_set.create(text="I love this blog post") -- cgit v1.2.3 From e0020c5b033308cd789408a8823d6707deed8032 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 11 Apr 2013 15:48:18 +0100 Subject: Simplify get_object --- rest_framework/generics.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index ba7d1f43..ea62123d 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -4,9 +4,10 @@ Generic views that provide commonly needed behaviour. from __future__ import unicode_literals from rest_framework import views, mixins from rest_framework.settings import api_settings -from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist +from django.core.exceptions import ImproperlyConfigured from django.core.paginator import Paginator, InvalidPage from django.http import Http404 +from django.shortcuts import get_object_or_404 from django.utils.translation import ugettext as _ ### Base classes for the generic views ### @@ -163,7 +164,7 @@ class GenericAPIView(views.APIView): # TODO: Deprecation warning return self.model._default_manager.all() - raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" + raise ImproperlyConfigured("'%s' must define 'queryset'" % self.__class__.__name__) def get_object(self, queryset=None): @@ -177,6 +178,8 @@ class GenericAPIView(views.APIView): # Determine the base queryset to use. if queryset is None: queryset = self.filter_queryset(self.get_queryset()) + else: + pass # Deprecation warning # Perform the lookup filtering. pk = self.kwargs.get(self.pk_url_kwarg, None) @@ -184,24 +187,19 @@ class GenericAPIView(views.APIView): lookup = self.kwargs.get(self.lookup_kwarg, None) if lookup is not None: - queryset = queryset.filter(**{self.lookup_kwarg: lookup}) + filter_kwargs = {self.lookup_kwarg: lookup} elif pk is not None: # TODO: Deprecation warning - queryset = queryset.filter(pk=pk) + filter_kwargs = {'pk': pk} elif slug is not None: # TODO: Deprecation warning - queryset = queryset.filter(**{self.slug_field: slug}) + filter_kwargs = {self.slug_field: slug} else: raise AttributeError("Generic detail view %s must be called with " "either an object pk or a slug." % self.__class__.__name__) - try: - # Get the single item from the filtered queryset - obj = queryset.get() - except ObjectDoesNotExist: - raise Http404(_("No %(verbose_name)s found matching the query") % - {'verbose_name': queryset.model._meta.verbose_name}) + obj = get_object_or_404(queryset, **filter_kwargs) # May raise a permission denied self.check_object_permissions(self.request, obj) -- cgit v1.2.3 From 5a5a602f8ad2e84b36aa88d86334c5afecc40295 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 13 Apr 2013 20:07:36 +0100 Subject: Allow overriding get_object to work correctly. Fixes #784 --- rest_framework/generics.py | 1 + rest_framework/mixins.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 36ecf915..f9133c73 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -130,6 +130,7 @@ class SingleObjectAPIView(SingleObjectMixin, GenericAPIView): """ Override default to add support for object-level permissions. """ + queryset = self.filter_queryset(self.get_queryset()) obj = super(SingleObjectAPIView, self).get_object(queryset) self.check_object_permissions(self.request, obj) return obj diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 7d9a6e65..3bd7d6df 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -97,9 +97,7 @@ class RetrieveModelMixin(object): Should be mixed in with `SingleObjectAPIView`. """ def retrieve(self, request, *args, **kwargs): - queryset = self.get_queryset() - filtered_queryset = self.filter_queryset(queryset) - self.object = self.get_object(filtered_queryset) + self.object = self.get_object() serializer = self.get_serializer(self.object) return Response(serializer.data) -- cgit v1.2.3 From 750451f5b4de61684f4a4e69dd5776bd84ac054c Mon Sep 17 00:00:00 2001 From: Johannes Spielmann Date: Sun, 14 Apr 2013 18:30:44 +0200 Subject: adding test case for generic view with overriden get_object() --- rest_framework/tests/generics.py | 173 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index f564890c..b40b0102 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -24,6 +24,28 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView): model = BasicModel +class InstanceDetailView(generics.RetrieveUpdateDestroyAPIView): + """ + Example detail view for override of get_object(). + """ + + # we have to implement this too, otherwise we can't be sure that get_object + # will be called + def get_serializer(self, instance=None, data=None, files=None, partial=None): + class InstanceDetailSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + return InstanceDetailSerializer(instance=instance, data=data, files=files, partial=partial) + + def get_object(self): + try: + pk = int(self.kwargs['pk']) + self.object = BasicModel.objects.get(id=pk) + return self.object + except BasicModel.DoesNotExist: + return self.permission_denied(self.request) + + class SlugSerializer(serializers.ModelSerializer): slug = serializers.Field() # read only @@ -301,6 +323,157 @@ class TestInstanceView(TestCase): new_obj = SlugBasedModel.objects.get(slug='test_slug') self.assertEqual(new_obj.text, 'foobar') +class TestInstanceDetailView(TestCase): + """ + Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the + queryset/model mechanism but instead overrides get_object() + """ + def setUp(self): + """ + Create 3 BasicModel intances. + """ + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view_class = InstanceDetailView + self.view = InstanceDetailView.as_view() + + def test_get_instance_view(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should return a single object. + """ + request = factory.get('/1') + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0]) + + def test_post_instance_view(self): + """ + POST requests to RetrieveUpdateDestroyAPIView should not be allowed + """ + content = {'text': 'foobar'} + request = factory.post('/', json.dumps(content), + content_type='application/json') + with self.assertNumQueries(0): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) + + def test_put_instance_view(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView should update an object. + """ + content = {'text': 'foobar'} + request = factory.put('/1', json.dumps(content), + content_type='application/json') + with self.assertNumQueries(2): + response = self.view(request, pk='1').render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) + updated = self.objects.get(id=1) + self.assertEqual(updated.text, 'foobar') + + def test_patch_instance_view(self): + """ + PATCH requests to RetrieveUpdateDestroyAPIView should update an object. + """ + content = {'text': 'foobar'} + request = factory.patch('/1', json.dumps(content), + content_type='application/json') + + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) + updated = self.objects.get(id=1) + self.assertEqual(updated.text, 'foobar') + + def test_delete_instance_view(self): + """ + DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. + """ + request = factory.delete('/1') + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(response.content, six.b('')) + ids = [obj.id for obj in self.objects.all()] + self.assertEqual(ids, [2, 3]) + + def test_options_instance_view(self): + """ + OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata + """ + request = factory.options('/') + with self.assertNumQueries(0): + response = self.view(request).render() + expected = { + 'parses': [ + 'application/json', + 'application/x-www-form-urlencoded', + 'multipart/form-data' + ], + 'renders': [ + 'application/json', + 'text/html' + ], + 'name': 'Instance Detail', + 'description': 'Example detail view for override of get_object().' + } + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, expected) + + def test_put_cannot_set_id(self): + """ + PUT requests to create a new object should not be able to set the id. + """ + content = {'id': 999, 'text': 'foobar'} + request = factory.put('/1', json.dumps(content), + content_type='application/json') + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) + updated = self.objects.get(id=1) + self.assertEqual(updated.text, 'foobar') + + def test_put_to_deleted_instance(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView should create an object + if it does not currently exist. In our DetailView, however, + we cannot access any other id's than those that already exist. + See the InstanceView for the normal behaviour. + """ + self.objects.get(id=1).delete() + content = {'text': 'foobar'} + request = factory.put('/1', json.dumps(content), + content_type='application/json') + with self.assertNumQueries(1): + response = self.view(request, pk=5).render() + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_put_as_create_on_id_based_url(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView should create an object + at the requested url if it doesn't exist. In our DetailView, however, + we cannot access any other id's than those that already exist. + See the InstanceView for the normal behaviour. + """ + content = {'text': 'foobar'} + # pk fields can not be created on demand, only the database can set the pk for a new object + request = factory.put('/5', json.dumps(content), + content_type='application/json') + with self.assertNumQueries(1): + response = self.view(request, pk=5).render() + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + # Regression test for #285 -- cgit v1.2.3 From ad436d966fa9ee2f5817aa5c26612c82558c4262 Mon Sep 17 00:00:00 2001 From: Stephan Groß Date: Mon, 15 Apr 2013 12:40:18 +0200 Subject: Add DecimalField support --- rest_framework/fields.py | 75 +++++++++++++++++++ rest_framework/tests/fields.py | 165 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f3496b53..a1b9f546 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals import copy import datetime +from decimal import Decimal, DecimalException import inspect import re import warnings @@ -721,6 +722,80 @@ class FloatField(WritableField): raise ValidationError(msg) +class DecimalField(WritableField): + type_name = 'DecimalField' + form_field_class = forms.DecimalField + + default_error_messages = { + 'invalid': _('Enter a number.'), + 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), + 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), + 'max_digits': _('Ensure that there are no more than %s digits in total.'), + 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), + 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') + } + + def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): + self.max_value, self.min_value = max_value, min_value + self.max_digits, self.decimal_places = max_digits, decimal_places + super(DecimalField, self).__init__(self, *args, **kwargs) + + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) + + def from_native(self, value): + """ + Validates that the input is a decimal number. Returns a Decimal + instance. Returns None for empty values. Ensures that there are no more + than max_digits in the number, and no more than decimal_places digits + after the decimal point. + """ + if value in validators.EMPTY_VALUES: + return None + value = smart_text(value).strip() + try: + value = Decimal(value) + except DecimalException: + raise ValidationError(self.error_messages['invalid']) + return value + + def to_native(self, value): + if value is not None: + return str(value) + return value + + def validate(self, value): + super(DecimalField, self).validate(value) + if value in validators.EMPTY_VALUES: + return + # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, + # since it is never equal to itself. However, NaN is the only value that + # isn't equal to itself, so we can use this to identify NaN + if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): + raise ValidationError(self.error_messages['invalid']) + sign, digittuple, exponent = value.as_tuple() + decimals = abs(exponent) + # digittuple doesn't include any leading zeros. + digits = len(digittuple) + if decimals > digits: + # We have leading zeros up to or past the decimal point. Count + # everything past the decimal point as a digit. We do not count + # 0 before the decimal point as a digit since that would mean + # we would not allow max_digits = decimal_places. + digits = decimals + whole_digits = digits - decimals + + if self.max_digits is not None and digits > self.max_digits: + raise ValidationError(self.error_messages['max_digits'] % self.max_digits) + if self.decimal_places is not None and decimals > self.decimal_places: + raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) + if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): + raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) + return value + + class FileField(WritableField): use_files = True type_name = 'FileField' diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index 19c663d8..f833aa32 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -3,12 +3,14 @@ General serializer field tests. """ from __future__ import unicode_literals import datetime +from decimal import Decimal from django.db import models from django.test import TestCase from django.core import validators from rest_framework import serializers +from rest_framework.serializers import Serializer class TimestampedModel(models.Model): @@ -481,3 +483,166 @@ class TimeFieldTest(TestCase): self.assertEqual('04 - 00 [000000]', result_1) self.assertEqual('04 - 59 [000000]', result_2) self.assertEqual('04 - 59 [000200]', result_3) + + +class DecimalFieldTest(TestCase): + """ + Tests for the DecimalField from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts string values + """ + f = serializers.DecimalField() + result_1 = f.from_native('9000') + result_2 = f.from_native('1.00000001') + + self.assertEqual(Decimal('9000'), result_1) + self.assertEqual(Decimal('1.00000001'), result_2) + + def test_from_native_invalid_string(self): + """ + Make sure from_native() raises ValidationError on passing invalid string + """ + f = serializers.DecimalField() + + try: + f.from_native('123.45.6') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Enter a number."]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_integer(self): + """ + Make sure from_native() accepts integer values + """ + f = serializers.DecimalField() + result = f.from_native(9000) + + self.assertEqual(Decimal('9000'), result) + + def test_from_native_float(self): + """ + Make sure from_native() accepts float values + """ + f = serializers.DecimalField() + result = f.from_native(1.00000001) + + self.assertEqual(Decimal('1.00000001'), result) + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.DecimalField() + result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DecimalField() + result = f.from_native(None) + + self.assertEqual(result, None) + + def test_to_native(self): + """ + Make sure to_native() returns Decimal as string. + """ + f = serializers.DecimalField() + + result_1 = f.to_native(Decimal('9000')) + result_2 = f.to_native(Decimal('1.00000001')) + + self.assertEqual('9000', result_1) + self.assertEqual('1.00000001', result_2) + + def test_to_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DecimalField(required=False) + self.assertEqual(None, f.to_native(None)) + + def test_valid_serialization(self): + """ + Make sure the serializer works correctly + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_value=9010, + min_value=9000, + max_digits=6, + decimal_places=2) + + self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid()) + self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid()) + self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid()) + + self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid()) + self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid()) + self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid()) + + def test_raise_max_value(self): + """ + Make sure max_value violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_value=100) + + s = DecimalSerializer(data={'decimal_field': '123'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': [u'Ensure this value is less than or equal to 100.']}) + + def test_raise_min_value(self): + """ + Make sure min_value violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(min_value=100) + + s = DecimalSerializer(data={'decimal_field': '99'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': [u'Ensure this value is greater than or equal to 100.']}) + + def test_raise_max_digits(self): + """ + Make sure max_digits violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_digits=5) + + s = DecimalSerializer(data={'decimal_field': '123.456'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 5 digits in total.']}) + + def test_raise_max_decimal_places(self): + """ + Make sure max_decimal_places violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(decimal_places=3) + + s = DecimalSerializer(data={'decimal_field': '123.4567'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 3 decimal places.']}) + + def test_raise_max_whole_digits(self): + """ + Make sure max_whole_digits violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) + + s = DecimalSerializer(data={'decimal_field': '12345.6'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file -- cgit v1.2.3 From 37f7d8bc0f00feb1a4d23c0e163eab8b47faaec3 Mon Sep 17 00:00:00 2001 From: Stephan Groß Date: Mon, 15 Apr 2013 12:55:29 +0200 Subject: Fix unicodes --- rest_framework/tests/fields.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index f833aa32..597180b4 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -597,7 +597,7 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '123'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': [u'Ensure this value is less than or equal to 100.']}) + self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']}) def test_raise_min_value(self): """ @@ -609,7 +609,7 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '99'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': [u'Ensure this value is greater than or equal to 100.']}) + self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']}) def test_raise_max_digits(self): """ @@ -621,7 +621,7 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '123.456'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 5 digits in total.']}) + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']}) def test_raise_max_decimal_places(self): """ @@ -633,7 +633,7 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '123.4567'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 3 decimal places.']}) + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']}) def test_raise_max_whole_digits(self): """ @@ -645,4 +645,4 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '12345.6'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': [u'Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file -- cgit v1.2.3 From c329d2f08511dbc7660af9b8fc94e92d97c015cc Mon Sep 17 00:00:00 2001 From: Stephan Groß Date: Mon, 15 Apr 2013 13:11:41 +0200 Subject: Add DecimalField to field_mapping --- rest_framework/serializers.py | 1 + 1 file changed, 1 insertion(+) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e28bbe81..cbc6586d 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -548,6 +548,7 @@ class ModelSerializer(Serializer): models.DateTimeField: DateTimeField, models.DateField: DateField, models.TimeField: TimeField, + models.DecimalField: DecimalField, models.EmailField: EmailField, models.CharField: CharField, models.URLField: URLField, -- cgit v1.2.3 From 9d80f01bced913dae0859be525b39eaa9df1fdbf Mon Sep 17 00:00:00 2001 From: Stephan Groß Date: Mon, 15 Apr 2013 15:15:55 +0200 Subject: Fix init call --- rest_framework/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index a1b9f546..6be633db 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -738,7 +738,7 @@ class DecimalField(WritableField): def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): self.max_value, self.min_value = max_value, min_value self.max_digits, self.decimal_places = max_digits, decimal_places - super(DecimalField, self).__init__(self, *args, **kwargs) + super(DecimalField, self).__init__(*args, **kwargs) if max_value is not None: self.validators.append(validators.MaxValueValidator(max_value)) -- cgit v1.2.3 From cac669702596cdf768971267e6355fb9223a69e8 Mon Sep 17 00:00:00 2001 From: Stephan Groß Date: Mon, 15 Apr 2013 15:24:14 +0200 Subject: Return Decimal instance instead of string --- rest_framework/fields.py | 5 ----- rest_framework/tests/fields.py | 4 ++-- 2 files changed, 2 insertions(+), 7 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 6be633db..926195be 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -761,11 +761,6 @@ class DecimalField(WritableField): raise ValidationError(self.error_messages['invalid']) return value - def to_native(self, value): - if value is not None: - return str(value) - return value - def validate(self, value): super(DecimalField, self).validate(value) if value in validators.EMPTY_VALUES: diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index 597180b4..3cdfa0f6 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -559,8 +559,8 @@ class DecimalFieldTest(TestCase): result_1 = f.to_native(Decimal('9000')) result_2 = f.to_native(Decimal('1.00000001')) - self.assertEqual('9000', result_1) - self.assertEqual('1.00000001', result_2) + self.assertEqual(Decimal('9000'), result_1) + self.assertEqual(Decimal('1.00000001'), result_2) def test_to_native_none(self): """ -- cgit v1.2.3 From 23289b023db230f73e4a5bfae24a56c79e3fcd4b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 16 Apr 2013 14:32:46 +0100 Subject: Explicit error if dev does not return a response from the view --- rest_framework/views.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/views.py b/rest_framework/views.py index 81cbdcbb..7c97607b 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,7 +3,7 @@ Provides an APIView class that is used as the base of all class-based views. """ from __future__ import unicode_literals from django.core.exceptions import PermissionDenied -from django.http import Http404 +from django.http import Http404, HttpResponse from django.utils.html import escape from django.utils.safestring import mark_safe from django.views.decorators.csrf import csrf_exempt @@ -327,6 +327,12 @@ class APIView(View): """ Returns the final response object. """ + # Make the error obvious if a proper response is not returned + assert isinstance(response, HttpResponse), ( + 'Expected a `Response` to be returned from the view, ' + 'but received a `%s`' % type(response) + ) + if isinstance(response, Response): if not getattr(request, 'accepted_renderer', None): neg = self.perform_content_negotiation(request, force=True) -- cgit v1.2.3 From 37fe0bf0de25d28d792a291d5a84987ab71c4cb6 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Apr 2013 09:03:24 +0100 Subject: Remove unneccessary tests from #789, and bit of cleanup. --- rest_framework/tests/generics.py | 165 ++++----------------------------------- 1 file changed, 17 insertions(+), 148 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index b40b0102..4a13389a 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals from django.db import models +from django.shortcuts import get_object_or_404 from django.test import TestCase from rest_framework import generics, serializers, status from rest_framework.tests.utils import RequestFactory @@ -24,28 +25,6 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView): model = BasicModel -class InstanceDetailView(generics.RetrieveUpdateDestroyAPIView): - """ - Example detail view for override of get_object(). - """ - - # we have to implement this too, otherwise we can't be sure that get_object - # will be called - def get_serializer(self, instance=None, data=None, files=None, partial=None): - class InstanceDetailSerializer(serializers.ModelSerializer): - class Meta: - model = BasicModel - return InstanceDetailSerializer(instance=instance, data=data, files=files, partial=partial) - - def get_object(self): - try: - pk = int(self.kwargs['pk']) - self.object = BasicModel.objects.get(id=pk) - return self.object - except BasicModel.DoesNotExist: - return self.permission_denied(self.request) - - class SlugSerializer(serializers.ModelSerializer): slug = serializers.Field() # read only @@ -323,7 +302,8 @@ class TestInstanceView(TestCase): new_obj = SlugBasedModel.objects.get(slug='test_slug') self.assertEqual(new_obj.text, 'foobar') -class TestInstanceDetailView(TestCase): + +class TestOverriddenGetObject(TestCase): """ Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the queryset/model mechanism but instead overrides get_object() @@ -340,139 +320,28 @@ class TestInstanceDetailView(TestCase): {'id': obj.id, 'text': obj.text} for obj in self.objects.all() ] - self.view_class = InstanceDetailView - self.view = InstanceDetailView.as_view() - def test_get_instance_view(self): - """ - GET requests to RetrieveUpdateDestroyAPIView should return a single object. - """ - request = factory.get('/1') - with self.assertNumQueries(1): - response = self.view(request, pk=1).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, self.data[0]) + class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): + """ + Example detail view for override of get_object(). + """ + model = BasicModel - def test_post_instance_view(self): - """ - POST requests to RetrieveUpdateDestroyAPIView should not be allowed - """ - content = {'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') - with self.assertNumQueries(0): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) - - def test_put_instance_view(self): - """ - PUT requests to RetrieveUpdateDestroyAPIView should update an object. - """ - content = {'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') - with self.assertNumQueries(2): - response = self.view(request, pk='1').render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) - updated = self.objects.get(id=1) - self.assertEqual(updated.text, 'foobar') - - def test_patch_instance_view(self): - """ - PATCH requests to RetrieveUpdateDestroyAPIView should update an object. - """ - content = {'text': 'foobar'} - request = factory.patch('/1', json.dumps(content), - content_type='application/json') - - with self.assertNumQueries(2): - response = self.view(request, pk=1).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) - updated = self.objects.get(id=1) - self.assertEqual(updated.text, 'foobar') - - def test_delete_instance_view(self): - """ - DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. - """ - request = factory.delete('/1') - with self.assertNumQueries(2): - response = self.view(request, pk=1).render() - self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - self.assertEqual(response.content, six.b('')) - ids = [obj.id for obj in self.objects.all()] - self.assertEqual(ids, [2, 3]) + def get_object(self): + pk = int(self.kwargs['pk']) + return get_object_or_404(BasicModel.objects.all(), id=pk) - def test_options_instance_view(self): - """ - OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata - """ - request = factory.options('/') - with self.assertNumQueries(0): - response = self.view(request).render() - expected = { - 'parses': [ - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data' - ], - 'renders': [ - 'application/json', - 'text/html' - ], - 'name': 'Instance Detail', - 'description': 'Example detail view for override of get_object().' - } - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, expected) + self.view = OverriddenGetObjectView.as_view() - def test_put_cannot_set_id(self): + def test_overridden_get_object_view(self): """ - PUT requests to create a new object should not be able to set the id. + GET requests to RetrieveUpdateDestroyAPIView should return a single object. """ - content = {'id': 999, 'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') - with self.assertNumQueries(2): + request = factory.get('/1') + with self.assertNumQueries(1): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) - updated = self.objects.get(id=1) - self.assertEqual(updated.text, 'foobar') - - def test_put_to_deleted_instance(self): - """ - PUT requests to RetrieveUpdateDestroyAPIView should create an object - if it does not currently exist. In our DetailView, however, - we cannot access any other id's than those that already exist. - See the InstanceView for the normal behaviour. - """ - self.objects.get(id=1).delete() - content = {'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') - with self.assertNumQueries(1): - response = self.view(request, pk=5).render() - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - def test_put_as_create_on_id_based_url(self): - """ - PUT requests to RetrieveUpdateDestroyAPIView should create an object - at the requested url if it doesn't exist. In our DetailView, however, - we cannot access any other id's than those that already exist. - See the InstanceView for the normal behaviour. - """ - content = {'text': 'foobar'} - # pk fields can not be created on demand, only the database can set the pk for a new object - request = factory.put('/5', json.dumps(content), - content_type='application/json') - with self.assertNumQueries(1): - response = self.view(request, pk=5).render() - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - + self.assertEqual(response.data, self.data[0]) # Regression test for #285 -- cgit v1.2.3 From ea55143a2308b396c8df6f59a0f6d663c1067163 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Apr 2013 09:07:20 +0100 Subject: Version 2.2.7 --- rest_framework/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 7ac12058..856badc6 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.2.6' +__version__ = '2.2.7' VERSION = __version__ # synonym -- cgit v1.2.3 From 33f494fcc89711ab7e97f47fe8d9b287aac4730f Mon Sep 17 00:00:00 2001 From: forgingdestiny Date: Wed, 17 Apr 2013 10:14:36 -0400 Subject: add branding and style blocks --- .../templates/rest_framework/login_base.html | 55 ++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 rest_framework/templates/rest_framework/login_base.html (limited to 'rest_framework') diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html new file mode 100644 index 00000000..380d5820 --- /dev/null +++ b/rest_framework/templates/rest_framework/login_base.html @@ -0,0 +1,55 @@ +{% load url from future %} +{% load rest_framework %} + + + + {% block style %} + {% block bootstrap_theme %}{% endblock %} + + + {% endblock %} + + + + +
+
+ +
+
+
+ {% block branding %}

Django REST framework

{% endblock %} +
+
+ +
+
+
+ {% csrf_token %} +
+
+ + +
+
+
+
+ + +
+
+ +
+ +
+
+
+
+
+ +
+
+ + + + -- cgit v1.2.3 From 03c736338fa04092da99d7d9ea202c8778998b38 Mon Sep 17 00:00:00 2001 From: forgingdestiny Date: Wed, 17 Apr 2013 10:15:02 -0400 Subject: extend base login template --- rest_framework/templates/rest_framework/login.html | 54 +--------------------- 1 file changed, 2 insertions(+), 52 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html index e10ce20f..b7629327 100644 --- a/rest_framework/templates/rest_framework/login.html +++ b/rest_framework/templates/rest_framework/login.html @@ -1,53 +1,3 @@ -{% load url from future %} -{% load rest_framework %} - +{% extends "rest_framework/login_base.html" %} - - - - - - - - -
-
- -
-
-
-

Django REST framework

-
-
- -
-
-
- {% csrf_token %} -
-
- - -
-
-
-
- - -
-
- -
- -
-
-
-
-
- -
-
- - - - +{# Override this template in your own templates directory to customize #} -- cgit v1.2.3 From 4bf1a09baeb885863e6028b97c2d51b26fb18534 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 23 Apr 2013 11:31:38 +0100 Subject: Ensure implementation of reverse relations in 'fields' is backwards compatible --- rest_framework/permissions.py | 2 +- rest_framework/serializers.py | 122 +++++++++++++++------------- rest_framework/tests/relations_hyperlink.py | 16 ++-- rest_framework/tests/relations_nested.py | 24 ++---- rest_framework/tests/relations_pk.py | 17 ++-- 5 files changed, 95 insertions(+), 86 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index ae895f39..2aa45c71 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -25,7 +25,7 @@ class BasePermission(object): """ Return `True` if permission is granted, `False` otherwise. """ - if len(inspect.getargspec(self.has_permission)[0]) == 4: + if len(inspect.getargspec(self.has_permission).args) == 4: warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. ' 'Use `has_object_permission()` instead for object permissions.', PendingDeprecationWarning, stacklevel=2) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index eac909c7..b4327af1 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -568,54 +568,73 @@ class ModelSerializer(Serializer): assert cls is not None, \ "Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__ opts = get_concrete_model(cls)._meta - pk_field = opts.pk + ret = SortedDict() + nested = bool(self.opts.depth) - # If model is a child via multitable inheritance, use parent's pk + # Deal with adding the primary key field + pk_field = opts.pk while pk_field.rel and pk_field.rel.parent_link: + # If model is a child via multitable inheritance, use parent's pk pk_field = pk_field.rel.to._meta.pk - fields = [pk_field] - fields += [field for field in opts.fields if field.serialize] - fields += [field for field in opts.many_to_many if field.serialize] + field = self.get_pk_field(pk_field) + if field: + ret[pk_field.name] = field - ret = SortedDict() - nested = bool(self.opts.depth) - is_pk = True # First field in the list is the pk - - for model_field in fields: - if is_pk: - field = self.get_pk_field(model_field) - is_pk = False - elif model_field.rel and nested: - field = self.get_nested_field(model_field) - elif model_field.rel: + # Deal with forward relationships + forward_rels = [field for field in opts.fields if field.serialize] + forward_rels += [field for field in opts.many_to_many if field.serialize] + + for model_field in forward_rels: + if model_field.rel: to_many = isinstance(model_field, models.fields.related.ManyToManyField) - field = self.get_related_field(model_field, to_many=to_many) + related_model = model_field.rel.to + + if model_field.rel and nested: + if len(inspect.getargspec(self.get_nested_field).args) == 2: + # TODO: deprecation warning + field = self.get_nested_field(model_field) + else: + field = self.get_nested_field(model_field, related_model, to_many) + elif model_field.rel: + if len(inspect.getargspec(self.get_nested_field).args) == 3: + # TODO: deprecation warning + field = self.get_related_field(model_field, to_many=to_many) + else: + field = self.get_related_field(model_field, related_model, to_many) else: field = self.get_field(model_field) if field: ret[model_field.name] = field - # Reverse relationships are only included if they are explicitly - # present in `Meta.fields`. - if self.opts.fields: - reverse = opts.get_all_related_objects() - reverse += opts.get_all_related_many_to_many_objects() - for rel in reverse: - name = rel.get_accessor_name() - if name not in self.opts.fields: - continue - - if nested: - field = self.get_nested_field(None, rel) - else: - field = self.get_related_field(None, rel, to_many=True) + # Deal with reverse relationships + if not self.opts.fields: + reverse_rels = [] + else: + # Reverse relationships are only included if they are explicitly + # present in the `fields` option on the serializer + reverse_rels = opts.get_all_related_objects() + reverse_rels += opts.get_all_related_many_to_many_objects() + + for relation in reverse_rels: + accessor_name = relation.get_accessor_name() + if accessor_name not in self.opts.fields: + continue + related_model = relation.model + to_many = relation.field.rel.multiple - if field: - ret[name] = field + if nested: + field = self.get_nested_field(None, related_model, to_many) + else: + field = self.get_related_field(None, related_model, to_many) + + if field: + ret[accessor_name] = field + # Add the `read_only` flag to any fields that have bee specified + # in the `read_only_fields` option for field_name in self.opts.read_only_fields: assert field_name in ret, \ "read_only_fields on '%s' included invalid item '%s'" % \ @@ -630,39 +649,30 @@ class ModelSerializer(Serializer): """ return self.get_field(model_field) - def get_nested_field(self, model_field, rel=None): + def get_nested_field(self, model_field, related_model, to_many): """ Creates a default instance of a nested relational field. """ - if rel: - model_class = rel.model - else: - model_class = model_field.rel.to - class NestedModelSerializer(ModelSerializer): class Meta: - model = model_class - return NestedModelSerializer() + model = related_model + return NestedModelSerializer(many=to_many) - def get_related_field(self, model_field, rel=None, to_many=False): + def get_related_field(self, model_field, related_model, to_many): """ Creates a default instance of a flat relational field. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - if rel: - model_class = rel.model - required = True - else: - model_class = model_field.rel.to - required = not(model_field.null or model_field.blank) kwargs = { - 'required': required, - 'queryset': model_class._default_manager, + 'queryset': related_model._default_manager, 'many': to_many } + if model_field: + kwargs['required'] = not(model_field.null or model_field.blank) + return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): @@ -830,19 +840,21 @@ class HyperlinkedModelSerializer(ModelSerializer): if self.opts.fields and model_field.name in self.opts.fields: return self.get_field(model_field) - def get_related_field(self, model_field, to_many): + def get_related_field(self, model_field, related_model, to_many): """ Creates a default instance of a flat relational field. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - rel = model_field.rel.to kwargs = { - 'required': not(model_field.null or model_field.blank), - 'queryset': rel._default_manager, - 'view_name': self._get_default_view_name(rel), + 'queryset': related_model._default_manager, + 'view_name': self._get_default_view_name(related_model), 'many': to_many } + + if model_field: + kwargs['required'] = not(model_field.null or model_field.blank) + return HyperlinkedRelatedField(**kwargs) def get_identity(self, data): diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index b5702a48..b1eed9a7 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -26,42 +26,44 @@ urlpatterns = patterns('', ) +# ManyToMany class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): - sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail') - class Meta: model = ManyToManyTarget + fields = ('url', 'name', 'sources') class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = ManyToManySource + fields = ('url', 'name', 'targets') +# ForeignKey class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): - sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail') - class Meta: model = ForeignKeyTarget + fields = ('url', 'name', 'sources') class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = ForeignKeySource + fields = ('url', 'name', 'target') # Nullable ForeignKey class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = NullableForeignKeySource + fields = ('url', 'name', 'target') -# OneToOne +# Nullable OneToOne class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): - nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail') - class Meta: model = OneToOneTarget + fields = ('url', 'name', 'nullable_source') # TODO: Add test that .data cannot be accessed prior to .is_valid diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index a125ba65..f6d006b3 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -5,39 +5,31 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null class ForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - depth = 1 - model = ForeignKeySource - - -class FlatForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = ForeignKeySource + fields = ('id', 'name', 'target') + depth = 1 class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = FlatForeignKeySourceSerializer(many=True) - class Meta: model = ForeignKeyTarget + fields = ('id', 'name', 'sources') + depth = 1 class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: - depth = 1 model = NullableForeignKeySource - - -class NullableOneToOneSourceSerializer(serializers.ModelSerializer): - class Meta: - model = NullableOneToOneSource + fields = ('id', 'name', 'target') + depth = 1 class NullableOneToOneTargetSerializer(serializers.ModelSerializer): - nullable_source = NullableOneToOneSourceSerializer() - class Meta: model = OneToOneTarget + fields = ('id', 'name', 'nullable_source') + depth = 1 class ReverseForeignKeyTests(TestCase): diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index f08e1808..5ce8b567 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -5,41 +5,44 @@ from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, Fore from rest_framework.compat import six +# ManyToMany class ManyToManyTargetSerializer(serializers.ModelSerializer): - sources = serializers.PrimaryKeyRelatedField(many=True) - class Meta: model = ManyToManyTarget + fields = ('id', 'name', 'sources') class ManyToManySourceSerializer(serializers.ModelSerializer): class Meta: model = ManyToManySource + fields = ('id', 'name', 'targets') +# ForeignKey class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = serializers.PrimaryKeyRelatedField(many=True) - class Meta: model = ForeignKeyTarget + fields = ('id', 'name', 'sources') class ForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = ForeignKeySource + fields = ('id', 'name', 'target') +# Nullable ForeignKey class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = NullableForeignKeySource + fields = ('id', 'name', 'target') -# OneToOne +# Nullable OneToOne class NullableOneToOneTargetSerializer(serializers.ModelSerializer): - nullable_source = serializers.PrimaryKeyRelatedField() - class Meta: model = OneToOneTarget + fields = ('id', 'name', 'nullable_source') # TODO: Add test that .data cannot be accessed prior to .is_valid -- cgit v1.2.3 From b94da2468cdda6b0ad491574d35097d0e336ea7f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 24 Apr 2013 22:40:24 +0100 Subject: Various clean up and lots of docs --- rest_framework/generics.py | 74 ++++++++++----- rest_framework/routers.py | 231 ++++++++++++++++++++++++++++++++++----------- 2 files changed, 222 insertions(+), 83 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index ae03060b..3440c01d 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -18,21 +18,35 @@ class GenericAPIView(views.APIView): Base class for all other generic views. """ + # You'll need to either set these attributes, + # or override `get_queryset`/`get_serializer_class`. queryset = None serializer_class = None - # Shortcut which may be used in place of `queryset`/`serializer_class` - model = None + # If you want to use object lookups other than pk, set this attribute. + lookup_field = 'pk' - filter_backend = api_settings.FILTER_BACKEND + # Pagination settings paginate_by = api_settings.PAGINATE_BY paginate_by_param = api_settings.PAGINATE_BY_PARAM pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS - model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS page_kwarg = 'page' - lookup_field = 'pk' + + # The filter backend class to use for queryset filtering + filter_backend = api_settings.FILTER_BACKEND + + # Determines if the view will return 200 or 404 responses for empty lists. allow_empty = True + # This shortcut may be used instead of setting either (or both) + # of the `queryset`/`serializer_class` attributes, although using + # the explicit style is generally preferred. + model = None + + # If the `model` shortcut is used instead of `serializer_class`, then the + # serializer class will be constructed using this class as the base. + model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS + ###################################### # These are pending deprecation... @@ -61,7 +75,7 @@ class GenericAPIView(views.APIView): return serializer_class(instance, data=data, files=files, many=many, partial=partial, context=context) - def get_pagination_serializer(self, page=None): + def get_pagination_serializer(self, page): """ Return a serializer instance to use with paginated data. """ @@ -73,32 +87,15 @@ class GenericAPIView(views.APIView): context = self.get_serializer_context() return pagination_serializer_class(instance=page, context=context) - def get_paginate_by(self, queryset=None): - """ - Return the size of pages to use with pagination. - - If `PAGINATE_BY_PARAM` is set it will attempt to get the page size - from a named query parameter in the url, eg. ?page_size=100 - - Otherwise defaults to using `self.paginate_by`. - """ - if self.paginate_by_param: - query_params = self.request.QUERY_PARAMS - try: - return int(query_params[self.paginate_by_param]) - except (KeyError, ValueError): - pass - - return self.paginate_by - def paginate_queryset(self, queryset, page_size, paginator_class=Paginator): """ Paginate a queryset. """ paginator = paginator_class(queryset, page_size, allow_empty_first_page=self.allow_empty) - page_kwarg = self.page_kwarg - page = self.kwargs.get(page_kwarg) or self.request.GET.get(page_kwarg) or 1 + page_kwarg = self.kwargs.get(self.page_kwarg) + page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) + page = page_kwarg or page_query_param or 1 try: page_number = int(page) except ValueError: @@ -133,6 +130,27 @@ class GenericAPIView(views.APIView): ### The following methods provide default implementations ### that you may want to override for more complex cases. + def get_paginate_by(self, queryset=None): + """ + Return the size of pages to use with pagination. + + If `PAGINATE_BY_PARAM` is set it will attempt to get the page size + from a named query parameter in the url, eg. ?page_size=100 + + Otherwise defaults to using `self.paginate_by`. + """ + if queryset is not None: + pass # TODO: Deprecation warning + + if self.paginate_by_param: + query_params = self.request.QUERY_PARAMS + try: + return int(query_params[self.paginate_by_param]) + except (KeyError, ValueError): + pass + + return self.paginate_by + def get_serializer_class(self): """ Return the class to use for the serializer. @@ -202,6 +220,7 @@ class GenericAPIView(views.APIView): # TODO: Deprecation warning filter_kwargs = {self.slug_field: slug} else: + # TODO: Fix error message raise AttributeError("Generic detail view %s must be called with " "either an object pk or a slug." % self.__class__.__name__) @@ -216,6 +235,9 @@ class GenericAPIView(views.APIView): ######################## ### The following are placeholder methods, ### and are intended to be overridden. + ### + ### The are not called by GenericAPIView directly, + ### but are used by the mixin methods. def pre_save(self, obj): """ diff --git a/rest_framework/routers.py b/rest_framework/routers.py index afc51f3b..febb02b3 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -1,81 +1,198 @@ +""" +Routers provide a convenient and consistent way of automatically +determining the URL conf for your API. + +They are used by simply instantiating a Router class, and then registering +all the required ViewSets with that router. + +For example, you might have a `urls.py` that looks something like this: + + router = routers.DefaultRouter() + router.register('users', UserViewSet, 'user') + router.register('accounts', AccountViewSet, 'account') + + urlpatterns = router.urls +""" from django.conf.urls import url, patterns +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.reverse import reverse +from rest_framework.urlpatterns import format_suffix_patterns class BaseRouter(object): def __init__(self): self.registry = [] - def register(self, prefix, viewset, base_name): - self.registry.append((prefix, viewset, base_name)) + def register(self, prefix, viewset, basename): + self.registry.append((prefix, viewset, basename)) - def get_urlpatterns(self): - raise NotImplemented('get_urlpatterns must be overridden') + def get_urls(self): + raise NotImplemented('get_urls must be overridden') @property - def urlpatterns(self): - if not hasattr(self, '_urlpatterns'): - self._urlpatterns = patterns('', *self.get_urlpatterns()) - return self._urlpatterns - - -class DefaultRouter(BaseRouter): - route_list = [ - (r'$', { - 'get': 'list', - 'post': 'create' - }, 'list'), - (r'(?P[^/]+)/$', { - 'get': 'retrieve', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy' - }, 'detail'), + def urls(self): + if not hasattr(self, '_urls'): + self._urls = patterns('', *self.get_urls()) + return self._urls + + +class SimpleRouter(BaseRouter): + routes = [ + # List route. + ( + r'^{prefix}/$', + { + 'get': 'list', + 'post': 'create' + }, + '{basename}-list' + ), + # Detail route. + ( + r'^{prefix}/{lookup}/$', + { + 'get': 'retrieve', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy' + }, + '{basename}-detail' + ), + # Dynamically generated routes. + # Generated using @action or @link decorators on methods of the viewset. + ( + r'^{prefix}/{lookup}/{methodname}/$', + { + '{httpmethod}': '{methodname}', + }, + '{basename}-{methodname}' + ), ] - extra_routes = r'(?P[^/]+)/%s/$' - name_format = '%s-%s' - def get_urlpatterns(self): + def get_routes(self, viewset): + """ + Augment `self.routes` with any dynamically generated routes. + + Returns a list of 4-tuples, of the form: + `(url_format, method_map, name_format, extra_kwargs)` + """ + + # Determine any `@action` or `@link` decorated methods on the viewset + dynamic_routes = {} + for methodname in dir(viewset): + attr = getattr(viewset, methodname) + httpmethod = getattr(attr, 'bind_to_method', None) + if httpmethod: + dynamic_routes[httpmethod] = methodname + ret = [] - for prefix, viewset, base_name in self.registry: - # Bind regular views - if not getattr(viewset, '_is_viewset', False): - regex = prefix - view = viewset - name = base_name - ret.append(url(regex, view, name=name)) - continue + for url_format, method_map, name_format in self.routes: + if method_map == {'{httpmethod}': '{methodname}'}: + # Dynamic routes + for httpmethod, methodname in dynamic_routes.items(): + extra_kwargs = getattr(viewset, methodname).kwargs + ret.append(( + url_format.replace('{methodname}', methodname), + {httpmethod: methodname}, + name_format.replace('{methodname}', methodname), + extra_kwargs + )) + else: + # Standard route + extra_kwargs = {} + ret.append((url_format, method_map, name_format, extra_kwargs)) - # Bind standard CRUD routes - for suffix, action_mapping, action_name in self.route_list: + return ret - # Only actions which actually exist on the viewset will be bound - bound_actions = {} - for method, action in action_mapping.items(): - if hasattr(viewset, action): - bound_actions[method] = action + def get_method_map(self, viewset, method_map): + """ + Given a viewset, and a mapping of http methods to actions, + return a new mapping which only includes any mappings that + are actually implemented by the viewset. + """ + bound_methods = {} + for method, action in method_map.items(): + if hasattr(viewset, action): + bound_methods[method] = action + return bound_methods + + def get_lookup_regex(self, viewset): + """ + Given a viewset, return the portion of URL regex that is used + to match against a single instance. + """ + base_regex = '(?P<{lookup_field}>[^/]+)' + lookup_field = getattr(viewset, 'lookup_field', 'pk') + return base_regex.format(lookup_field=lookup_field) + + def get_urls(self): + """ + Use the registered viewsets to generate a list of URL patterns. + """ + ret = [] - # Build the url pattern - regex = prefix + suffix - view = viewset.as_view(bound_actions, name_suffix=action_name) - name = self.name_format % (base_name, action_name) - ret.append(url(regex, view, name=name)) + for prefix, viewset, basename in self.registry: + lookup = self.get_lookup_regex(viewset) + routes = self.get_routes(viewset) - # Bind any extra `@action` or `@link` routes - for action_name in dir(viewset): - func = getattr(viewset, action_name) - http_method = getattr(func, 'bind_to_method', None) + for url_format, method_map, name_format, extra_kwargs in routes: - # Skip if this is not an @action or @link method - if not http_method: + # Only actions which actually exist on the viewset will be bound + method_map = self.get_method_map(viewset, method_map) + if not method_map: continue - suffix = self.extra_routes % action_name - # Build the url pattern - regex = prefix + suffix - view = viewset.as_view({http_method: action_name}, **func.kwargs) - name = self.name_format % (base_name, action_name) + regex = url_format.format(prefix=prefix, lookup=lookup) + view = viewset.as_view(method_map, **extra_kwargs) + name = name_format.format(basename=basename) ret.append(url(regex, view, name=name)) - # Return a list of url patterns return ret + + +class DefaultRouter(SimpleRouter): + """ + The default router extends the SimpleRouter, but also adds in a default + API root view, and adds format suffix patterns to the URLs. + """ + include_root_view = True + include_format_suffixes = True + + def get_api_root_view(self): + """ + Return a view to use as the API root. + """ + api_root_dict = {} + list_name = self.routes[0][-1] + for prefix, viewset, basename in self.registry: + api_root_dict[prefix] = list_name.format(basename=basename) + + @api_view(('GET',)) + def api_root(request, format=None): + ret = {} + for key, url_name in api_root_dict.items(): + ret[key] = reverse(url_name, request=request, format=format) + return Response(ret) + + return api_root + + def get_urls(self): + """ + Generate the list of URL patterns, including a default root view + for the API, and appending `.json` style format suffixes. + """ + urls = [] + + if self.include_root_view: + root_url = url(r'^$', self.get_api_root_view(), name='api-root') + urls.append(root_url) + + default_urls = super(DefaultRouter, self).get_urls() + urls.extend(default_urls) + + if self.include_format_suffixes: + urls = format_suffix_patterns(urls) + + return urls -- cgit v1.2.3 From 95abe6e8445f59f9e52609b0c54d9276830dbfd3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Apr 2013 12:47:34 +0100 Subject: Cleanup docstrings --- rest_framework/authentication.py | 2 +- rest_framework/decorators.py | 8 +++++++ rest_framework/fields.py | 5 +++++ rest_framework/filters.py | 4 ++++ rest_framework/generics.py | 2 -- rest_framework/negotiation.py | 4 ++++ rest_framework/pagination.py | 6 +++-- rest_framework/relations.py | 6 +++++ rest_framework/request.py | 5 ++--- rest_framework/response.py | 6 +++++ rest_framework/serializers.py | 12 ++++++++++ rest_framework/throttling.py | 8 ++++--- rest_framework/views.py | 2 +- rest_framework/viewsets.py | 48 +++++++++++++++++++++++++++++----------- 14 files changed, 93 insertions(+), 25 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 1eebb5b9..9caca788 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -1,5 +1,5 @@ """ -Provides a set of pluggable authentication policies. +Provides various authentication policies. """ from __future__ import unicode_literals import base64 diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 00b37f8b..81e585e1 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,3 +1,11 @@ +""" +The most imporant decorator in this module is `@api_view`, which is used +for writing function-based views with REST framework. + +There are also various decorators for setting the API policies on function +based views, as well as the `@action` and `@link` decorators, which are +used to annotate methods on viewsets that should be included by routers. +""" from __future__ import unicode_literals from rest_framework.compat import six from rest_framework.views import APIView diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f3496b53..949f68d6 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,3 +1,8 @@ +""" +Serializer fields perform validation on incoming data. + +They are very similar to Django's form fields. +""" from __future__ import unicode_literals import copy diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 413fa0d2..5e1cdbac 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -1,3 +1,7 @@ +""" +Provides generic filtering backends that can be used to filter the results +returned by list views. +""" from __future__ import unicode_literals from rest_framework.compat import django_filters diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 3440c01d..56471cfa 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -10,8 +10,6 @@ from django.http import Http404 from django.shortcuts import get_object_or_404 from django.utils.translation import ugettext as _ -### Base classes for the generic views ### - class GenericAPIView(views.APIView): """ diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index 0694d35f..4d205c0e 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -1,3 +1,7 @@ +""" +Content negotiation deals with selecting an appropriate renderer given the +incoming request. Typically this will be based on the request's Accept header. +""" from __future__ import unicode_literals from django.http import Http404 from rest_framework import exceptions diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 03a7a30f..d51ea929 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -1,9 +1,11 @@ +""" +Pagination serializers determine the structure of the output that should +be used for paginated responses. +""" from __future__ import unicode_literals from rest_framework import serializers from rest_framework.templatetags.rest_framework import replace_query_param -# TODO: Support URLconf kwarg-style paging - class NextPageField(serializers.Field): """ diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 2a10e9af..6bda7418 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,3 +1,9 @@ +""" +Serializer fields that deal with relationships. + +These fields allow you to specify the style that should be used to represent +model relationships, including hyperlinks, primary keys, or slugs. +""" from __future__ import unicode_literals from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch diff --git a/rest_framework/request.py b/rest_framework/request.py index ffbbab33..a434659c 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -1,11 +1,10 @@ """ -The :mod:`request` module provides a :class:`Request` class used to wrap the standard `request` -object received in all the views. +The Request class is used as a wrapper around the standard request object. The wrapped request then offers a richer API, in particular : - content automatically parsed according to `Content-Type` header, - and available as :meth:`.DATA` + and available as `request.DATA` - full support of PUT method, including support for file uploads - form overloading of HTTP method, content type and content """ diff --git a/rest_framework/response.py b/rest_framework/response.py index 5e1bf46e..26e4ab37 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -1,3 +1,9 @@ +""" +The Response class in REST framework is similiar to HTTPResponse, except that +it is initialized with unrendered data, instead of a pre-rendered string. + +The appropriate renderer is called during Django's template response rendering. +""" from __future__ import unicode_literals from django.core.handlers.wsgi import STATUS_CODE_TEXT from django.template.response import SimpleTemplateResponse diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b4327af1..fb438b12 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1,3 +1,15 @@ +""" +Serializers and ModelSerializers are similar to Forms and ModelForms. +Unlike forms, they are not constrained to dealing with HTML output, and +form encoded input. + +Serialization in REST framework is a two-phase process: + +1. Serializers marshal between complex types like model instances, and +python primatives. +2. The process of marshalling between python primatives and request and +response content is handled by parsers and renderers. +""" from __future__ import unicode_literals import copy import datetime diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 810cad63..93ea9816 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -1,3 +1,6 @@ +""" +Provides various throttling policies. +""" from __future__ import unicode_literals from django.core.cache import cache from rest_framework import exceptions @@ -28,9 +31,8 @@ class SimpleRateThrottle(BaseThrottle): A simple cache implementation, that only requires `.get_cache_key()` to be overridden. - The rate (requests / seconds) is set by a :attr:`throttle` attribute - on the :class:`.View` class. The attribute is a string of the form 'number of - requests/period'. + The rate (requests / seconds) is set by a `throttle` attribute on the View + class. The attribute is a string of the form 'number_of_requests/period'. Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') diff --git a/rest_framework/views.py b/rest_framework/views.py index b8e948e0..555fa2f4 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -1,5 +1,5 @@ """ -Provides an APIView class that is used as the base of all class-based views. +Provides an APIView class that is the base of all views in REST framework. """ from __future__ import unicode_literals from django.core.exceptions import PermissionDenied diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 28ab30e2..9133fd44 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -1,3 +1,21 @@ +""" +ViewSets are essentially just a type of class based view, that doesn't provide +any method handlers, such as `get()`, `post()`, etc... but instead has actions, +such as `list()`, `retrieve()`, `create()`, etc... + +Actions are only bound to methods at the point of instantiating the views. + + user_list = UserViewSet.as_view({'get': 'list'}) + user_detail = UserViewSet.as_view({'get': 'retrieve'}) + +Typically, rather than instantiate views from viewsets directly, you'll +regsiter the viewset with a router and let the URL conf be determined +automatically. + + router = DefaultRouter() + router.register(r'users', UserViewSet, 'user') + urlpatterns = router.urls +""" from functools import update_wrapper from django.utils.decorators import classonlymethod from rest_framework import views, generics, mixins @@ -15,13 +33,10 @@ class ViewSetMixin(object): view = MyViewSet.as_view({'get': 'list', 'post': 'create'}) """ - _is_viewset = True @classonlymethod def as_view(cls, actions=None, name_suffix=None, **initkwargs): """ - Main entry point for a request-response process. - Because of the way class based views create a closure around the instantiated view, we need to totally reimplement `.as_view`, and slightly modify the view function that is created and returned. @@ -64,12 +79,22 @@ class ViewSetMixin(object): class ViewSet(ViewSetMixin, views.APIView): + """ + The base ViewSet class does not provide any actions by default. + """ + pass + + +class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, + mixins.ListModelMixin, + ViewSetMixin, + generics.GenericAPIView): + """ + A viewset that provides default `list()` and `retrieve()` actions. + """ pass -# Note the inheritence of both MultipleObjectAPIView *and* SingleObjectAPIView -# is a bit weird given the diamond inheritence, but it will work for now. -# There's some implementation clean up that can happen later. class ModelViewSet(mixins.CreateModelMixin, mixins.RetrieveModelMixin, mixins.UpdateModelMixin, @@ -77,11 +102,8 @@ class ModelViewSet(mixins.CreateModelMixin, mixins.ListModelMixin, ViewSetMixin, generics.GenericAPIView): - pass - - -class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, - mixins.ListModelMixin, - ViewSetMixin, - generics.GenericAPIView): + """ + A viewset that provides default `create()`, `retrieve()`, `update()`, + `partial_update()`, `destroy()` and `list()` actions. + """ pass -- cgit v1.2.3 From 5d01ae661fcf85016718041e021b4bca524dfcdc Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Apr 2013 17:40:17 +0100 Subject: Simplify paginate_queryset method --- rest_framework/generics.py | 29 ++++++++++++++++++++++------- rest_framework/mixins.py | 9 +++------ 2 files changed, 25 insertions(+), 13 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 56471cfa..a18584d4 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -45,6 +45,8 @@ class GenericAPIView(views.APIView): # serializer class will be constructed using this class as the base. model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS + _paginator_class = Paginator + ###################################### # These are pending deprecation... @@ -85,12 +87,24 @@ class GenericAPIView(views.APIView): context = self.get_serializer_context() return pagination_serializer_class(instance=page, context=context) - def paginate_queryset(self, queryset, page_size, paginator_class=Paginator): + def paginate_queryset(self, queryset, page_size=None): """ - Paginate a queryset. + Paginate a queryset if required, either returning a page object, + or `None` if pagination is not configured for this view. """ - paginator = paginator_class(queryset, page_size, - allow_empty_first_page=self.allow_empty) + deprecated_style = False + if page_size is not None: + # TODO: Deperecation warning + deprecated_style = True + else: + # Determine the required page size. + # If pagination is not configured, simply return None. + page_size = self.get_paginate_by() + if not page_size: + return None + + paginator = self._paginator_class(queryset, page_size, + allow_empty_first_page=self.allow_empty) page_kwarg = self.kwargs.get(self.page_kwarg) page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) page = page_kwarg or page_query_param or 1 @@ -103,13 +117,16 @@ class GenericAPIView(views.APIView): raise Http404(_("Page is not 'last', nor can it be converted to an int.")) try: page = paginator.page(page_number) - return (paginator, page, page.object_list, page.has_other_pages()) except InvalidPage as e: raise Http404(_('Invalid page (%(page_number)s): %(message)s') % { 'page_number': page_number, 'message': str(e) }) + if deprecated_style: + return (paginator, page, page.object_list, page.has_other_pages()) + return page + def filter_queryset(self, queryset): """ Given a queryset, filter it with whichever filter backend is in use. @@ -163,7 +180,6 @@ class GenericAPIView(views.APIView): if serializer_class is not None: return serializer_class - # TODO: Deprecation warning class DefaultSerializer(self.model_serializer_class): class Meta: model = self.model @@ -184,7 +200,6 @@ class GenericAPIView(views.APIView): return self.queryset._clone() if self.model is not None: - # TODO: Deprecation warning return self.model._default_manager.all() raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 6e40b5c4..ec751e24 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -76,12 +76,9 @@ class ListModelMixin(object): error_msg = self.empty_error % {'class_name': class_name} raise Http404(error_msg) - # Pagination size is set by the `.paginate_by` attribute, - # which may be `None` to disable pagination. - page_size = self.get_paginate_by() - if page_size: - packed = self.paginate_queryset(self.object_list, page_size) - paginator, page, queryset, is_paginated = packed + # Switch between paginated or standard style responses + page = self.paginate_queryset(self.object_list) + if page is not None: serializer = self.get_pagination_serializer(page) else: serializer = self.get_serializer(self.object_list, many=True) -- cgit v1.2.3 From 7268a5c571bce323ccc75eb039b7c3f1b2b32391 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Apr 2013 17:41:47 +0100 Subject: Added AutoRouter. Don't know if this is a good idea. --- rest_framework/routers.py | 46 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index febb02b3..b7052218 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -14,12 +14,26 @@ For example, you might have a `urls.py` that looks something like this: urlpatterns = router.urls """ from django.conf.urls import url, patterns +from django.db import models from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.reverse import reverse +from rest_framework.viewsets import ModelViewSet from rest_framework.urlpatterns import format_suffix_patterns +def replace_methodname(format_string, methodname): + """ + Partially format a format_string, swapping out any + '{methodname}'' or '{methodnamehyphen}'' components. + """ + methodnamehyphen = methodname.replace('_', '-') + ret = format_string + ret = ret.replace('{methodname}', methodname) + ret = ret.replace('{methodnamehyphen}', methodnamehyphen) + return ret + + class BaseRouter(object): def __init__(self): self.registry = [] @@ -66,7 +80,7 @@ class SimpleRouter(BaseRouter): { '{httpmethod}': '{methodname}', }, - '{basename}-{methodname}' + '{basename}-{methodnamehyphen}' ), ] @@ -89,13 +103,13 @@ class SimpleRouter(BaseRouter): ret = [] for url_format, method_map, name_format in self.routes: if method_map == {'{httpmethod}': '{methodname}'}: - # Dynamic routes + # Dynamic routes (@link or @action decorator) for httpmethod, methodname in dynamic_routes.items(): extra_kwargs = getattr(viewset, methodname).kwargs ret.append(( - url_format.replace('{methodname}', methodname), + replace_methodname(url_format, methodname), {httpmethod: methodname}, - name_format.replace('{methodname}', methodname), + replace_methodname(name_format, methodname), extra_kwargs )) else: @@ -196,3 +210,27 @@ class DefaultRouter(SimpleRouter): urls = format_suffix_patterns(urls) return urls + + +class AutoRouter(DefaultRouter): + """ + A router class that doesn't require you to register any viewsets, + but instead automatically creates routes for all installed models. + + Useful for quick and dirty prototyping. + """ + def __init__(self): + super(AutoRouter, self).__init__() + for model in models.get_models(): + prefix = model._meta.verbose_name_plural.replace(' ', '_') + basename = model._meta.object_name.lower() + classname = model.__name__ + + DynamicViewSet = type( + classname, + (ModelViewSet,), + {} + ) + DynamicViewSet.model = model + + self.register(prefix, DynamicViewSet, basename) -- cgit v1.2.3 From 8fa79a7fd38dda015afa658084361c6da2856e46 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 26 Apr 2013 14:59:21 +0100 Subject: Deal with List/Instance suffixes for viewsets --- rest_framework/renderers.py | 2 +- rest_framework/routers.py | 72 ++++++++++++++++++++----------------- rest_framework/utils/breadcrumbs.py | 3 +- rest_framework/utils/formatting.py | 7 ++-- rest_framework/viewsets.py | 10 +++++- 5 files changed, 56 insertions(+), 38 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 752306ad..a0829c8f 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -439,7 +439,7 @@ class BrowsableAPIRenderer(BaseRenderer): return GenericContentForm() def get_name(self, view): - return get_view_name(view.__class__) + return get_view_name(view.__class__, getattr(view, 'suffix', None)) def get_description(self, view): return get_view_description(view.__class__, html=True) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index b7052218..3a8c4508 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -13,6 +13,7 @@ For example, you might have a `urls.py` that looks something like this: urlpatterns = router.urls """ +from collections import namedtuple from django.conf.urls import url, patterns from django.db import models from rest_framework.decorators import api_view @@ -22,6 +23,9 @@ from rest_framework.viewsets import ModelViewSet from rest_framework.urlpatterns import format_suffix_patterns +Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs']) + + def replace_methodname(format_string, methodname): """ Partially format a format_string, swapping out any @@ -38,8 +42,8 @@ class BaseRouter(object): def __init__(self): self.registry = [] - def register(self, prefix, viewset, basename): - self.registry.append((prefix, viewset, basename)) + def register(self, prefix, viewset, name): + self.registry.append((prefix, viewset, name)) def get_urls(self): raise NotImplemented('get_urls must be overridden') @@ -54,33 +58,36 @@ class BaseRouter(object): class SimpleRouter(BaseRouter): routes = [ # List route. - ( - r'^{prefix}/$', - { + Route( + url=r'^{prefix}/$', + mapping={ 'get': 'list', 'post': 'create' }, - '{basename}-list' + name='{basename}-list', + initkwargs={'suffix': 'List'} ), # Detail route. - ( - r'^{prefix}/{lookup}/$', - { + Route( + url=r'^{prefix}/{lookup}/$', + mapping={ 'get': 'retrieve', 'put': 'update', 'patch': 'partial_update', 'delete': 'destroy' }, - '{basename}-detail' + name='{basename}-detail', + initkwargs={'suffix': 'Instance'} ), # Dynamically generated routes. # Generated using @action or @link decorators on methods of the viewset. - ( - r'^{prefix}/{lookup}/{methodname}/$', - { + Route( + url=r'^{prefix}/{lookup}/{methodname}/$', + mapping={ '{httpmethod}': '{methodname}', }, - '{basename}-{methodnamehyphen}' + name='{basename}-{methodnamehyphen}', + initkwargs={} ), ] @@ -88,8 +95,7 @@ class SimpleRouter(BaseRouter): """ Augment `self.routes` with any dynamically generated routes. - Returns a list of 4-tuples, of the form: - `(url_format, method_map, name_format, extra_kwargs)` + Returns a list of the Route namedtuple. """ # Determine any `@action` or `@link` decorated methods on the viewset @@ -101,21 +107,21 @@ class SimpleRouter(BaseRouter): dynamic_routes[httpmethod] = methodname ret = [] - for url_format, method_map, name_format in self.routes: - if method_map == {'{httpmethod}': '{methodname}'}: + for route in self.routes: + if route.mapping == {'{httpmethod}': '{methodname}'}: # Dynamic routes (@link or @action decorator) for httpmethod, methodname in dynamic_routes.items(): - extra_kwargs = getattr(viewset, methodname).kwargs - ret.append(( - replace_methodname(url_format, methodname), - {httpmethod: methodname}, - replace_methodname(name_format, methodname), - extra_kwargs + initkwargs = route.initkwargs.copy() + initkwargs.update(getattr(viewset, methodname).kwargs) + ret.append(Route( + url=replace_methodname(route.url, methodname), + mapping={httpmethod: methodname}, + name=replace_methodname(route.name, methodname), + initkwargs=initkwargs, )) else: # Standard route - extra_kwargs = {} - ret.append((url_format, method_map, name_format, extra_kwargs)) + ret.append(route) return ret @@ -150,17 +156,17 @@ class SimpleRouter(BaseRouter): lookup = self.get_lookup_regex(viewset) routes = self.get_routes(viewset) - for url_format, method_map, name_format, extra_kwargs in routes: + for route in routes: # Only actions which actually exist on the viewset will be bound - method_map = self.get_method_map(viewset, method_map) - if not method_map: + mapping = self.get_method_map(viewset, route.mapping) + if not mapping: continue # Build the url pattern - regex = url_format.format(prefix=prefix, lookup=lookup) - view = viewset.as_view(method_map, **extra_kwargs) - name = name_format.format(basename=basename) + regex = route.url.format(prefix=prefix, lookup=lookup) + view = viewset.as_view(mapping, **route.initkwargs) + name = route.name.format(basename=basename) ret.append(url(regex, view, name=name)) return ret @@ -179,7 +185,7 @@ class DefaultRouter(SimpleRouter): Return a view to use as the API root. """ api_root_dict = {} - list_name = self.routes[0][-1] + list_name = self.routes[0].name for prefix, viewset, basename in self.registry: api_root_dict[prefix] = list_name.format(basename=basename) diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 18b3b207..8f8e5710 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -21,7 +21,8 @@ def get_breadcrumbs(url): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: - breadcrumbs_list.insert(0, (get_view_name(view.cls), prefix + url)) + suffix = getattr(view, 'suffix', None) + breadcrumbs_list.insert(0, (get_view_name(view.cls, suffix), prefix + url)) seen.append(view) if url == '': diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 79566db1..ebadb3a6 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -45,14 +45,17 @@ def _camelcase_to_spaces(content): return ' '.join(content.split('_')).title() -def get_view_name(cls): +def get_view_name(cls, suffix=None): """ Return a formatted name for an `APIView` class or `@api_view` function. """ name = cls.__name__ name = _remove_trailing_string(name, 'View') name = _remove_trailing_string(name, 'ViewSet') - return _camelcase_to_spaces(name) + name = _camelcase_to_spaces(name) + if suffix: + name += ' ' + suffix + return name def get_view_description(cls, html=False): diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 9133fd44..bd25df77 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -35,12 +35,16 @@ class ViewSetMixin(object): """ @classonlymethod - def as_view(cls, actions=None, name_suffix=None, **initkwargs): + def as_view(cls, actions=None, **initkwargs): """ Because of the way class based views create a closure around the instantiated view, we need to totally reimplement `.as_view`, and slightly modify the view function that is created and returned. """ + # The suffix initkwarg is reserved for identifing the viewset type + # eg. 'List' or 'Instance'. + cls.suffix = None + # sanitize keyword arguments for key in initkwargs: if key in cls.http_method_names: @@ -74,7 +78,11 @@ class ViewSetMixin(object): # like csrf_exempt from dispatch update_wrapper(view, cls.dispatch, assigned=()) + # We need to set these on the view function, so that breadcrumb + # generation can pick out these bits of information from a + # resolved URL. view.cls = cls + view.suffix = initkwargs.get('suffix', None) return view -- cgit v1.2.3 From 018d8b8dced31309196496e625cf8a746b98d65e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 26 Apr 2013 15:09:55 +0100 Subject: Bits of cleanup --- rest_framework/routers.py | 2 +- rest_framework/utils/breadcrumbs.py | 30 +++++++++++++++++++++--------- 2 files changed, 22 insertions(+), 10 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 3a8c4508..33e88a81 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -29,7 +29,7 @@ Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs']) def replace_methodname(format_string, methodname): """ Partially format a format_string, swapping out any - '{methodname}'' or '{methodnamehyphen}'' components. + '{methodname}' or '{methodnamehyphen}' components. """ methodnamehyphen = methodname.replace('_', '-') ret = format_string diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 8f8e5710..28801d09 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -4,25 +4,33 @@ from rest_framework.utils.formatting import get_view_name def get_breadcrumbs(url): - """Given a url returns a list of breadcrumbs, which are each a tuple of (name, url).""" + """ + Given a url returns a list of breadcrumbs, which are each a + tuple of (name, url). + """ from rest_framework.views import APIView def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen): - """Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url.""" + """ + Add tuples of (name, url) to the breadcrumbs list, + progressively chomping off parts of the url. + """ try: (view, unused_args, unused_kwargs) = resolve(url) except Exception: pass else: - # Check if this is a REST framework view, and if so add it to the breadcrumbs + # Check if this is a REST framework view, + # and if so add it to the breadcrumbs if issubclass(getattr(view, 'cls', None), APIView): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: suffix = getattr(view, 'suffix', None) - breadcrumbs_list.insert(0, (get_view_name(view.cls, suffix), prefix + url)) + name = get_view_name(view.cls, suffix) + breadcrumbs_list.insert(0, (name, prefix + url)) seen.append(view) if url == '': @@ -30,11 +38,15 @@ def get_breadcrumbs(url): return breadcrumbs_list elif url.endswith('/'): - # Drop trailing slash off the end and continue to try to resolve more breadcrumbs - return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen) - - # Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs - return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen) + # Drop trailing slash off the end and continue to try to + # resolve more breadcrumbs + url = url.rstrip('/') + return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) + + # Drop trailing non-slash off the end and continue to try to + # resolve more breadcrumbs + url = url[:url.rfind('/') + 1] + return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) prefix = get_script_prefix().rstrip('/') url = url[len(prefix):] -- cgit v1.2.3 From 3b0fa3ebaa9d42723d970bb88be0dfe2586d1a5e Mon Sep 17 00:00:00 2001 From: JC Date: Sat, 27 Apr 2013 13:10:39 -0700 Subject: Changed DepthTest to have depth=2 --- rest_framework/tests/serializer.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 05217f35..bd874253 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -3,7 +3,7 @@ from django.utils.datastructures import MultiValueDict from django.test import TestCase from rest_framework import serializers from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, - BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel, + BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) import datetime import pickle @@ -767,8 +767,6 @@ class RelatedTraversalTest(TestCase): post = BlogPost.objects.create(title="Test blog post", writer=user) post.blogpostcomment_set.create(text="I love this blog post") - from rest_framework.tests.models import BlogPostComment - class PersonSerializer(serializers.ModelSerializer): class Meta: model = Person @@ -968,23 +966,26 @@ class SerializerPickleTests(TestCase): class DepthTest(TestCase): def test_implicit_nesting(self): + writer = Person.objects.create(name="django", age=1) post = BlogPost.objects.create(title="Test blog post", writer=writer) + comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) - class BlogPostSerializer(serializers.ModelSerializer): + class BlogPostCommentSerializer(serializers.ModelSerializer): class Meta: - model = BlogPost - depth = 1 + model = BlogPostComment + depth = 2 - serializer = BlogPostSerializer(instance=post) - expected = {'id': 1, 'title': 'Test blog post', - 'writer': {'id': 1, 'name': 'django', 'age': 1}} + serializer = BlogPostCommentSerializer(instance=comment) + expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', + 'writer': {'id': 1, 'name': 'django', 'age': 1}}} self.assertEqual(serializer.data, expected) def test_explicit_nesting(self): writer = Person.objects.create(name="django", age=1) post = BlogPost.objects.create(title="Test blog post", writer=writer) + comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) class PersonSerializer(serializers.ModelSerializer): class Meta: @@ -996,9 +997,15 @@ class DepthTest(TestCase): class Meta: model = BlogPost - serializer = BlogPostSerializer(instance=post) - expected = {'id': 1, 'title': 'Test blog post', - 'writer': {'id': 1, 'name': 'django', 'age': 1}} + class BlogPostCommentSerializer(serializers.ModelSerializer): + blog_post = BlogPostSerializer() + + class Meta: + model = BlogPostComment + + serializer = BlogPostCommentSerializer(instance=comment) + expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', + 'writer': {'id': 1, 'name': 'django', 'age': 1}}} self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From 8cbb715f4c5550d76e397828608a31a4f254a37d Mon Sep 17 00:00:00 2001 From: JC Date: Sat, 27 Apr 2013 13:23:55 -0700 Subject: Changed definition of NestedModelSerializer to correct depth handling --- rest_framework/serializers.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e28bbe81..add46566 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -205,18 +205,6 @@ class BaseSerializer(WritableField): return ret - ##### - # Field methods - used when the serializer class is itself used as a field. - - def initialize(self, parent, field_name): - """ - Same behaviour as usual Field, except that we need to keep track - of state so that we can deal with handling maximum depth. - """ - super(BaseSerializer, self).initialize(parent, field_name) - if parent.opts.depth: - self.opts.depth = parent.opts.depth - 1 - ##### # Methods to convert or revert from objects <--> primitive representations. @@ -619,6 +607,8 @@ class ModelSerializer(Serializer): class NestedModelSerializer(ModelSerializer): class Meta: model = model_field.rel.to + depth = self.opts.depth - 1 + return NestedModelSerializer() def get_related_field(self, model_field, to_many=False): -- cgit v1.2.3 From dc7b1d643020cac5d585aac42f98962cc7aa6bf7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Apr 2013 12:45:00 +0100 Subject: 2.2's PendingDeprecationWarnings now become DeprecationWarnings. 2.3's PendingDeprecationWarnings added. --- rest_framework/fields.py | 4 +-- rest_framework/generics.py | 60 +++++++++++++++++++++++-------- rest_framework/permissions.py | 8 +++-- rest_framework/relations.py | 28 +++++++-------- rest_framework/routers.py | 2 ++ rest_framework/serializers.py | 24 +++++++++---- rest_framework/tests/serializer.py | 5 ++- rest_framework/tests/serializer_nested.py | 4 +-- rest_framework/viewsets.py | 2 ++ 9 files changed, 93 insertions(+), 44 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 38fe025d..f934fc39 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -200,9 +200,9 @@ class WritableField(Field): # 'blank' is to be deprecated in favor of 'required' if blank is not None: - warnings.warn('The `blank` keyword argument is due to deprecated. ' + warnings.warn('The `blank` keyword argument is deprecated. ' 'Use the `required` keyword argument instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) required = not(blank) super(WritableField, self).__init__(source=source) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index a18584d4..972424e6 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -2,13 +2,16 @@ Generic views that provide commonly needed behaviour. """ from __future__ import unicode_literals -from rest_framework import views, mixins -from rest_framework.settings import api_settings + from django.core.exceptions import ImproperlyConfigured from django.core.paginator import Paginator, InvalidPage from django.http import Http404 from django.shortcuts import get_object_or_404 from django.utils.translation import ugettext as _ +from rest_framework import views, mixins +from rest_framework.exceptions import ConfigurationError +from rest_framework.settings import api_settings +import warnings class GenericAPIView(views.APIView): @@ -94,7 +97,12 @@ class GenericAPIView(views.APIView): """ deprecated_style = False if page_size is not None: - # TODO: Deperecation warning + warnings.warn('The `page_size` parameter to `paginate_queryset()` ' + 'is due to be deprecated. ' + 'Note that the return style of this method is also ' + 'changed, and will simply return a page object ' + 'when called without a `page_size` argument.', + PendingDeprecationWarning, stacklevel=2) deprecated_style = True else: # Determine the required page size. @@ -155,7 +163,9 @@ class GenericAPIView(views.APIView): Otherwise defaults to using `self.paginate_by`. """ if queryset is not None: - pass # TODO: Deprecation warning + warnings.warn('The `queryset` parameter to `get_paginate_by()` ' + 'is due to be deprecated.', + PendingDeprecationWarning, stacklevel=2) if self.paginate_by_param: query_params = self.request.QUERY_PARAMS @@ -226,17 +236,27 @@ class GenericAPIView(views.APIView): if lookup is not None: filter_kwargs = {self.lookup_field: lookup} - elif pk is not None: - # TODO: Deprecation warning + elif pk is not None and self.lookup_field == 'pk': + warnings.warn( + 'The `pk_url_kwarg` attribute is due to be deprecated. ' + 'Use the `lookup_field` attribute instead', + PendingDeprecationWarning + ) filter_kwargs = {'pk': pk} - elif slug is not None: - # TODO: Deprecation warning + elif slug is not None and self.lookup_field == 'pk': + warnings.warn( + 'The `slug_url_kwarg` attribute is due to be deprecated. ' + 'Use the `lookup_field` attribute instead', + PendingDeprecationWarning + ) filter_kwargs = {self.slug_field: slug} else: - # TODO: Fix error message - raise AttributeError("Generic detail view %s must be called with " - "either an object pk or a slug." - % self.__class__.__name__) + raise ConfigurationError( + 'Expected view %s to be called with a URL keyword argument ' + 'named "%s". Fix your URL conf, or set the `.lookup_field` ' + 'attribute on the view correctly.' % + (self.__class__.__name__, self.lookup_field) + ) obj = get_object_or_404(queryset, **filter_kwargs) @@ -391,8 +411,20 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, ########################## class MultipleObjectAPIView(GenericAPIView): - pass + def __init__(self, *args, **kwargs): + warnings.warn( + 'Subclassing `MultipleObjectAPIView` is due to be deprecated. ' + 'You should simply subclass `GenericAPIView` instead.', + PendingDeprecationWarning, stacklevel=2 + ) + super(MultipleObjectAPIView, self).__init__(*args, **kwargs) class SingleObjectAPIView(GenericAPIView): - pass + def __init__(self, *args, **kwargs): + warnings.warn( + 'Subclassing `SingleObjectAPIView` is due to be deprecated. ' + 'You should simply subclass `GenericAPIView` instead.', + PendingDeprecationWarning, stacklevel=2 + ) + super(SingleObjectAPIView, self).__init__(*args, **kwargs) diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 2aa45c71..91bf5ad6 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -26,9 +26,11 @@ class BasePermission(object): Return `True` if permission is granted, `False` otherwise. """ if len(inspect.getargspec(self.has_permission).args) == 4: - warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. ' - 'Use `has_object_permission()` instead for object permissions.', - PendingDeprecationWarning, stacklevel=2) + warnings.warn( + 'The `obj` argument in `has_permission` is deprecated. ' + 'Use `has_object_permission()` instead for object permissions.', + DeprecationWarning, stacklevel=2 + ) return self.has_permission(request, view, obj) return True diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 6bda7418..abe5203b 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -42,9 +42,9 @@ class RelatedField(WritableField): # 'null' is to be deprecated in favor of 'required' if 'null' in kwargs: - warnings.warn('The `null` keyword argument is due to be deprecated. ' + warnings.warn('The `null` keyword argument is deprecated. ' 'Use the `required` keyword argument instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['required'] = not kwargs.pop('null') self.queryset = kwargs.pop('queryset', None) @@ -328,9 +328,9 @@ class HyperlinkedRelatedField(RelatedField): if request is None: warnings.warn("Using `HyperlinkedRelatedField` without including the " - "request in the serializer context is due to be deprecated. " + "request in the serializer context is deprecated. " "Add `context={'request': request}` when instantiating the serializer.", - PendingDeprecationWarning, stacklevel=4) + DeprecationWarning, stacklevel=4) pk = getattr(obj, 'pk', None) if pk is None: @@ -443,9 +443,9 @@ class HyperlinkedIdentityField(Field): if request is None: warnings.warn("Using `HyperlinkedIdentityField` without including the " - "request in the serializer context is due to be deprecated. " + "request in the serializer context is deprecated. " "Add `context={'request': request}` when instantiating the serializer.", - PendingDeprecationWarning, stacklevel=4) + DeprecationWarning, stacklevel=4) # By default use whatever format is given for the current context # unless the target is a different type to the source. @@ -488,35 +488,35 @@ class HyperlinkedIdentityField(Field): class ManyRelatedField(RelatedField): def __init__(self, *args, **kwargs): - warnings.warn('`ManyRelatedField()` is due to be deprecated. ' + warnings.warn('`ManyRelatedField()` is deprecated. ' 'Use `RelatedField(many=True)` instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['many'] = True super(ManyRelatedField, self).__init__(*args, **kwargs) class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField): def __init__(self, *args, **kwargs): - warnings.warn('`ManyPrimaryKeyRelatedField()` is due to be deprecated. ' + warnings.warn('`ManyPrimaryKeyRelatedField()` is deprecated. ' 'Use `PrimaryKeyRelatedField(many=True)` instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['many'] = True super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs) class ManySlugRelatedField(SlugRelatedField): def __init__(self, *args, **kwargs): - warnings.warn('`ManySlugRelatedField()` is due to be deprecated. ' + warnings.warn('`ManySlugRelatedField()` is deprecated. ' 'Use `SlugRelatedField(many=True)` instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['many'] = True super(ManySlugRelatedField, self).__init__(*args, **kwargs) class ManyHyperlinkedRelatedField(HyperlinkedRelatedField): def __init__(self, *args, **kwargs): - warnings.warn('`ManyHyperlinkedRelatedField()` is due to be deprecated. ' + warnings.warn('`ManyHyperlinkedRelatedField()` is deprecated. ' 'Use `HyperlinkedRelatedField(many=True)` instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) kwargs['many'] = True super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 33e88a81..2bbf519c 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -13,6 +13,8 @@ For example, you might have a `urls.py` that looks something like this: urlpatterns = router.urls """ +from __future__ import unicode_literals + from collections import namedtuple from django.conf.urls import url, patterns from django.db import models diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 3d956e4d..3afb7475 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -436,9 +436,9 @@ class BaseSerializer(WritableField): else: many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) if many: - warnings.warn('Implict list/queryset serialization is due to be deprecated. ' + warnings.warn('Implict list/queryset serialization is deprecated. ' 'Use the `many=True` flag when instantiating the serializer.', - PendingDeprecationWarning, stacklevel=3) + DeprecationWarning, stacklevel=3) if many: ret = [] @@ -498,9 +498,9 @@ class BaseSerializer(WritableField): else: many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) if many: - warnings.warn('Implict list/queryset serialization is due to be deprecated. ' + warnings.warn('Implict list/queryset serialization is deprecated. ' 'Use the `many=True` flag when instantiating the serializer.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) if many: self._data = [self.to_native(item) for item in obj] @@ -606,13 +606,25 @@ class ModelSerializer(Serializer): if model_field.rel and nested: if len(inspect.getargspec(self.get_nested_field).args) == 2: - # TODO: deprecation warning + warnings.warn( + 'The `get_nested_field(model_field)` call signature ' + 'is due to be deprecated. ' + 'Use `get_nested_field(model_field, related_model, ' + 'to_many) instead', + PendingDeprecationWarning + ) field = self.get_nested_field(model_field) else: field = self.get_nested_field(model_field, related_model, to_many) elif model_field.rel: if len(inspect.getargspec(self.get_nested_field).args) == 3: - # TODO: deprecation warning + warnings.warn( + 'The `get_related_field(model_field, to_many)` call ' + 'signature is due to be deprecated. ' + 'Use `get_related_field(model_field, related_model, ' + 'to_many) instead', + PendingDeprecationWarning + ) field = self.get_related_field(model_field, to_many=to_many) else: field = self.get_related_field(model_field, related_model, to_many) diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 3a94fad5..ae8d09dc 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -357,7 +357,6 @@ class CustomValidationTests(TestCase): def validate_email(self, attrs, source): value = attrs[source] - return attrs def validate_content(self, attrs, source): @@ -1103,7 +1102,7 @@ class DeserializeListTestCase(TestCase): def test_no_errors(self): data = [self.data.copy() for x in range(0, 3)] - serializer = CommentSerializer(data=data) + serializer = CommentSerializer(data=data, many=True) self.assertTrue(serializer.is_valid()) self.assertTrue(isinstance(serializer.object, list)) self.assertTrue( @@ -1115,7 +1114,7 @@ class DeserializeListTestCase(TestCase): invalid_item['email'] = '' data = [self.data.copy(), invalid_item, self.data.copy()] - serializer = CommentSerializer(data=data) + serializer = CommentSerializer(data=data, many=True) self.assertFalse(serializer.is_valid()) expected = [{}, {'email': ['This field is required.']}, {}] self.assertEqual(serializer.errors, expected) diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py index 6a29c652..71d0e24b 100644 --- a/rest_framework/tests/serializer_nested.py +++ b/rest_framework/tests/serializer_nested.py @@ -109,7 +109,7 @@ class WritableNestedSerializerBasicTests(TestCase): } ] - serializer = self.AlbumSerializer(data=data) + serializer = self.AlbumSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) self.assertEqual(serializer.errors, expected_errors) @@ -241,6 +241,6 @@ class WritableNestedSerializerObjectTests(TestCase): ) ] - serializer = self.AlbumSerializer(data=data) + serializer = self.AlbumSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.object, expected_object) diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index bd25df77..a54467d7 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -16,6 +16,8 @@ automatically. router.register(r'users', UserViewSet, 'user') urlpatterns = router.urls """ +from __future__ import unicode_literals + from functools import update_wrapper from django.utils.decorators import classonlymethod from rest_framework import views, generics, mixins -- cgit v1.2.3 From d17e2d852fc6ebc738e324b8797d390dc0287d37 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Apr 2013 12:46:57 +0100 Subject: Remove AutoRouter. (Adding shortcut to generic views/viewsets means it's unneccessary) --- rest_framework/routers.py | 26 -------------------------- 1 file changed, 26 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 2bbf519c..923405e8 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -17,11 +17,9 @@ from __future__ import unicode_literals from collections import namedtuple from django.conf.urls import url, patterns -from django.db import models from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.reverse import reverse -from rest_framework.viewsets import ModelViewSet from rest_framework.urlpatterns import format_suffix_patterns @@ -218,27 +216,3 @@ class DefaultRouter(SimpleRouter): urls = format_suffix_patterns(urls) return urls - - -class AutoRouter(DefaultRouter): - """ - A router class that doesn't require you to register any viewsets, - but instead automatically creates routes for all installed models. - - Useful for quick and dirty prototyping. - """ - def __init__(self): - super(AutoRouter, self).__init__() - for model in models.get_models(): - prefix = model._meta.verbose_name_plural.replace(' ', '_') - basename = model._meta.object_name.lower() - classname = model.__name__ - - DynamicViewSet = type( - classname, - (ModelViewSet,), - {} - ) - DynamicViewSet.model = model - - self.register(prefix, DynamicViewSet, basename) -- cgit v1.2.3 From 53f9d4a380ee0066cbee8382ae265ea6005d8c88 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Apr 2013 13:20:15 +0100 Subject: fields shortcut on views --- rest_framework/generics.py | 5 +++++ rest_framework/serializers.py | 2 +- rest_framework/tests/generics.py | 24 ++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 972424e6..0b8e4a15 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -44,6 +44,10 @@ class GenericAPIView(views.APIView): # the explicit style is generally preferred. model = None + # This shortcut may be used instead of setting the `serializer_class` + # attribute, although using the explicit style is generally preferred. + fields = None + # If the `model` shortcut is used instead of `serializer_class`, then the # serializer class will be constructed using this class as the base. model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS @@ -193,6 +197,7 @@ class GenericAPIView(views.APIView): class DefaultSerializer(self.model_serializer_class): class Meta: model = self.model + fields = self.fields return DefaultSerializer def get_queryset(self): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 3afb7475..f4a20097 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -645,7 +645,7 @@ class ModelSerializer(Serializer): for relation in reverse_rels: accessor_name = relation.get_accessor_name() - if accessor_name not in self.opts.fields: + if not self.opts.fields or accessor_name not in self.opts.fields: continue related_model = relation.model to_many = relation.field.rel.multiple diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 4a13389a..12c9b677 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -344,6 +344,30 @@ class TestOverriddenGetObject(TestCase): self.assertEqual(response.data, self.data[0]) +class TestFieldsShortcut(TestCase): + """ + Test cases for setting the `fields` attribute on a view. + """ + def setUp(self): + class OverriddenFieldsView(generics.RetrieveUpdateDestroyAPIView): + model = BasicModel + fields = ('text',) + + class RegularView(generics.RetrieveUpdateDestroyAPIView): + model = BasicModel + + self.overridden_fields_view = OverriddenFieldsView() + self.regular_view = RegularView() + + def test_overridden_fields_view(self): + Serializer = self.overridden_fields_view.get_serializer_class() + self.assertEqual(Serializer().fields.keys(), ['text']) + + def test_not_overridden_fields_view(self): + Serializer = self.regular_view.get_serializer_class() + self.assertEqual(Serializer().fields.keys(), ['id', 'text']) + + # Regression test for #285 class CommentSerializer(serializers.ModelSerializer): -- cgit v1.2.3 From 0c1ab584d3d0898d47e0bce6beb5d7c39a55dd52 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Apr 2013 14:08:38 +0100 Subject: Tweaks for preferring .queryset over .model --- rest_framework/generics.py | 19 ++++++++++++------- rest_framework/tests/generics.py | 6 ++++-- 2 files changed, 16 insertions(+), 9 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 0b8e4a15..3ea78b5d 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -48,11 +48,10 @@ class GenericAPIView(views.APIView): # attribute, although using the explicit style is generally preferred. fields = None - # If the `model` shortcut is used instead of `serializer_class`, then the - # serializer class will be constructed using this class as the base. + # The following attributes may be subject to change, + # and should be considered private API. model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS - - _paginator_class = Paginator + paginator_class = Paginator ###################################### # These are pending deprecation... @@ -115,8 +114,8 @@ class GenericAPIView(views.APIView): if not page_size: return None - paginator = self._paginator_class(queryset, page_size, - allow_empty_first_page=self.allow_empty) + paginator = self.paginator_class(queryset, page_size, + allow_empty_first_page=self.allow_empty) page_kwarg = self.kwargs.get(self.page_kwarg) page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) page = page_kwarg or page_query_param or 1 @@ -194,9 +193,15 @@ class GenericAPIView(views.APIView): if serializer_class is not None: return serializer_class + assert self.model is not None or self.queryset is not None, \ + "'%s' should either include a 'serializer_class' attribute, " \ + "or use the 'queryset' or 'model' attribute as a shortcut for " \ + "automatically generating a serializer class." \ + % self.__class__.__name__ + class DefaultSerializer(self.model_serializer_class): class Meta: - model = self.model + model = self.model or self.queryset.model fields = self.fields return DefaultSerializer diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 12c9b677..63ff1fc3 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -350,11 +350,11 @@ class TestFieldsShortcut(TestCase): """ def setUp(self): class OverriddenFieldsView(generics.RetrieveUpdateDestroyAPIView): - model = BasicModel + queryset = BasicModel.objects.all() fields = ('text',) class RegularView(generics.RetrieveUpdateDestroyAPIView): - model = BasicModel + queryset = BasicModel.objects.all() self.overridden_fields_view = OverriddenFieldsView() self.regular_view = RegularView() @@ -362,10 +362,12 @@ class TestFieldsShortcut(TestCase): def test_overridden_fields_view(self): Serializer = self.overridden_fields_view.get_serializer_class() self.assertEqual(Serializer().fields.keys(), ['text']) + self.assertEqual(Serializer().opts.model, BasicModel) def test_not_overridden_fields_view(self): Serializer = self.regular_view.get_serializer_class() self.assertEqual(Serializer().fields.keys(), ['id', 'text']) + self.assertEqual(Serializer().opts.model, BasicModel) # Regression test for #285 -- cgit v1.2.3 From 21ae3a66917acf4ea57e8f7940ce1a6823a2ce92 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 30 Apr 2013 08:24:33 +0100 Subject: Drop out attribute --- rest_framework/generics.py | 24 ++++++++++-------------- rest_framework/serializers.py | 4 ++++ rest_framework/tests/generics.py | 26 -------------------------- 3 files changed, 14 insertions(+), 40 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 3ea78b5d..62129dcc 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -20,11 +20,17 @@ class GenericAPIView(views.APIView): """ # You'll need to either set these attributes, - # or override `get_queryset`/`get_serializer_class`. + # or override `get_queryset()`/`get_serializer_class()`. queryset = None serializer_class = None + # This shortcut may be used instead of setting either or both + # of the `queryset`/`serializer_class` attributes, although using + # the explicit style is generally preferred. + model = None + # If you want to use object lookups other than pk, set this attribute. + # For more complex lookup requirements override `get_object()`. lookup_field = 'pk' # Pagination settings @@ -39,15 +45,6 @@ class GenericAPIView(views.APIView): # Determines if the view will return 200 or 404 responses for empty lists. allow_empty = True - # This shortcut may be used instead of setting either (or both) - # of the `queryset`/`serializer_class` attributes, although using - # the explicit style is generally preferred. - model = None - - # This shortcut may be used instead of setting the `serializer_class` - # attribute, although using the explicit style is generally preferred. - fields = None - # The following attributes may be subject to change, # and should be considered private API. model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS @@ -193,16 +190,15 @@ class GenericAPIView(views.APIView): if serializer_class is not None: return serializer_class - assert self.model is not None or self.queryset is not None, \ + assert self.model is not None, \ "'%s' should either include a 'serializer_class' attribute, " \ - "or use the 'queryset' or 'model' attribute as a shortcut for " \ + "or use the 'model' attribute as a shortcut for " \ "automatically generating a serializer class." \ % self.__class__.__name__ class DefaultSerializer(self.model_serializer_class): class Meta: - model = self.model or self.queryset.model - fields = self.fields + model = self.model return DefaultSerializer def get_queryset(self): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index f4a20097..0f943d79 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -677,6 +677,8 @@ class ModelSerializer(Serializer): def get_nested_field(self, model_field, related_model, to_many): """ Creates a default instance of a nested relational field. + + Note that model_field will be `None` for reverse relationships. """ class NestedModelSerializer(ModelSerializer): class Meta: @@ -686,6 +688,8 @@ class ModelSerializer(Serializer): def get_related_field(self, model_field, related_model, to_many): """ Creates a default instance of a flat relational field. + + Note that model_field will be `None` for reverse relationships. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 63ff1fc3..4a13389a 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -344,32 +344,6 @@ class TestOverriddenGetObject(TestCase): self.assertEqual(response.data, self.data[0]) -class TestFieldsShortcut(TestCase): - """ - Test cases for setting the `fields` attribute on a view. - """ - def setUp(self): - class OverriddenFieldsView(generics.RetrieveUpdateDestroyAPIView): - queryset = BasicModel.objects.all() - fields = ('text',) - - class RegularView(generics.RetrieveUpdateDestroyAPIView): - queryset = BasicModel.objects.all() - - self.overridden_fields_view = OverriddenFieldsView() - self.regular_view = RegularView() - - def test_overridden_fields_view(self): - Serializer = self.overridden_fields_view.get_serializer_class() - self.assertEqual(Serializer().fields.keys(), ['text']) - self.assertEqual(Serializer().opts.model, BasicModel) - - def test_not_overridden_fields_view(self): - Serializer = self.regular_view.get_serializer_class() - self.assertEqual(Serializer().fields.keys(), ['id', 'text']) - self.assertEqual(Serializer().opts.model, BasicModel) - - # Regression test for #285 class CommentSerializer(serializers.ModelSerializer): -- cgit v1.2.3 From 8dff8d2fdcfcee356c134f4be8235d2a4f122d1a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 30 Apr 2013 14:34:03 +0100 Subject: Add get_breadcrumbs hook to BrowseableAPIRenderer. Closes #733. --- rest_framework/renderers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index a0829c8f..c457ec73 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -444,6 +444,9 @@ class BrowsableAPIRenderer(BaseRenderer): def get_description(self, view): return get_view_description(view.__class__, html=True) + def get_breadcrumbs(self, request): + return get_breadcrumbs(request.path) + def render(self, data, accepted_media_type=None, renderer_context=None): """ Renders *obj* using the :attr:`template` set on the class. @@ -475,7 +478,7 @@ class BrowsableAPIRenderer(BaseRenderer): name = self.get_name(view) description = self.get_description(view) - breadcrumb_list = get_breadcrumbs(request.path) + breadcrumb_list = self.get_breadcrumbs(request) template = loader.get_template(self.template) context = RequestContext(request, { -- cgit v1.2.3 From b65b065375796919a57f4bd6f1dd8187ef0eb165 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 30 Apr 2013 14:34:28 +0100 Subject: Add DjangoModelPermissionsOrAnonReadOnly --- rest_framework/permissions.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 91bf5ad6..751f31a7 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -89,8 +89,8 @@ class DjangoModelPermissions(BasePermission): It ensures that the user is authenticated, and has the appropriate `add`/`change`/`delete` permissions on the model. - This permission will only be applied against view classes that - provide a `.model` attribute, such as the generic class-based views. + This permission can only be applied against view classes that + provide a `.model` or `.queryset` attribute. """ # Map methods into required permission codes. @@ -138,6 +138,14 @@ class DjangoModelPermissions(BasePermission): return False +class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions): + """ + Similar to DjangoModelPermissions, except that anonymous users are + allowed read-only access. + """ + authenticated_users_only = False + + class TokenHasReadWriteScope(BasePermission): """ The request is authenticated as a user and the token used has the right scope -- cgit v1.2.3 From e5040fbf942e021444f629a371bc71c9d47d052f Mon Sep 17 00:00:00 2001 From: Danilo Bargen Date: Tue, 30 Apr 2013 23:24:20 +0200 Subject: Catch ImproperlyConfigured exception in compat.py (fixes #803) --- rest_framework/compat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 067e9018..f8e4e7ca 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -6,6 +6,7 @@ versions of django/python, and compatibility wrappers around optional packages. from __future__ import unicode_literals import django +from django.core.exceptions import ImproperlyConfigured # Try to import six from Django, fallback to included `six`. try: @@ -473,7 +474,7 @@ except ImportError: try: import oauth_provider from oauth_provider.store import store as oauth_provider_store -except ImportError: +except (ImportError, ImproperlyConfigured): oauth_provider = None oauth_provider_store = None -- cgit v1.2.3 From 35f99cddc4a098547389fab7d9f397ad442dfff1 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 1 May 2013 09:03:09 +0100 Subject: lookup_field on hyperlinked fields, and overriddable hyperlinked fields. Closes #688 --- rest_framework/relations.py | 147 +++++++++++++++++++++++++----------------- rest_framework/serializers.py | 3 +- 2 files changed, 91 insertions(+), 59 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index abe5203b..6d8deec1 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -288,10 +288,8 @@ class HyperlinkedRelatedField(RelatedField): """ Represents a relationship using hyperlinking. """ - pk_url_kwarg = 'pk' - slug_field = 'slug' - slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden read_only = False + lookup_field = 'pk' default_error_messages = { 'no_match': _('Invalid hyperlink - No URL match'), @@ -301,69 +299,120 @@ class HyperlinkedRelatedField(RelatedField): 'incorrect_type': _('Incorrect type. Expected url string, received %s.'), } + # These are all pending deprecation + pk_url_kwarg = 'pk' + slug_field = 'slug' + slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden + def __init__(self, *args, **kwargs): try: self.view_name = kwargs.pop('view_name') except KeyError: raise ValueError("Hyperlinked field requires 'view_name' kwarg") + self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + self.format = kwargs.pop('format', None) + + # These are pending deprecation + self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) self.slug_field = kwargs.pop('slug_field', self.slug_field) default_slug_kwarg = self.slug_url_kwarg or self.slug_field - self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) - self.format = kwargs.pop('format', None) super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) - def get_slug_field(self): + def get_url(self, obj, view_name, request, format): """ - Get the name of a slug field to be used to look up by slug. - """ - return self.slug_field - - def to_native(self, obj): - view_name = self.view_name - request = self.context.get('request', None) - format = self.format or self.context.get('format', None) - - if request is None: - warnings.warn("Using `HyperlinkedRelatedField` without including the " - "request in the serializer context is deprecated. " - "Add `context={'request': request}` when instantiating the serializer.", - DeprecationWarning, stacklevel=4) + Given an object, return the URL that hyperlinks to the object. - pk = getattr(obj, 'pk', None) - if pk is None: - return - kwargs = {self.pk_url_kwarg: pk} + May raise a `NoReverseMatch` if the `view_name` and `lookup_field` + attributes are not configured to correctly match the URL conf. + """ + lookup_field = getattr(obj, self.lookup_field) + kwargs = {self.lookup_field: lookup_field} try: return reverse(view_name, kwargs=kwargs, request=request, format=format) except NoReverseMatch: pass + if self.pk_url_kwarg != 'pk': + # Only try pk if it has been explicitly set. + # Otherwise, the default `lookup_field = 'pk'` has us covered. + pk = obj.pk + kwargs = {self.pk_url_kwarg: pk} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass + slug = getattr(obj, self.slug_field, None) + if slug is not None: + # Only try slug if it corresponds to an attribute on the object. + kwargs = {self.slug_url_kwarg: slug} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass - if not slug: - raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + raise NoReverseMatch() - kwargs = {self.slug_url_kwarg: slug} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass + def get_object(self, queryset, view_name, view_args, view_kwargs): + """ + Return the object corresponding to a matched URL. - kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} + Takes the matched URL conf arguments, and the queryset, and should + return an object instance, or raise an `ObjectDoesNotExist` exception. + """ + lookup = view_kwargs.get(self.lookup_field, None) + pk = view_kwargs.get(self.pk_url_kwarg, None) + slug = view_kwargs.get(self.slug_url_kwarg, None) + + if lookup is not None: + filter_kwargs = {self.lookup_field: lookup} + elif pk is not None: + filter_kwargs = {'pk': pk} + elif slug is not None: + filter_kwargs = {self.slug_field: slug} + else: + raise ObjectDoesNotExist() + + return queryset.get(**filter_kwargs) + + def to_native(self, obj): + view_name = self.view_name + request = self.context.get('request', None) + format = self.format or self.context.get('format', None) + + if request is None: + msg = ( + "Using `HyperlinkedRelatedField` without including the request " + "in the serializer context is deprecated. " + "Add `context={'request': request}` when instantiating " + "the serializer." + ) + warnings.warn(msg, DeprecationWarning, stacklevel=4) + + # If the object has not yet been saved then we cannot hyperlink to it. + if getattr(obj, 'pk', None) is None: + return + + # Return the hyperlink, or error if incorrectly configured. try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) + return self.get_url(obj, view_name, request, format) except NoReverseMatch: - pass - - raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + msg = ( + 'Could not resolve URL for hyperlinked relationship using ' + 'view name "%s". You may have failed to include the related ' + 'model in your API, or incorrectly configured the ' + '`lookup_field` attribute on this field.' + ) + raise Exception(msg % view_name) def from_native(self, value): # Convert URL -> model instance pk # TODO: Use values_list - if self.queryset is None: + queryset = self.queryset + if queryset is None: raise Exception('Writable related fields must include a `queryset` argument') try: @@ -387,29 +436,11 @@ class HyperlinkedRelatedField(RelatedField): if match.view_name != self.view_name: raise ValidationError(self.error_messages['incorrect_match']) - pk = match.kwargs.get(self.pk_url_kwarg, None) - slug = match.kwargs.get(self.slug_url_kwarg, None) - - # Try explicit primary key. - if pk is not None: - queryset = self.queryset.filter(pk=pk) - # Next, try looking up by slug. - elif slug is not None: - slug_field = self.get_slug_field() - queryset = self.queryset.filter(**{slug_field: slug}) - # If none of those are defined, it's probably a configuation error. - else: - raise ValidationError(self.error_messages['configuration_error']) - try: - obj = queryset.get() - except ObjectDoesNotExist: + return self.get_object(queryset, match.view_name, + match.args, match.kwargs) + except (ObjectDoesNotExist, TypeError, ValueError): raise ValidationError(self.error_messages['does_not_exist']) - except (TypeError, ValueError): - msg = self.error_messages['incorrect_type'] - raise ValidationError(msg % type(value).__name__) - - return obj class HyperlinkedIdentityField(Field): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b589eca8..d4b34c01 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -836,6 +836,7 @@ class HyperlinkedModelSerializer(ModelSerializer): """ _options_class = HyperlinkedModelSerializerOptions _default_view_name = '%(model_name)s-detail' + _hyperlink_field_class = HyperlinkedRelatedField url = HyperlinkedIdentityField() @@ -874,7 +875,7 @@ class HyperlinkedModelSerializer(ModelSerializer): if model_field: kwargs['required'] = not(model_field.null or model_field.blank) - return HyperlinkedRelatedField(**kwargs) + return self._hyperlink_field_class(**kwargs) def get_identity(self, data): """ -- cgit v1.2.3 From 8cabae22c5330da2e0a15a6d61ef038a6447756a Mon Sep 17 00:00:00 2001 From: Victor Shih Date: Wed, 1 May 2013 21:26:40 -0700 Subject: Example and spelling fixes. Change "browseable" to "browsable" for consistency. --- rest_framework/renderers.py | 2 +- rest_framework/tests/generics.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 4c15e0db..83bbc5b8 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -57,7 +57,7 @@ class JSONRenderer(BaseRenderer): return '' # If 'indent' is provided in the context, then pretty print the result. - # E.g. If we're being called by the BrowseableAPIRenderer. + # E.g. If we're being called by the BrowsableAPIRenderer. renderer_context = renderer_context or {} indent = renderer_context.get('indent', None) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 4a13389a..eca50d82 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -377,7 +377,7 @@ class TestCreateModelWithAutoNowAddField(TestCase): self.assertEqual(created.content, 'foobar') -# Test for particularly ugly regression with m2m in browseable API +# Test for particularly ugly regression with m2m in browsable API class ClassB(models.Model): name = models.CharField(max_length=255) @@ -402,7 +402,7 @@ class ExampleView(generics.ListCreateAPIView): class TestM2MBrowseableAPI(TestCase): def test_m2m_in_browseable_api(self): """ - Test for particularly ugly regression with m2m in browseable API + Test for particularly ugly regression with m2m in browsable API """ request = factory.get('/', HTTP_ACCEPT='text/html') view = ExampleView().as_view() -- cgit v1.2.3 From e4067bfb75a38851ea865719ebfbb65708187b4e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 2 May 2013 12:07:18 +0100 Subject: introduce lookup_field and add pendingdeprecationwarnings --- rest_framework/mixins.py | 13 +++++++++++-- rest_framework/relations.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index ec751e24..ae703771 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -12,7 +12,7 @@ from rest_framework.response import Response from rest_framework.request import clone_request -def _get_validation_exclusions(obj, pk=None, slug_field=None): +def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None): """ Given a model instance, and an optional pk and slug field, return the full list of all other field names on that model. @@ -23,14 +23,19 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None): include = [] if pk: + # Pending deprecation pk_field = obj._meta.pk while pk_field.rel: pk_field = pk_field.rel.to._meta.pk include.append(pk_field.name) if slug_field: + # Pending deprecation include.append(slug_field) + if lookup_field and lookup_field != 'pk': + include.append(lookup_field) + return [field.name for field in obj._meta.fields if field.name not in include] @@ -139,10 +144,14 @@ class UpdateModelMixin(object): Set any attributes on the object that are implicit in the request. """ # pk and/or slug attributes are implicit in the URL. + lookup = self.kwargs.get(self.lookup_field, None) pk = self.kwargs.get(self.pk_url_kwarg, None) slug = self.kwargs.get(self.slug_url_kwarg, None) slug_field = slug and self.slug_field or None + if lookup: + setattr(obj, self.lookup_field, lookup) + if pk: setattr(obj, 'pk', pk) @@ -152,7 +161,7 @@ class UpdateModelMixin(object): # Ensure we clean the attributes so that we don't eg return integer # pk using a string representation, as provided by the url conf kwarg. if hasattr(obj, 'full_clean'): - exclude = _get_validation_exclusions(obj, pk, slug_field) + exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field) obj.full_clean(exclude) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 6d8deec1..bc7f112c 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -314,6 +314,16 @@ class HyperlinkedRelatedField(RelatedField): self.format = kwargs.pop('format', None) # These are pending deprecation + if 'pk_url_kwarg' in kwargs: + msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_url_kwarg' in kwargs: + msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_field' in kwargs: + msg = 'slug_field is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) self.slug_field = kwargs.pop('slug_field', self.slug_field) default_slug_kwarg = self.slug_url_kwarg or self.slug_field -- cgit v1.2.3 From 387250bee438a3826191b2d0d196d0c11373f7f3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 2 May 2013 12:07:37 +0100 Subject: Automagically determine base_name in router class --- rest_framework/routers.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 923405e8..0707635a 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -42,10 +42,22 @@ class BaseRouter(object): def __init__(self): self.registry = [] - def register(self, prefix, viewset, name): - self.registry.append((prefix, viewset, name)) + def register(self, prefix, viewset, base_name=None): + if base_name is None: + base_name = self.get_default_base_name(viewset) + self.registry.append((prefix, viewset, base_name)) + + def get_default_base_name(self, viewset): + """ + If `base_name` is not specified, attempt to automatically determine + it from the viewset. + """ + raise NotImplemented('get_default_base_name must be overridden') def get_urls(self): + """ + Return a list of URL patterns, given the registered viewsets. + """ raise NotImplemented('get_urls must be overridden') @property @@ -91,6 +103,22 @@ class SimpleRouter(BaseRouter): ), ] + def get_default_base_name(self, viewset): + """ + If `base_name` is not specified, attempt to automatically determine + it from the viewset. + """ + model_cls = getattr(viewset, 'model', None) + queryset = getattr(viewset, 'queryset', None) + if model_cls is None and queryset is not None: + model_cls = queryset.model + + assert model_cls, '`name` not argument not specified, and could ' \ + 'not automatically determine the name from the viewset, as ' \ + 'it does not have a `.model` or `.queryset` attribute.' + + return model_cls._meta.object_name.lower() + def get_routes(self, viewset): """ Augment `self.routes` with any dynamically generated routes. -- cgit v1.2.3 From 0c85768435e67133ff219aaddb4ea3bf122bd360 Mon Sep 17 00:00:00 2001 From: Michael Elovskikh Date: Fri, 3 May 2013 01:37:25 +0600 Subject: Added FileUploadParser refs #7 --- rest_framework/parsers.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 491acd68..6ba05aef 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -6,9 +6,10 @@ on the request, such as form content or json encoded data. """ from __future__ import unicode_literals from django.conf import settings +from django.core.files.uploadhandler import StopFutureHandlers from django.http import QueryDict from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser -from django.http.multipartparser import MultiPartParserError +from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter from rest_framework.compat import yaml, etree from rest_framework.exceptions import ParseError from rest_framework.compat import six @@ -205,3 +206,63 @@ class XMLParser(BaseParser): pass return value + + +class FileUploadParser(BaseParser): + """ + Parser for file upload data. + """ + media_type = '*/*' + + def parse(self, stream, media_type=None, parser_context=None): + parser_context = parser_context or {} + request = parser_context['request'] + encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + meta = request.META + + try: + disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) + filename = disposition[1]['filename'] + except KeyError: + filename = None + + content_type = meta.get('HTTP_CONTENT_TYPE', meta.get('CONTENT_TYPE', '')) + try: + content_length = int(meta.get('HTTP_CONTENT_LENGTH', meta.get('CONTENT_LENGTH', 0))) + except (ValueError, TypeError): + content_length = None + + # See if the handler will want to take care of the parsing. + for handler in request.upload_handlers: + result = handler.handle_raw_input(None, + meta, + content_length, + None, + encoding) + if result is not None: + return DataAndFiles(result[0], {'file': result[1]}) + + possible_sizes = [x.chunk_size for x in request.upload_handlers if x.chunk_size] + chunk_size = min([2**31-4] + possible_sizes) + chunks = ChunkIter(stream, chunk_size) + counters = [0] * len(request.upload_handlers) + + for handler in request.upload_handlers: + try: + handler.new_file(None, filename, content_type, content_length, encoding) + except StopFutureHandlers: + break + + for chunk in chunks: + for i, handler in enumerate(request.upload_handlers): + chunk_length = len(chunk) + chunk = handler.receive_data_chunk(chunk, counters[i]) + counters[i] += chunk_length + if chunk is None: + # If the chunk received by the handler is None, then don't continue. + break + + for i, handler in enumerate(request.upload_handlers): + file_obj = handler.file_complete(counters[i]) + if file_obj: + return DataAndFiles(None, {'file': file_obj}) -- cgit v1.2.3 From 318fdaabe560c99de4983e0a3cdcb79756baaf01 Mon Sep 17 00:00:00 2001 From: Michael Elovskikh Date: Fri, 3 May 2013 01:39:08 +0600 Subject: Tests for FileUploadParser --- rest_framework/tests/parsers.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py index 539c5b44..b18ecbf2 100644 --- a/rest_framework/tests/parsers.py +++ b/rest_framework/tests/parsers.py @@ -1,10 +1,11 @@ from __future__ import unicode_literals from rest_framework.compat import StringIO from django import forms +from django.core.files.uploadhandler import MemoryFileUploadHandler from django.test import TestCase from django.utils import unittest from rest_framework.compat import etree -from rest_framework.parsers import FormParser +from rest_framework.parsers import FormParser, FileUploadParser from rest_framework.parsers import XMLParser import datetime @@ -82,3 +83,27 @@ class TestXMLParser(TestCase): parser = XMLParser() data = parser.parse(self._complex_data_input) self.assertEqual(data, self._complex_data) + + +class TestFileUploadParser(TestCase): + def setUp(self): + class MockRequest(object): + pass + from io import BytesIO + self.stream = BytesIO( + "Test text file".encode('utf-8') + ) + request = MockRequest() + request.upload_handlers = (MemoryFileUploadHandler(),) + request.META = { + 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'), + 'HTTP_CONTENT_LENGTH': 14, + } + self.parser_context = {'request': request} + + def test_parse(self): + """ Make sure the `QueryDict` works OK """ + parser = FileUploadParser() + data_and_files = parser.parse(self.stream, parser_context=self.parser_context) + file_obj = data_and_files.files['file'] + self.assertEqual(file_obj._size, 14) -- cgit v1.2.3 From e36e4f48ad481b4303e68ed524677add07b224f7 Mon Sep 17 00:00:00 2001 From: Michael Elovskikh Date: Sat, 4 May 2013 14:58:21 +0600 Subject: Codebase improvements on FileUploadParser * Added docstrings. * Added `FileUploadParser.get_filename` to make it easier to override. * Added url kwargs filename detection step. * Updated tests corresponding to these changes. --- rest_framework/parsers.py | 45 +++++++++++++++++++++++++++++------------ rest_framework/tests/parsers.py | 10 +++++++-- 2 files changed, 40 insertions(+), 15 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 6ba05aef..7eb92184 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -215,16 +215,19 @@ class FileUploadParser(BaseParser): media_type = '*/*' def parse(self, stream, media_type=None, parser_context=None): + """ + Returns a DataAndFiles object. + + `.data` will be None (we expect request body to be a file content). + `.files` will be a `QueryDict` containing one 'file' elemnt - a parsed file. + """ + parser_context = parser_context or {} request = parser_context['request'] encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) meta = request.META - - try: - disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) - filename = disposition[1]['filename'] - except KeyError: - filename = None + upload_handlers = request.upload_handlers + filename = self.get_filename(stream, media_type, parser_context) content_type = meta.get('HTTP_CONTENT_TYPE', meta.get('CONTENT_TYPE', '')) try: @@ -233,28 +236,28 @@ class FileUploadParser(BaseParser): content_length = None # See if the handler will want to take care of the parsing. - for handler in request.upload_handlers: + for handler in upload_handlers: result = handler.handle_raw_input(None, meta, content_length, None, encoding) if result is not None: - return DataAndFiles(result[0], {'file': result[1]}) + return DataAndFiles(None, {'file': result[1]}) - possible_sizes = [x.chunk_size for x in request.upload_handlers if x.chunk_size] + possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size] chunk_size = min([2**31-4] + possible_sizes) chunks = ChunkIter(stream, chunk_size) - counters = [0] * len(request.upload_handlers) + counters = [0] * len(upload_handlers) - for handler in request.upload_handlers: + for handler in upload_handlers: try: handler.new_file(None, filename, content_type, content_length, encoding) except StopFutureHandlers: break for chunk in chunks: - for i, handler in enumerate(request.upload_handlers): + for i, handler in enumerate(upload_handlers): chunk_length = len(chunk) chunk = handler.receive_data_chunk(chunk, counters[i]) counters[i] += chunk_length @@ -262,7 +265,23 @@ class FileUploadParser(BaseParser): # If the chunk received by the handler is None, then don't continue. break - for i, handler in enumerate(request.upload_handlers): + for i, handler in enumerate(upload_handlers): file_obj = handler.file_complete(counters[i]) if file_obj: return DataAndFiles(None, {'file': file_obj}) + + def get_filename(self, stream, media_type, parser_context): + """ + Detects the uploaded file name. First searches a 'filename' url kwarg. + Then tries to parse Content-Disposition header. + """ + try: + return parser_context['kwargs']['filename'] + except KeyError: + pass + try: + meta = parser_context['request'].META + disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) + return disposition[1]['filename'] + except (AttributeError, KeyError): + pass diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py index b18ecbf2..7699e10c 100644 --- a/rest_framework/tests/parsers.py +++ b/rest_framework/tests/parsers.py @@ -99,11 +99,17 @@ class TestFileUploadParser(TestCase): 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'), 'HTTP_CONTENT_LENGTH': 14, } - self.parser_context = {'request': request} + self.parser_context = {'request': request, 'kwargs': {}} def test_parse(self): """ Make sure the `QueryDict` works OK """ parser = FileUploadParser() - data_and_files = parser.parse(self.stream, parser_context=self.parser_context) + self.stream.seek(0) + data_and_files = parser.parse(self.stream, None, self.parser_context) file_obj = data_and_files.files['file'] self.assertEqual(file_obj._size, 14) + + def test_get_filename(self): + parser = FileUploadParser() + filename = parser.get_filename(self.stream, None, self.parser_context) + self.assertEqual(filename, 'file.txt'.encode('utf-8')) -- cgit v1.2.3 From a514232815a82ad8a4dc1819afa0d62f9bab1323 Mon Sep 17 00:00:00 2001 From: Michael Elovskikh Date: Sat, 4 May 2013 17:18:10 +0600 Subject: Raise ParseError if can't handle the uploaded file --- rest_framework/parsers.py | 1 + 1 file changed, 1 insertion(+) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 7eb92184..27a0db65 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -269,6 +269,7 @@ class FileUploadParser(BaseParser): file_obj = handler.file_complete(counters[i]) if file_obj: return DataAndFiles(None, {'file': file_obj}) + raise ParseError("FileUpload parse error - none of upload handlers can handle the stream") def get_filename(self, stream, media_type, parser_context): """ -- cgit v1.2.3 From 538d2e35e7f1e4623a215d1b8c684b284f951c09 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sun, 5 May 2013 16:47:45 +0100 Subject: lookup_field on hyperlink serializers --- rest_framework/relations.py | 10 +++++++++- rest_framework/serializers.py | 4 ++++ 2 files changed, 13 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index bc7f112c..fc5054b2 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -360,7 +360,15 @@ class HyperlinkedRelatedField(RelatedField): # Only try slug if it corresponds to an attribute on the object. kwargs = {self.slug_url_kwarg: slug} try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) + ret = reverse(view_name, kwargs=kwargs, request=request, format=format) + if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug': + # If the lookup succeeds using the default slug params, + # then `slug_field` is being used implicitly, and we + # we need to warn about the pending deprecation. + msg = 'Implicit slug field hyperlinked fields are pending deprecation.' \ + 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + return ret except NoReverseMatch: pass diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d4b34c01..ea5175e2 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -827,6 +827,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): def __init__(self, meta): super(HyperlinkedModelSerializerOptions, self).__init__(meta) self.view_name = getattr(meta, 'view_name', None) + self.lookup_field = getattr(meta, 'slug_field', None) class HyperlinkedModelSerializer(ModelSerializer): @@ -875,6 +876,9 @@ class HyperlinkedModelSerializer(ModelSerializer): if model_field: kwargs['required'] = not(model_field.null or model_field.blank) + if self.opts.lookup_field: + kwargs['lookup_field'] = self.opts.lookup_field + return self._hyperlink_field_class(**kwargs) def get_identity(self, data): -- cgit v1.2.3 From 660d2405174519628c72ed84a69ae37531df12f3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sun, 5 May 2013 16:48:00 +0100 Subject: .action attribute on viewsets --- rest_framework/viewsets.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index a54467d7..0eb3e86d 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -59,6 +59,10 @@ class ViewSetMixin(object): def view(request, *args, **kwargs): self = cls(**initkwargs) + # We also store the mapping of request methods to actions, + # so that we can later set the action attribute. + # eg. `self.action = 'list'` on an incoming GET request. + self.action_map = actions # Bind methods to actions # This is the bit that's different to a standard view @@ -87,6 +91,15 @@ class ViewSetMixin(object): view.suffix = initkwargs.get('suffix', None) return view + def initialize_request(self, request, *args, **kargs): + """ + Set the `.action` attribute on the view, + depending on the request method. + """ + request = super(ViewSetMixin, self).initialize_request(request, *args, **kargs) + self.action = self.action_map.get(request.method.lower()) + return request + class ViewSet(ViewSetMixin, views.APIView): """ -- cgit v1.2.3 From d71a5533f9a8787652244dfb16af37fb7d9059fb Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 7 May 2013 12:25:41 +0100 Subject: allow_empty -> pending deprecation in preference of overridden get_queryset. --- rest_framework/filters.py | 29 +++++++++++++++++++++++++++++ rest_framework/generics.py | 12 +++++++++--- 2 files changed, 38 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 5e1cdbac..571704dc 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -3,7 +3,10 @@ Provides generic filtering backends that can be used to filter the results returned by list views. """ from __future__ import unicode_literals + +from django.db import models from rest_framework.compat import django_filters +import operator FilterSet = django_filters and django_filters.FilterSet or None @@ -62,3 +65,29 @@ class DjangoFilterBackend(BaseFilterBackend): return filter_class(request.QUERY_PARAMS, queryset=queryset).qs return queryset + + +class SearchFilter(BaseFilterBackend): + def construct_search(self, field_name): + if field_name.startswith('^'): + return "%s__istartswith" % field_name[1:] + elif field_name.startswith('='): + return "%s__iexact" % field_name[1:] + elif field_name.startswith('@'): + return "%s__search" % field_name[1:] + else: + return "%s__icontains" % field_name + + def filter_queryset(self, request, queryset, view): + search_fields = getattr(view, 'search_fields', None) + + if not search_fields: + return None + + orm_lookups = [self.construct_search(str(search_field)) + for search_field in self.search_fields] + for bit in self.query.split(): + or_queries = [models.Q(**{orm_lookup: bit}) + for orm_lookup in orm_lookups] + queryset = queryset.filter(reduce(operator.or_, or_queries)) + return queryset diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 62129dcc..2bb23a89 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -42,9 +42,6 @@ class GenericAPIView(views.APIView): # The filter backend class to use for queryset filtering filter_backend = api_settings.FILTER_BACKEND - # Determines if the view will return 200 or 404 responses for empty lists. - allow_empty = True - # The following attributes may be subject to change, # and should be considered private API. model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS @@ -56,6 +53,7 @@ class GenericAPIView(views.APIView): pk_url_kwarg = 'pk' slug_url_kwarg = 'slug' slug_field = 'slug' + allow_empty = True def get_serializer_context(self): """ @@ -111,6 +109,14 @@ class GenericAPIView(views.APIView): if not page_size: return None + if not self.allow_empty: + warnings.warn( + 'The `allow_empty` parameter is due to be deprecated. ' + 'To use `allow_empty=False` style behavior, You should override ' + '`get_queryset()` and explicitly raise a 404 on empty querysets.', + PendingDeprecationWarning, stacklevel=2 + ) + paginator = self.paginator_class(queryset, page_size, allow_empty_first_page=self.allow_empty) page_kwarg = self.kwargs.get(self.page_kwarg) -- cgit v1.2.3 From 3c2bb0666063917707bfbfedf056e5692bfcc471 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 7 May 2013 13:00:44 +0100 Subject: Support for multiple filter classes --- rest_framework/generics.py | 23 +++++++++++++++++------ rest_framework/settings.py | 12 +++++++++--- 2 files changed, 26 insertions(+), 9 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 2bb23a89..05ec93d3 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -39,8 +39,8 @@ class GenericAPIView(views.APIView): pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS page_kwarg = 'page' - # The filter backend class to use for queryset filtering - filter_backend = api_settings.FILTER_BACKEND + # The filter backend classes to use for queryset filtering + filter_backends = api_settings.DEFAULT_FILTER_BACKENDS # The following attributes may be subject to change, # and should be considered private API. @@ -54,6 +54,7 @@ class GenericAPIView(views.APIView): slug_url_kwarg = 'slug' slug_field = 'slug' allow_empty = True + filter_backend = api_settings.FILTER_BACKEND def get_serializer_context(self): """ @@ -150,10 +151,20 @@ class GenericAPIView(views.APIView): method if you want to apply the configured filtering backend to the default queryset. """ - if not self.filter_backend: - return queryset - backend = self.filter_backend() - return backend.filter_queryset(self.request, queryset, self) + filter_backends = self.filter_backends or [] + if not filter_backends and self.filter_backend: + warnings.warn( + 'The `filter_backend` attribute and `FILTER_BACKEND` setting ' + 'are due to be deprecated in favor of a `filter_backends` ' + 'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take ' + 'a *list* of filter backend classes.', + PendingDeprecationWarning, stacklevel=2 + ) + filter_backends = [self.filter_backend] + + for backend in filter_backends: + queryset = backend().filter_queryset(self.request, queryset, self) + return queryset ######################## ### The following methods provide default implementations diff --git a/rest_framework/settings.py b/rest_framework/settings.py index eede0c5a..734d8478 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -29,6 +29,7 @@ from rest_framework.compat import six USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None) DEFAULTS = { + # Base API policies 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer', 'rest_framework.renderers.BrowsableAPIRenderer', @@ -50,11 +51,15 @@ DEFAULTS = { 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', + + # Genric view behavior 'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.serializers.ModelSerializer', 'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer', + 'DEFAULT_FILTER_BACKENDS': (), + # Throttling 'DEFAULT_THROTTLE_RATES': { 'user': None, 'anon': None, @@ -64,9 +69,6 @@ DEFAULTS = { 'PAGINATE_BY': None, 'PAGINATE_BY_PARAM': None, - # Filtering - 'FILTER_BACKEND': None, - # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -95,6 +97,9 @@ DEFAULTS = { ISO_8601, ), 'TIME_FORMAT': ISO_8601, + + # Pending deprecation + 'FILTER_BACKEND': None, } @@ -108,6 +113,7 @@ IMPORT_STRINGS = ( 'DEFAULT_CONTENT_NEGOTIATION_CLASS', 'DEFAULT_MODEL_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS', + 'DEFAULT_FILTER_BACKENDS', 'FILTER_BACKEND', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', -- cgit v1.2.3 From 3353889ae85cc21890469cf00f7073d1ea5c2070 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 7 May 2013 13:27:27 +0100 Subject: Docs for FileUploadParser --- rest_framework/parsers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 27a0db65..614531a1 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -246,7 +246,7 @@ class FileUploadParser(BaseParser): return DataAndFiles(None, {'file': result[1]}) possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size] - chunk_size = min([2**31-4] + possible_sizes) + chunk_size = min([2 ** 31 - 4] + possible_sizes) chunks = ChunkIter(stream, chunk_size) counters = [0] * len(upload_handlers) @@ -280,9 +280,10 @@ class FileUploadParser(BaseParser): return parser_context['kwargs']['filename'] except KeyError: pass + try: meta = parser_context['request'].META disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) return disposition[1]['filename'] except (AttributeError, KeyError): - pass + raise ParseError("Filename must be set in Content-Disposition header.") -- cgit v1.2.3 From ed2cf180c961bb337c5d3ab7e5f74a1539c33ae4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 7 May 2013 13:29:38 +0100 Subject: Version 2.3.0 --- rest_framework/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 856badc6..35196c74 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.2.7' +__version__ = '2.3.0' VERSION = __version__ # synonym -- cgit v1.2.3 From d7c08222f14389b4d61e5ca9032c49b8b917d251 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 7 May 2013 14:11:48 +0100 Subject: Fix breadcrumb rendering issue --- rest_framework/__init__.py | 2 +- rest_framework/utils/breadcrumbs.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 35196c74..819558b5 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.3.0' +__version__ = '2.3.1' VERSION = __version__ # synonym diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 28801d09..d51374b0 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -24,7 +24,8 @@ def get_breadcrumbs(url): else: # Check if this is a REST framework view, # and if so add it to the breadcrumbs - if issubclass(getattr(view, 'cls', None), APIView): + cls = getattr(view, 'cls', None) + if cls is not None and issubclass(cls, APIView): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: -- cgit v1.2.3 From 429e078eee63a120c408946cf7c1460d4ca9e9b4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 May 2013 20:07:51 +0100 Subject: Allow None filename on uploaded files --- rest_framework/parsers.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 614531a1..25be2e6a 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -219,7 +219,7 @@ class FileUploadParser(BaseParser): Returns a DataAndFiles object. `.data` will be None (we expect request body to be a file content). - `.files` will be a `QueryDict` containing one 'file' elemnt - a parsed file. + `.files` will be a `QueryDict` containing one 'file' element. """ parser_context = parser_context or {} @@ -229,9 +229,13 @@ class FileUploadParser(BaseParser): upload_handlers = request.upload_handlers filename = self.get_filename(stream, media_type, parser_context) - content_type = meta.get('HTTP_CONTENT_TYPE', meta.get('CONTENT_TYPE', '')) + # Note that this code is extracted from Django's handling of + # file uploads in MultiPartParser. + content_type = meta.get('HTTP_CONTENT_TYPE', + meta.get('CONTENT_TYPE', '')) try: - content_length = int(meta.get('HTTP_CONTENT_LENGTH', meta.get('CONTENT_LENGTH', 0))) + content_length = int(meta.get('HTTP_CONTENT_LENGTH', + meta.get('CONTENT_LENGTH', 0))) except (ValueError, TypeError): content_length = None @@ -245,6 +249,7 @@ class FileUploadParser(BaseParser): if result is not None: return DataAndFiles(None, {'file': result[1]}) + # This is the standard case. possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size] chunk_size = min([2 ** 31 - 4] + possible_sizes) chunks = ChunkIter(stream, chunk_size) @@ -252,7 +257,8 @@ class FileUploadParser(BaseParser): for handler in upload_handlers: try: - handler.new_file(None, filename, content_type, content_length, encoding) + handler.new_file(None, filename, content_type, + content_length, encoding) except StopFutureHandlers: break @@ -262,14 +268,14 @@ class FileUploadParser(BaseParser): chunk = handler.receive_data_chunk(chunk, counters[i]) counters[i] += chunk_length if chunk is None: - # If the chunk received by the handler is None, then don't continue. break for i, handler in enumerate(upload_handlers): file_obj = handler.file_complete(counters[i]) if file_obj: return DataAndFiles(None, {'file': file_obj}) - raise ParseError("FileUpload parse error - none of upload handlers can handle the stream") + raise ParseError("FileUpload parse error - " + "none of upload handlers can handle the stream") def get_filename(self, stream, media_type, parser_context): """ @@ -286,4 +292,4 @@ class FileUploadParser(BaseParser): disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) return disposition[1]['filename'] except (AttributeError, KeyError): - raise ParseError("Filename must be set in Content-Disposition header.") + pass -- cgit v1.2.3 From de69a28b9e786b8c759cda4acedb0a1b8542298b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 May 2013 20:18:01 +0100 Subject: Test and fix for #814. --- rest_framework/filters.py | 14 ++++++++++---- rest_framework/tests/filterset.py | 28 +++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 571704dc..f2163f6f 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -38,21 +38,27 @@ class DjangoFilterBackend(BaseFilterBackend): """ filter_class = getattr(view, 'filter_class', None) filter_fields = getattr(view, 'filter_fields', None) - view_model = getattr(view, 'model', None) + model_cls = getattr(view, 'model', None) + queryset = getattr(view, 'queryset', None) + if model_cls is None and queryset is not None: + model_cls = queryset.model if filter_class: filter_model = filter_class.Meta.model - assert issubclass(filter_model, view_model), \ + assert issubclass(filter_model, model_cls), \ 'FilterSet model %s does not match view model %s' % \ - (filter_model, view_model) + (filter_model, model_cls) return filter_class if filter_fields: + assert model_cls is not None, 'Cannot use DjangoFilterBackend ' \ + 'on a view which does not have a .model or .queryset attribute.' + class AutoFilterSet(self.default_filter_set): class Meta: - model = view_model + model = model_cls fields = filter_fields return AutoFilterSet diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 1e53a5cd..023bd016 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -5,7 +5,7 @@ from django.core.urlresolvers import reverse from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest -from rest_framework import generics, status, filters +from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url from rest_framework.tests.models import FilterableItem, BasicModel @@ -52,6 +52,17 @@ if django_filters: filter_class = SeveralFieldsFilter filter_backend = filters.DjangoFilterBackend + # Regression test for #814 + class FilterableItemSerializer(serializers.ModelSerializer): + class Meta: + model = FilterableItem + + class FilterFieldsQuerysetView(generics.ListCreateAPIView): + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer + filter_fields = ['decimal', 'date'] + filter_backend = filters.DjangoFilterBackend + urlpatterns = patterns('', url(r'^(?P\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), url(r'^$', FilterClassRootView.as_view(), name='root-view'), @@ -114,6 +125,21 @@ class IntegrationTestFiltering(CommonFilteringTestCase): expected_data = [f for f in self.data if f['date'] == search_date] self.assertEqual(response.data, expected_data) + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_filter_with_queryset(self): + """ + Regression test for #814. + """ + view = FilterFieldsQuerysetView.as_view() + + # Tests that the decimal filter works. + search_decimal = Decimal('2.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] == search_decimal] + self.assertEqual(response.data, expected_data) + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_class_root_view(self): """ -- cgit v1.2.3 From b443560080a20d52a3dd49f625a103810935affd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 May 2013 20:38:50 +0100 Subject: Fix DATETIME_FORMAT, DATE_FORMAT, TIME_FORMAT settings. Closes #798 --- rest_framework/fields.py | 6 +++--- rest_framework/settings.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f934fc39..c83ee5ec 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -500,7 +500,7 @@ class DateField(WritableField): } empty = None input_formats = api_settings.DATE_INPUT_FORMATS - format = None + format = api_settings.DATE_FORMAT def __init__(self, input_formats=None, format=None, *args, **kwargs): self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -563,7 +563,7 @@ class DateTimeField(WritableField): } empty = None input_formats = api_settings.DATETIME_INPUT_FORMATS - format = None + format = api_settings.DATETIME_FORMAT def __init__(self, input_formats=None, format=None, *args, **kwargs): self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -632,7 +632,7 @@ class TimeField(WritableField): } empty = None input_formats = api_settings.TIME_INPUT_FORMATS - format = None + format = api_settings.TIME_FORMAT def __init__(self, input_formats=None, format=None, *args, **kwargs): self.input_formats = input_formats if input_formats is not None else self.input_formats diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 734d8478..beb511ac 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -86,17 +86,17 @@ DEFAULTS = { 'DATE_INPUT_FORMATS': ( ISO_8601, ), - 'DATE_FORMAT': ISO_8601, + 'DATE_FORMAT': None, 'DATETIME_INPUT_FORMATS': ( ISO_8601, ), - 'DATETIME_FORMAT': ISO_8601, + 'DATETIME_FORMAT': None, 'TIME_INPUT_FORMATS': ( ISO_8601, ), - 'TIME_FORMAT': ISO_8601, + 'TIME_FORMAT': None, # Pending deprecation 'FILTER_BACKEND': None, -- cgit v1.2.3 From 4ab7b8f257f9d3a1b35d34d0f90f0103b0cc6369 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 May 2013 20:49:49 +0100 Subject: Version 2.3.2 --- rest_framework/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 819558b5..b4961e2f 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.3.1' +__version__ = '2.3.2' VERSION = __version__ # synonym -- cgit v1.2.3 From 31f94ab409f1d5f41982a5946b980cf3ad8e3ba9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 9 May 2013 13:31:42 +0100 Subject: Added GenericViewSet and docs tweaking --- rest_framework/viewsets.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 0eb3e86d..7c820091 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -108,6 +108,15 @@ class ViewSet(ViewSetMixin, views.APIView): pass +class GenericViewSet(ViewSetMixin, generics.GenericAPIView): + """ + The GenericViewSet class does not provide any actions by default, + but does include the base set of generic view behavior, such as + the `get_object` and `get_queryset` methods. + """ + pass + + class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, ViewSetMixin, -- cgit v1.2.3 From 939cc5adba6f5a95aac317134eb841838a0bff3f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 9 May 2013 13:35:01 +0100 Subject: Tweak inheritance --- rest_framework/viewsets.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 7c820091..d91323f2 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -119,8 +119,7 @@ class GenericViewSet(ViewSetMixin, generics.GenericAPIView): class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, - ViewSetMixin, - generics.GenericAPIView): + GenericViewSet): """ A viewset that provides default `list()` and `retrieve()` actions. """ @@ -132,8 +131,7 @@ class ModelViewSet(mixins.CreateModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, mixins.ListModelMixin, - ViewSetMixin, - generics.GenericAPIView): + GenericViewSet): """ A viewset that provides default `create()`, `retrieve()`, `update()`, `partial_update()`, `destroy()` and `list()` actions. -- cgit v1.2.3 From 0176a5391b5d0c5c5dd61133f17b9b68840d6e1a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 9 May 2013 17:09:40 +0100 Subject: Fix HyperlinkedModelSerializer not respecting lookup_fields --- rest_framework/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index ea5175e2..d7a4c9ef 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -827,7 +827,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): def __init__(self, meta): super(HyperlinkedModelSerializerOptions, self).__init__(meta) self.view_name = getattr(meta, 'view_name', None) - self.lookup_field = getattr(meta, 'slug_field', None) + self.lookup_field = getattr(meta, 'lookup_field', None) class HyperlinkedModelSerializer(ModelSerializer): -- cgit v1.2.3 From 773a92eab3ac4b635511483ef906b3b8de9dedc9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 10 May 2013 21:57:05 +0100 Subject: Move models into test modules, out of models module --- rest_framework/tests/models.py | 7 ------- rest_framework/tests/pagination.py | 10 ++++++++-- 2 files changed, 8 insertions(+), 9 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index f2117538..40e41a64 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -58,13 +58,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel): rel = models.ManyToManyField(Anchor) -# Model to test filtering. -class FilterableItem(RESTFrameworkModel): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() - - # Model for regression test for #285 class Comment(RESTFrameworkModel): diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 6b8ef02f..894d53d6 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -1,18 +1,24 @@ from __future__ import unicode_literals import datetime from decimal import Decimal -import django +from django.db import models from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters -from rest_framework.tests.models import BasicModel, FilterableItem +from rest_framework.tests.models import BasicModel factory = RequestFactory() +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + class RootView(generics.ListCreateAPIView): """ Example description for OPTIONS. -- cgit v1.2.3 From 8ce36d2bf1a899683208dc7de425a238ab27d0b3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 10 May 2013 21:57:20 +0100 Subject: SearchFilter and tests --- rest_framework/filters.py | 9 ++++- rest_framework/tests/filterset.py | 81 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 87 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index f2163f6f..54cbbde3 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -74,6 +74,8 @@ class DjangoFilterBackend(BaseFilterBackend): class SearchFilter(BaseFilterBackend): + search_param = 'search' + def construct_search(self, field_name): if field_name.startswith('^'): return "%s__istartswith" % field_name[1:] @@ -90,10 +92,13 @@ class SearchFilter(BaseFilterBackend): if not search_fields: return None + search_terms = request.QUERY_PARAMS.get(self.search_param) orm_lookups = [self.construct_search(str(search_field)) - for search_field in self.search_fields] - for bit in self.query.split(): + for search_field in search_fields] + + for bit in search_terms.split(): or_queries = [models.Q(**{orm_lookup: bit}) for orm_lookup in orm_lookups] queryset = queryset.filter(reduce(operator.or_, or_queries)) + return queryset diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 023bd016..7865fedd 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -1,17 +1,24 @@ from __future__ import unicode_literals import datetime from decimal import Decimal +from django.db import models from django.core.urlresolvers import reverse from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url -from rest_framework.tests.models import FilterableItem, BasicModel +from rest_framework.tests.models import BasicModel factory = RequestFactory() +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + if django_filters: # Basic filter on a list view. class FilterFieldsRootView(generics.ListCreateAPIView): @@ -256,3 +263,75 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, valid_item_data) + + +class SearchFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class SearchFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # z abc + # zz bcd + # zzz cde + # ... + for idx in range(10): + title = 'z' * (idx + 1) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + SearchFilterModel(title=title, text=text).save() + + def test_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {u'id': 1, 'title': u'z', 'text': u'abc'}, + {u'id': 2, 'title': u'zz', 'text': u'bcd'} + ] + ) + + def test_exact_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('=title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=zzz') + response = view(request) + self.assertEqual( + response.data, + [ + {u'id': 3, 'title': u'zzz', 'text': u'cde'} + ] + ) + + def test_startswith_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', '^text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {u'id': 2, 'title': u'zz', 'text': u'bcd'} + ] + ) -- cgit v1.2.3 From 293dc3e6d8071fb464a63593831309468e457d6b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 10 May 2013 22:33:11 +0100 Subject: Added SearchFilter --- rest_framework/filters.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 54cbbde3..3edef30d 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -74,7 +74,8 @@ class DjangoFilterBackend(BaseFilterBackend): class SearchFilter(BaseFilterBackend): - search_param = 'search' + search_param = 'search' # The URL query parameter used for the search. + delimiter = None # For example, set to ',' for comma delimited searchs. def construct_search(self, field_name): if field_name.startswith('^'): @@ -96,8 +97,8 @@ class SearchFilter(BaseFilterBackend): orm_lookups = [self.construct_search(str(search_field)) for search_field in search_fields] - for bit in search_terms.split(): - or_queries = [models.Q(**{orm_lookup: bit}) + for search_term in search_terms.split(self.delimiter): + or_queries = [models.Q(**{orm_lookup: search_term}) for orm_lookup in orm_lookups] queryset = queryset.filter(reduce(operator.or_, or_queries)) -- cgit v1.2.3 From dd51d369c8228f3add37cc639702097b0df9cd90 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 10 May 2013 23:02:24 +0100 Subject: Unicode string fix --- rest_framework/tests/filterset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 7865fedd..e5414232 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -299,8 +299,8 @@ class SearchFilterTests(TestCase): self.assertEqual( response.data, [ - {u'id': 1, 'title': u'z', 'text': u'abc'}, - {u'id': 2, 'title': u'zz', 'text': u'bcd'} + {'id': 1, 'title': 'z', 'text': 'abc'}, + {'id': 2, 'title': 'zz', 'text': 'bcd'} ] ) @@ -316,7 +316,7 @@ class SearchFilterTests(TestCase): self.assertEqual( response.data, [ - {u'id': 3, 'title': u'zzz', 'text': u'cde'} + {'id': 3, 'title': 'zzz', 'text': 'cde'} ] ) @@ -332,6 +332,6 @@ class SearchFilterTests(TestCase): self.assertEqual( response.data, [ - {u'id': 2, 'title': u'zz', 'text': u'bcd'} + {'id': 2, 'title': 'zz', 'text': 'bcd'} ] ) -- cgit v1.2.3 From fd4a66cfc7888775d20b18665d63156cf3dae13a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 10 May 2013 23:06:42 +0100 Subject: Fix py3k compat with functools.reduce --- rest_framework/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 3edef30d..57f0f7c8 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -3,9 +3,9 @@ Provides generic filtering backends that can be used to filter the results returned by list views. """ from __future__ import unicode_literals - from django.db import models from rest_framework.compat import django_filters +from functools import reduce import operator FilterSet = django_filters and django_filters.FilterSet or None -- cgit v1.2.3 From 9d2580dccfe23e113221c7e150bddebb95d98214 Mon Sep 17 00:00:00 2001 From: Marlon Bailey Date: Sat, 11 May 2013 22:26:34 -0400 Subject: added support for multiple @action and @link decorators on a viewset, along with a router testcase illustrating the failure against the master code base --- rest_framework/routers.py | 6 +++--- rest_framework/tests/routers.py | 46 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 rest_framework/tests/routers.py (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 0707635a..ebdf2b2a 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -127,18 +127,18 @@ class SimpleRouter(BaseRouter): """ # Determine any `@action` or `@link` decorated methods on the viewset - dynamic_routes = {} + dynamic_routes = [] for methodname in dir(viewset): attr = getattr(viewset, methodname) httpmethod = getattr(attr, 'bind_to_method', None) if httpmethod: - dynamic_routes[httpmethod] = methodname + dynamic_routes.append((httpmethod, methodname)) ret = [] for route in self.routes: if route.mapping == {'{httpmethod}': '{methodname}'}: # Dynamic routes (@link or @action decorator) - for httpmethod, methodname in dynamic_routes.items(): + for httpmethod, methodname in dynamic_routes: initkwargs = route.initkwargs.copy() initkwargs.update(getattr(viewset, methodname).kwargs) ret.append(Route( diff --git a/rest_framework/tests/routers.py b/rest_framework/tests/routers.py new file mode 100644 index 00000000..138d13d7 --- /dev/null +++ b/rest_framework/tests/routers.py @@ -0,0 +1,46 @@ +from __future__ import unicode_literals +from django.test import TestCase +from django.test.client import RequestFactory +from rest_framework import status +from rest_framework.response import Response +from rest_framework import viewsets +from rest_framework.decorators import link, action +from rest_framework.routers import SimpleRouter +import copy + +factory = RequestFactory() + + +class BasicViewSet(viewsets.ViewSet): + def list(self, request, *args, **kwargs): + return Response({'method': 'list'}) + + @action() + def action1(self, request, *args, **kwargs): + return Response({'method': 'action1'}) + + @action() + def action2(self, request, *args, **kwargs): + return Response({'method': 'action2'}) + + @link() + def link1(self, request, *args, **kwargs): + return Response({'method': 'link1'}) + + @link() + def link2(self, request, *args, **kwargs): + return Response({'method': 'link2'}) + + +class TestSimpleRouter(TestCase): + def setUp(self): + self.router = SimpleRouter() + + def test_link_and_action_decorator(self): + routes = self.router.get_routes(BasicViewSet) + # Should be 2 by default, and then four from the @action and @link combined + #self.assertEqual(len(routes), 6) + # + decorator_routes = routes[2:] + for i, method in enumerate(['action1', 'action2', 'link1', 'link2']): + self.assertEqual(decorator_routes[i].mapping.values()[0], method) -- cgit v1.2.3 From 5e2d8052d4bf87c81cc9807c96c933ca975cc483 Mon Sep 17 00:00:00 2001 From: Marlon Bailey Date: Sun, 12 May 2013 09:22:14 -0400 Subject: fix test case to work with Python 3 and make it more explicit --- rest_framework/tests/routers.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/routers.py b/rest_framework/tests/routers.py index 138d13d7..4e4765cb 100644 --- a/rest_framework/tests/routers.py +++ b/rest_framework/tests/routers.py @@ -38,9 +38,18 @@ class TestSimpleRouter(TestCase): def test_link_and_action_decorator(self): routes = self.router.get_routes(BasicViewSet) - # Should be 2 by default, and then four from the @action and @link combined - #self.assertEqual(len(routes), 6) - # decorator_routes = routes[2:] - for i, method in enumerate(['action1', 'action2', 'link1', 'link2']): - self.assertEqual(decorator_routes[i].mapping.values()[0], method) + # Make sure all these endpoints exist and none have been clobbered + for i, endpoint in enumerate(['action1', 'action2', 'link1', 'link2']): + route = decorator_routes[i] + # check url listing + self.assertEqual(route.url, + '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint)) + # check method to function mapping + if endpoint.startswith('action'): + method_map = 'post' + else: + method_map = 'get' + self.assertEqual(route.mapping[method_map], endpoint) + + -- cgit v1.2.3 From 5074bbe4b21a0fc116e4288743fb78314a76a33b Mon Sep 17 00:00:00 2001 From: James Summerfield Date: Mon, 13 May 2013 07:51:23 +0200 Subject: Remove trailing unmatched in login_base.html template. Reformat indentation and label all closing tags for consistency. --- .../templates/rest_framework/login_base.html | 68 ++++++++++------------ 1 file changed, 32 insertions(+), 36 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html index 380d5820..a3e73b6b 100644 --- a/rest_framework/templates/rest_framework/login_base.html +++ b/rest_framework/templates/rest_framework/login_base.html @@ -12,44 +12,40 @@ -
-
- -
+
-
- {% block branding %}

Django REST framework

{% endblock %} -
-
- -
-
-
- {% csrf_token %} -
-
- - -
-
-
-
- - -
+
+
+
+ {% block branding %}

Django REST framework

{% endblock %}
- -
- +
+ +
+
+ + {% csrf_token %} +
+
+ + +
+
+
+
+ + +
+
+ +
+ +
+
- -
-
-
- -
-
- -
+
+
+
+
-- cgit v1.2.3 From 24c9c455feaa47487196a2c9343746d7d5bdd962 Mon Sep 17 00:00:00 2001 From: Brian Zambrano Date: Mon, 13 May 2013 10:51:51 -0700 Subject: Allow for missing non-required nested objects. Serializer fields which are themselves serializers should not be required. Specifically, if a nested object is set to "required=False", it should be possible to serialize the main object and have the sub-object set to None/null. --- rest_framework/fields.py | 2 +- rest_framework/tests/serializer.py | 47 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c83ee5ec..1f38b795 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -50,7 +50,7 @@ def get_component(obj, attr_name): return that attribute on the object. """ if isinstance(obj, dict): - val = obj[attr_name] + val = obj.get(attr_name) else: val = getattr(obj, attr_name) diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 84e1ee4e..6e732327 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -43,6 +43,17 @@ class CommentSerializer(serializers.Serializer): return instance +class NamesSerializer(serializers.Serializer): + first = serializers.CharField() + last = serializers.CharField(required=False, default='') + initials = serializers.CharField(required=False, default='') + + +class PersonIdentifierSerializer(serializers.Serializer): + ssn = serializers.CharField() + names = NamesSerializer(source='names', required=False) + + class BookSerializer(serializers.ModelSerializer): isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'}) @@ -141,6 +152,42 @@ class BasicTests(TestCase): self.assertFalse(serializer.object is expected) self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') + def test_create_nested(self): + """Test a serializer with nested data.""" + names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'} + data = {'ssn': '1234567890', 'names': names} + serializer = PersonIdentifierSerializer(data=data) + + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + self.assertFalse(serializer.object is data) + self.assertEqual(serializer.data['names'], names) + + def test_create_partial_nested(self): + """Test a serializer with nested data which has missing fields.""" + names = {'first': 'John'} + data = {'ssn': '1234567890', 'names': names} + serializer = PersonIdentifierSerializer(data=data) + + expected_names = {'first': 'John', 'last': '', 'initials': ''} + data['names'] = expected_names + + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + self.assertFalse(serializer.object is expected_names) + self.assertEqual(serializer.data['names'], expected_names) + + def test_null_nested(self): + """Test a serializer with a nonexistent nested field""" + data = {'ssn': '1234567890'} + serializer = PersonIdentifierSerializer(data=data) + + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + self.assertFalse(serializer.object is data) + expected = {'ssn': '1234567890', 'names': None} + self.assertEqual(serializer.data, expected) + def test_update(self): serializer = CommentSerializer(self.comment, data=self.data) expected = self.comment -- cgit v1.2.3 From 752c01420f7574cd99e28a17d56df711b675ce71 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 14 May 2013 10:01:05 +0100 Subject: Fix Django 1.3 compat with routers --- rest_framework/routers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 0707635a..ed4dc338 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -16,7 +16,7 @@ For example, you might have a `urls.py` that looks something like this: from __future__ import unicode_literals from collections import namedtuple -from django.conf.urls import url, patterns +from rest_framework.compat import patterns, url from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.reverse import reverse -- cgit v1.2.3 From b2bf5f1f886d131957f99308a0da89b24b3352d4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 14 May 2013 10:10:44 +0100 Subject: SearchFilter may be comma and/or whitespace seperated --- rest_framework/filters.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 57f0f7c8..c496ec4b 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -75,7 +75,14 @@ class DjangoFilterBackend(BaseFilterBackend): class SearchFilter(BaseFilterBackend): search_param = 'search' # The URL query parameter used for the search. - delimiter = None # For example, set to ',' for comma delimited searchs. + + def get_search_terms(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be comma and/or whitespace delimited. + """ + params = request.QUERY_PARAMS.get(self.search_param) + return params.replace(',', ' ').split() def construct_search(self, field_name): if field_name.startswith('^'): @@ -93,11 +100,10 @@ class SearchFilter(BaseFilterBackend): if not search_fields: return None - search_terms = request.QUERY_PARAMS.get(self.search_param) orm_lookups = [self.construct_search(str(search_field)) for search_field in search_fields] - for search_term in search_terms.split(self.delimiter): + for search_term in self.get_search_terms(request): or_queries = [models.Q(**{orm_lookup: search_term}) for orm_lookup in orm_lookups] queryset = queryset.filter(reduce(operator.or_, or_queries)) -- cgit v1.2.3 From 08bc97626960f108f01657e4ad12b7fd62e6183d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 14 May 2013 10:16:46 +0100 Subject: Rename filter tests --- rest_framework/tests/filters.py | 337 ++++++++++++++++++++++++++++++++++++++ rest_framework/tests/filterset.py | 337 -------------------------------------- 2 files changed, 337 insertions(+), 337 deletions(-) create mode 100644 rest_framework/tests/filters.py delete mode 100644 rest_framework/tests/filterset.py (limited to 'rest_framework') diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/filters.py new file mode 100644 index 00000000..e5414232 --- /dev/null +++ b/rest_framework/tests/filters.py @@ -0,0 +1,337 @@ +from __future__ import unicode_literals +import datetime +from decimal import Decimal +from django.db import models +from django.core.urlresolvers import reverse +from django.test import TestCase +from django.test.client import RequestFactory +from django.utils import unittest +from rest_framework import generics, serializers, status, filters +from rest_framework.compat import django_filters, patterns, url +from rest_framework.tests.models import BasicModel + +factory = RequestFactory() + + +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + +if django_filters: + # Basic filter on a list view. + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_fields = ['decimal', 'date'] + filter_backend = filters.DjangoFilterBackend + + # These class are used to test a filter class. + class SeveralFieldsFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + decimal = django_filters.NumberFilter(lookup_type='lt') + date = django_filters.DateFilter(lookup_type='gt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterClassRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backend = filters.DjangoFilterBackend + + # These classes are used to test a misconfigured filter class. + class MisconfiguredFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + + class Meta: + model = BasicModel + fields = ['text'] + + class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = MisconfiguredFilter + filter_backend = filters.DjangoFilterBackend + + class FilterClassDetailView(generics.RetrieveAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backend = filters.DjangoFilterBackend + + # Regression test for #814 + class FilterableItemSerializer(serializers.ModelSerializer): + class Meta: + model = FilterableItem + + class FilterFieldsQuerysetView(generics.ListCreateAPIView): + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer + filter_fields = ['decimal', 'date'] + filter_backend = filters.DjangoFilterBackend + + urlpatterns = patterns('', + url(r'^(?P\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), + url(r'^$', FilterClassRootView.as_view(), name='root-view'), + ) + + +class CommonFilteringTestCase(TestCase): + def _serialize_object(self, obj): + return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + + def setUp(self): + """ + Create 10 FilterableItem instances. + """ + base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) + for i in range(10): + text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. + decimal = base_data[1] + i + date = base_data[2] - datetime.timedelta(days=i * 2) + FilterableItem(text=text, decimal=decimal, date=date).save() + + self.objects = FilterableItem.objects + self.data = [ + self._serialize_object(obj) + for obj in self.objects.all() + ] + + +class IntegrationTestFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered list views. + """ + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_fields_root_view(self): + """ + GET requests to paginated ListCreateAPIView should return paginated results. + """ + view = FilterFieldsRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) + + # Tests that the decimal filter works. + search_decimal = Decimal('2.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] == search_decimal] + self.assertEqual(response.data, expected_data) + + # Tests that the date filter works. + search_date = datetime.date(2012, 9, 22) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] == search_date] + self.assertEqual(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_filter_with_queryset(self): + """ + Regression test for #814. + """ + view = FilterFieldsQuerysetView.as_view() + + # Tests that the decimal filter works. + search_decimal = Decimal('2.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] == search_decimal] + self.assertEqual(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_class_root_view(self): + """ + GET requests to filtered ListCreateAPIView that have a filter_class set + should return filtered results. + """ + view = FilterClassRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) + + # Tests that the decimal filter set with 'lt' in the filter class works. + search_decimal = Decimal('4.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] < search_decimal] + self.assertEqual(response.data, expected_data) + + # Tests that the date filter set with 'gt' in the filter class works. + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] > search_date] + self.assertEqual(response.data, expected_data) + + # Tests that the text filter set with 'icontains' in the filter class works. + search_text = 'ff' + request = factory.get('/?text=%s' % search_text) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if search_text in f['text'].lower()] + self.assertEqual(response.data, expected_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] > search_date and + f['decimal'] < search_decimal] + self.assertEqual(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_incorrectly_configured_filter(self): + """ + An error should be displayed when the filter class is misconfigured. + """ + view = IncorrectlyConfiguredRootView.as_view() + + request = factory.get('/') + self.assertRaises(AssertionError, view, request) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_unknown_filter(self): + """ + GET requests with filters that aren't configured should return 200. + """ + view = FilterFieldsRootView.as_view() + + search_integer = 10 + request = factory.get('/?integer=%s' % search_integer) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class IntegrationTestDetailFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered detail views. + """ + urls = 'rest_framework.tests.filterset' + + def _get_url(self, item): + return reverse('detail-view', kwargs=dict(pk=item.pk)) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_detail_view(self): + """ + GET requests to filtered RetrieveAPIView that have a filter_class set + should return filtered results. + """ + item = self.objects.all()[0] + data = self._serialize_object(item) + + # Basic test with no filter. + response = self.client.get(self._get_url(item)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, data) + + # Tests that the decimal filter set that should fail. + search_decimal = Decimal('4.25') + high_item = self.objects.filter(decimal__gt=search_decimal)[0] + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + # Tests that the decimal filter set that should succeed. + search_decimal = Decimal('4.25') + low_item = self.objects.filter(decimal__lt=search_decimal)[0] + low_item_data = self._serialize_object(low_item) + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, low_item_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] + valid_item_data = self._serialize_object(valid_item) + response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, valid_item_data) + + +class SearchFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class SearchFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # z abc + # zz bcd + # zzz cde + # ... + for idx in range(10): + title = 'z' * (idx + 1) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + SearchFilterModel(title=title, text=text).save() + + def test_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'z', 'text': 'abc'}, + {'id': 2, 'title': 'zz', 'text': 'bcd'} + ] + ) + + def test_exact_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('=title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=zzz') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'zzz', 'text': 'cde'} + ] + ) + + def test_startswith_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', '^text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 2, 'title': 'zz', 'text': 'bcd'} + ] + ) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py deleted file mode 100644 index e5414232..00000000 --- a/rest_framework/tests/filterset.py +++ /dev/null @@ -1,337 +0,0 @@ -from __future__ import unicode_literals -import datetime -from decimal import Decimal -from django.db import models -from django.core.urlresolvers import reverse -from django.test import TestCase -from django.test.client import RequestFactory -from django.utils import unittest -from rest_framework import generics, serializers, status, filters -from rest_framework.compat import django_filters, patterns, url -from rest_framework.tests.models import BasicModel - -factory = RequestFactory() - - -class FilterableItem(models.Model): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() - - -if django_filters: - # Basic filter on a list view. - class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_fields = ['decimal', 'date'] - filter_backend = filters.DjangoFilterBackend - - # These class are used to test a filter class. - class SeveralFieldsFilter(django_filters.FilterSet): - text = django_filters.CharFilter(lookup_type='icontains') - decimal = django_filters.NumberFilter(lookup_type='lt') - date = django_filters.DateFilter(lookup_type='gt') - - class Meta: - model = FilterableItem - fields = ['text', 'decimal', 'date'] - - class FilterClassRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend - - # These classes are used to test a misconfigured filter class. - class MisconfiguredFilter(django_filters.FilterSet): - text = django_filters.CharFilter(lookup_type='icontains') - - class Meta: - model = BasicModel - fields = ['text'] - - class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_class = MisconfiguredFilter - filter_backend = filters.DjangoFilterBackend - - class FilterClassDetailView(generics.RetrieveAPIView): - model = FilterableItem - filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend - - # Regression test for #814 - class FilterableItemSerializer(serializers.ModelSerializer): - class Meta: - model = FilterableItem - - class FilterFieldsQuerysetView(generics.ListCreateAPIView): - queryset = FilterableItem.objects.all() - serializer_class = FilterableItemSerializer - filter_fields = ['decimal', 'date'] - filter_backend = filters.DjangoFilterBackend - - urlpatterns = patterns('', - url(r'^(?P\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), - url(r'^$', FilterClassRootView.as_view(), name='root-view'), - ) - - -class CommonFilteringTestCase(TestCase): - def _serialize_object(self, obj): - return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} - - def setUp(self): - """ - Create 10 FilterableItem instances. - """ - base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) - for i in range(10): - text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. - decimal = base_data[1] + i - date = base_data[2] - datetime.timedelta(days=i * 2) - FilterableItem(text=text, decimal=decimal, date=date).save() - - self.objects = FilterableItem.objects - self.data = [ - self._serialize_object(obj) - for obj in self.objects.all() - ] - - -class IntegrationTestFiltering(CommonFilteringTestCase): - """ - Integration tests for filtered list views. - """ - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_get_filtered_fields_root_view(self): - """ - GET requests to paginated ListCreateAPIView should return paginated results. - """ - view = FilterFieldsRootView.as_view() - - # Basic test with no filter. - request = factory.get('/') - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, self.data) - - # Tests that the decimal filter works. - search_decimal = Decimal('2.25') - request = factory.get('/?decimal=%s' % search_decimal) - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] == search_decimal] - self.assertEqual(response.data, expected_data) - - # Tests that the date filter works. - search_date = datetime.date(2012, 9, 22) - request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] == search_date] - self.assertEqual(response.data, expected_data) - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_filter_with_queryset(self): - """ - Regression test for #814. - """ - view = FilterFieldsQuerysetView.as_view() - - # Tests that the decimal filter works. - search_decimal = Decimal('2.25') - request = factory.get('/?decimal=%s' % search_decimal) - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] == search_decimal] - self.assertEqual(response.data, expected_data) - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_get_filtered_class_root_view(self): - """ - GET requests to filtered ListCreateAPIView that have a filter_class set - should return filtered results. - """ - view = FilterClassRootView.as_view() - - # Basic test with no filter. - request = factory.get('/') - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, self.data) - - # Tests that the decimal filter set with 'lt' in the filter class works. - search_decimal = Decimal('4.25') - request = factory.get('/?decimal=%s' % search_decimal) - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] < search_decimal] - self.assertEqual(response.data, expected_data) - - # Tests that the date filter set with 'gt' in the filter class works. - search_date = datetime.date(2012, 10, 2) - request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] > search_date] - self.assertEqual(response.data, expected_data) - - # Tests that the text filter set with 'icontains' in the filter class works. - search_text = 'ff' - request = factory.get('/?text=%s' % search_text) - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if search_text in f['text'].lower()] - self.assertEqual(response.data, expected_data) - - # Tests that multiple filters works. - search_decimal = Decimal('5.25') - search_date = datetime.date(2012, 10, 2) - request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] > search_date and - f['decimal'] < search_decimal] - self.assertEqual(response.data, expected_data) - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_incorrectly_configured_filter(self): - """ - An error should be displayed when the filter class is misconfigured. - """ - view = IncorrectlyConfiguredRootView.as_view() - - request = factory.get('/') - self.assertRaises(AssertionError, view, request) - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_unknown_filter(self): - """ - GET requests with filters that aren't configured should return 200. - """ - view = FilterFieldsRootView.as_view() - - search_integer = 10 - request = factory.get('/?integer=%s' % search_integer) - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - - -class IntegrationTestDetailFiltering(CommonFilteringTestCase): - """ - Integration tests for filtered detail views. - """ - urls = 'rest_framework.tests.filterset' - - def _get_url(self, item): - return reverse('detail-view', kwargs=dict(pk=item.pk)) - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_get_filtered_detail_view(self): - """ - GET requests to filtered RetrieveAPIView that have a filter_class set - should return filtered results. - """ - item = self.objects.all()[0] - data = self._serialize_object(item) - - # Basic test with no filter. - response = self.client.get(self._get_url(item)) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, data) - - # Tests that the decimal filter set that should fail. - search_decimal = Decimal('4.25') - high_item = self.objects.filter(decimal__gt=search_decimal)[0] - response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - # Tests that the decimal filter set that should succeed. - search_decimal = Decimal('4.25') - low_item = self.objects.filter(decimal__lt=search_decimal)[0] - low_item_data = self._serialize_object(low_item) - response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, low_item_data) - - # Tests that multiple filters works. - search_decimal = Decimal('5.25') - search_date = datetime.date(2012, 10, 2) - valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] - valid_item_data = self._serialize_object(valid_item) - response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, valid_item_data) - - -class SearchFilterModel(models.Model): - title = models.CharField(max_length=20) - text = models.CharField(max_length=100) - - -class SearchFilterTests(TestCase): - def setUp(self): - # Sequence of title/text is: - # - # z abc - # zz bcd - # zzz cde - # ... - for idx in range(10): - title = 'z' * (idx + 1) - text = ( - chr(idx + ord('a')) + - chr(idx + ord('b')) + - chr(idx + ord('c')) - ) - SearchFilterModel(title=title, text=text).save() - - def test_search(self): - class SearchListView(generics.ListAPIView): - model = SearchFilterModel - filter_backends = (filters.SearchFilter,) - search_fields = ('title', 'text') - - view = SearchListView.as_view() - request = factory.get('?search=b') - response = view(request) - self.assertEqual( - response.data, - [ - {'id': 1, 'title': 'z', 'text': 'abc'}, - {'id': 2, 'title': 'zz', 'text': 'bcd'} - ] - ) - - def test_exact_search(self): - class SearchListView(generics.ListAPIView): - model = SearchFilterModel - filter_backends = (filters.SearchFilter,) - search_fields = ('=title', 'text') - - view = SearchListView.as_view() - request = factory.get('?search=zzz') - response = view(request) - self.assertEqual( - response.data, - [ - {'id': 3, 'title': 'zzz', 'text': 'cde'} - ] - ) - - def test_startswith_search(self): - class SearchListView(generics.ListAPIView): - model = SearchFilterModel - filter_backends = (filters.SearchFilter,) - search_fields = ('title', '^text') - - view = SearchListView.as_view() - request = factory.get('?search=b') - response = view(request) - self.assertEqual( - response.data, - [ - {'id': 2, 'title': 'zz', 'text': 'bcd'} - ] - ) -- cgit v1.2.3 From 6a037f63edf33e7a76f56828cf68bfae4ccb4f80 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 14 May 2013 11:27:03 +0100 Subject: Added OrderingFilter --- rest_framework/filters.py | 41 +++++++++++++- rest_framework/tests/filters.py | 116 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 156 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index c496ec4b..308e7da2 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -4,7 +4,7 @@ returned by list views. """ from __future__ import unicode_literals from django.db import models -from rest_framework.compat import django_filters +from rest_framework.compat import django_filters, six from functools import reduce import operator @@ -109,3 +109,42 @@ class SearchFilter(BaseFilterBackend): queryset = queryset.filter(reduce(operator.or_, or_queries)) return queryset + + +class OrderingFilter(BaseFilterBackend): + ordering_param = 'order' # The URL query parameter used for the ordering. + + def get_ordering(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be comma and/or whitespace delimited. + """ + params = request.QUERY_PARAMS.get(self.ordering_param) + if params: + return [param.strip() for param in params.split(',')] + + def get_default_ordering(self, view): + ordering = getattr(view, 'ordering', None) + if isinstance(ordering, six.string_types): + return (ordering,) + return ordering + + def remove_invalid_fields(self, queryset, ordering): + field_names = [field.name for field in queryset.model._meta.fields] + return [term for term in ordering if term.lstrip('-') in field_names] + + def filter_queryset(self, request, queryset, view): + ordering = self.get_ordering(request) + + if ordering: + # Skip any incorrect parameters + ordering = self.remove_invalid_fields(queryset, ordering) + + if not ordering: + # Use 'ordering' attribtue by default + ordering = self.get_default_ordering(view) + + if ordering: + return queryset.order_by(*ordering) + + return queryset diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/filters.py index e5414232..6b604deb 100644 --- a/rest_framework/tests/filters.py +++ b/rest_framework/tests/filters.py @@ -335,3 +335,119 @@ class SearchFilterTests(TestCase): {'id': 2, 'title': 'zz', 'text': 'bcd'} ] ) + + +class OrdringFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class OrderingFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # zyx abc + # yxw bcd + # xwv cde + for idx in range(3): + title = ( + chr(ord('z') - idx) + + chr(ord('y') - idx) + + chr(ord('x') - idx) + ) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + OrdringFilterModel(title=title, text=text).save() + + def test_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?order=text') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + ] + ) + + def test_reverse_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?order=-text') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_incorrectfield_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?order=foobar') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_default_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_default_ordering_using_string(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = 'title' + + view = OrderingListView.as_view() + request = factory.get('') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) -- cgit v1.2.3 From 2cff6e69dbe3828eca56d0ce60ffdfc80fed045c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 14 May 2013 11:27:08 +0100 Subject: Added OrderingFilter --- rest_framework/filters.py | 2 +- rest_framework/tests/filters.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 308e7da2..6a3e055d 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -112,7 +112,7 @@ class SearchFilter(BaseFilterBackend): class OrderingFilter(BaseFilterBackend): - ordering_param = 'order' # The URL query parameter used for the ordering. + ordering_param = 'ordering' # The URL query parameter used for the ordering. def get_ordering(self, request): """ diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/filters.py index 6b604deb..b20d5980 100644 --- a/rest_framework/tests/filters.py +++ b/rest_framework/tests/filters.py @@ -369,7 +369,7 @@ class OrderingFilterTests(TestCase): ordering = ('title',) view = OrderingListView.as_view() - request = factory.get('?order=text') + request = factory.get('?ordering=text') response = view(request) self.assertEqual( response.data, @@ -387,7 +387,7 @@ class OrderingFilterTests(TestCase): ordering = ('title',) view = OrderingListView.as_view() - request = factory.get('?order=-text') + request = factory.get('?ordering=-text') response = view(request) self.assertEqual( response.data, @@ -405,7 +405,7 @@ class OrderingFilterTests(TestCase): ordering = ('title',) view = OrderingListView.as_view() - request = factory.get('?order=foobar') + request = factory.get('?ordering=foobar') response = view(request) self.assertEqual( response.data, -- cgit v1.2.3 From a303d0f38c4758fc3aad412529922203e5785e29 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 14 May 2013 11:37:59 +0100 Subject: Fix filter test renaming --- rest_framework/tests/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/filters.py index b20d5980..18972c84 100644 --- a/rest_framework/tests/filters.py +++ b/rest_framework/tests/filters.py @@ -222,7 +222,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): """ Integration tests for filtered detail views. """ - urls = 'rest_framework.tests.filterset' + urls = 'rest_framework.tests.filters' def _get_url(self, item): return reverse('detail-view', kwargs=dict(pk=item.pk)) -- cgit v1.2.3 From d62414147fa949af4db698afedae7b5506229a9f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 14 May 2013 17:53:37 +0100 Subject: Fix assert messaging on fields/exclude checking. Closes #833 --- rest_framework/serializers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d7a4c9ef..ecff2c52 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -200,7 +200,7 @@ class BaseSerializer(WritableField): # If 'fields' is specified, use those fields, in that order. if self.opts.fields: - assert isinstance(self.opts.fields, (list, tuple)), '`include` must be a list or tuple' + assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple' new = SortedDict() for key in self.opts.fields: new[key] = ret[key] @@ -208,7 +208,7 @@ class BaseSerializer(WritableField): # Remove anything in 'exclude' if self.opts.exclude: - assert isinstance(self.opts.fields, (list, tuple)), '`exclude` must be a list or tuple' + assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple' for key in self.opts.exclude: ret.pop(key, None) -- cgit v1.2.3 From e939e1755a94b50c87a82c0f777645e28fe91bf0 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 14 May 2013 21:40:55 +0100 Subject: Base automatic filterset model on the queryset model. Fixes #834. --- rest_framework/filters.py | 19 ++++++------------- rest_framework/tests/filters.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 13 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 6a3e055d..34831dd7 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -32,40 +32,33 @@ class DjangoFilterBackend(BaseFilterBackend): def __init__(self): assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed' - def get_filter_class(self, view): + def get_filter_class(self, view, queryset=None): """ Return the django-filters `FilterSet` used to filter the queryset. """ filter_class = getattr(view, 'filter_class', None) filter_fields = getattr(view, 'filter_fields', None) - model_cls = getattr(view, 'model', None) - queryset = getattr(view, 'queryset', None) - if model_cls is None and queryset is not None: - model_cls = queryset.model if filter_class: filter_model = filter_class.Meta.model - assert issubclass(filter_model, model_cls), \ - 'FilterSet model %s does not match view model %s' % \ - (filter_model, model_cls) + assert issubclass(filter_model, queryset.model), \ + 'FilterSet model %s does not match queryset model %s' % \ + (filter_model, queryset.model) return filter_class if filter_fields: - assert model_cls is not None, 'Cannot use DjangoFilterBackend ' \ - 'on a view which does not have a .model or .queryset attribute.' - class AutoFilterSet(self.default_filter_set): class Meta: - model = model_cls + model = queryset.model fields = filter_fields return AutoFilterSet return None def filter_queryset(self, request, queryset, view): - filter_class = self.get_filter_class(view) + filter_class = self.get_filter_class(view, queryset) if filter_class: return filter_class(request.QUERY_PARAMS, queryset=queryset).qs diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/filters.py index 18972c84..a58c66ae 100644 --- a/rest_framework/tests/filters.py +++ b/rest_framework/tests/filters.py @@ -70,9 +70,19 @@ if django_filters: filter_fields = ['decimal', 'date'] filter_backend = filters.DjangoFilterBackend + class GetQuerysetView(generics.ListCreateAPIView): + serializer_class = FilterableItemSerializer + filter_class = SeveralFieldsFilter + filter_backend = filters.DjangoFilterBackend + + def get_queryset(self): + return FilterableItem.objects.all() + urlpatterns = patterns('', url(r'^(?P\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), url(r'^$', FilterClassRootView.as_view(), name='root-view'), + url(r'^get-queryset/$', GetQuerysetView.as_view(), + name='get-queryset-view'), ) @@ -147,6 +157,17 @@ class IntegrationTestFiltering(CommonFilteringTestCase): expected_data = [f for f in self.data if f['decimal'] == search_decimal] self.assertEqual(response.data, expected_data) + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_filter_with_get_queryset_only(self): + """ + Regression test for #834. + """ + view = GetQuerysetView.as_view() + request = factory.get('/get-queryset/') + view(request).render() + # Used to raise "issubclass() arg 2 must be a class or tuple of classes" + # here when neither `model' nor `queryset' was specified. + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_class_root_view(self): """ -- cgit v1.2.3 From 092d5223eb7ea1bbf9b6bb967200cb3725e02112 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 15 May 2013 10:29:51 +0100 Subject: Fix searchfilter issues --- rest_framework/filters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 34831dd7..c058bc71 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -74,7 +74,7 @@ class SearchFilter(BaseFilterBackend): Search terms are set by a ?search=... query parameter, and may be comma and/or whitespace delimited. """ - params = request.QUERY_PARAMS.get(self.search_param) + params = request.QUERY_PARAMS.get(self.search_param, '') return params.replace(',', ' ').split() def construct_search(self, field_name): @@ -91,7 +91,7 @@ class SearchFilter(BaseFilterBackend): search_fields = getattr(view, 'search_fields', None) if not search_fields: - return None + return queryset orm_lookups = [self.construct_search(str(search_field)) for search_field in search_fields] -- cgit v1.2.3 From af88a5b1751da32018e8408eac01a91a5f63f8ce Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 15 May 2013 14:25:25 +0100 Subject: Test and fix which closes #652. --- rest_framework/serializers.py | 8 +++++++- rest_framework/tests/serializer.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index ecff2c52..7707de7a 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -649,8 +649,14 @@ class ModelSerializer(Serializer): # Add the `read_only` flag to any fields that have bee specified # in the `read_only_fields` option for field_name in self.opts.read_only_fields: + assert field_name not in self.base_fields.keys(), \ + "field '%s' on serializer '%s' specfied in " \ + "`read_only_fields`, but also added " \ + "as an explict field. Remove it from `read_only_fields`." % \ + (field_name, self.__class__.__name__) assert field_name in ret, \ - "read_only_fields on '%s' included invalid item '%s'" % \ + "Noexistant field '%s' specified in `read_only_fields` " \ + "on serializer '%s'." % \ (self.__class__.__name__, field_name) ret[field_name].read_only = True diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 84e1ee4e..db3881f9 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -78,6 +78,18 @@ class PersonSerializer(serializers.ModelSerializer): read_only_fields = ('age',) +class PersonSerializerInvalidReadOnly(serializers.ModelSerializer): + """ + Testing for #652. + """ + info = serializers.Field(source='info') + + class Meta: + model = Person + fields = ('name', 'age', 'info') + read_only_fields = ('age', 'info') + + class AlbumsSerializer(serializers.ModelSerializer): class Meta: @@ -189,6 +201,12 @@ class BasicTests(TestCase): # Assert age is unchanged (35) self.assertEqual(instance.age, self.person_data['age']) + def test_invalid_read_only_fields(self): + """ + Regression test for #652. + """ + self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, []) + class DictStyleSerializer(serializers.Serializer): """ -- cgit v1.2.3 From aff88d15f7a483bca2da120339b1b346aa8b1d4c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 16 May 2013 15:08:12 +0100 Subject: Version 2.3.3 --- rest_framework/__init__.py | 2 +- rest_framework/permissions.py | 5 +++++ rest_framework/routers.py | 17 ++++++++++------- 3 files changed, 16 insertions(+), 8 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index b4961e2f..0b1e67fb 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.3.2' +__version__ = '2.3.3' VERSION = __version__ # synonym diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 751f31a7..45fcfd66 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -126,6 +126,11 @@ class DjangoModelPermissions(BasePermission): if model_cls is None and queryset is not None: model_cls = queryset.model + # Workaround to ensure DjangoModelPermissions are not applied + # to the root view when using DefaultRouter. + if model_cls is None and getattr(view, '_ignore_model_permissions'): + return True + assert model_cls, ('Cannot apply DjangoModelPermissions on a view that' ' does not have `.model` or `.queryset` property.') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 76714fd0..dba104c3 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -16,6 +16,7 @@ For example, you might have a `urls.py` that looks something like this: from __future__ import unicode_literals from collections import namedtuple +from rest_framework import views from rest_framework.compat import patterns, url from rest_framework.decorators import api_view from rest_framework.response import Response @@ -217,14 +218,16 @@ class DefaultRouter(SimpleRouter): for prefix, viewset, basename in self.registry: api_root_dict[prefix] = list_name.format(basename=basename) - @api_view(('GET',)) - def api_root(request, format=None): - ret = {} - for key, url_name in api_root_dict.items(): - ret[key] = reverse(url_name, request=request, format=format) - return Response(ret) + class APIRoot(views.APIView): + _ignore_model_permissions = True - return api_root + def get(self, request, format=None): + ret = {} + for key, url_name in api_root_dict.items(): + ret[key] = reverse(url_name, request=request, format=format) + return Response(ret) + + return APIRoot.as_view() def get_urls(self): """ -- cgit v1.2.3 From abe207b869c771187523efd3d189ffc0beba51c3 Mon Sep 17 00:00:00 2001 From: Andy Freeland Date: Thu, 16 May 2013 11:24:11 -0400 Subject: HyperlinkedIdentityField uses `lookup_field` kwarg. According to the [Serializers API Guide][1], `HyperlinkedIdentityField` takes `lookup_field` as a kwarg like the other related fields and the generic views. However, this was not actually implemented. [1]: http://django-rest-framework.org/api-guide/serializers.html#hyperlinkedmodelserializer --- rest_framework/relations.py | 21 ++++++++++++-- rest_framework/tests/hyperlinkedserializers.py | 40 ++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index fc5054b2..c4b790d4 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -465,10 +465,13 @@ class HyperlinkedIdentityField(Field): """ Represents the instance, or a property on the instance, using hyperlinking. """ + lookup_field = 'pk' + read_only = True + + # These are all pending deprecation pk_url_kwarg = 'pk' slug_field = 'slug' slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden - read_only = True def __init__(self, *args, **kwargs): # TODO: Make view_name mandatory, and have the @@ -477,6 +480,19 @@ class HyperlinkedIdentityField(Field): # Optionally the format of the target hyperlink may be specified self.format = kwargs.pop('format', None) + self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + + # These are pending deprecation + if 'pk_url_kwarg' in kwargs: + msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_url_kwarg' in kwargs: + msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_field' in kwargs: + msg = 'slug_field is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + self.slug_field = kwargs.pop('slug_field', self.slug_field) default_slug_kwarg = self.slug_url_kwarg or self.slug_field self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) @@ -488,7 +504,8 @@ class HyperlinkedIdentityField(Field): request = self.context.get('request', None) format = self.context.get('format', None) view_name = self.view_name or self.parent.opts.view_name - kwargs = {self.pk_url_kwarg: obj.pk} + lookup_field = getattr(obj, self.lookup_field) + kwargs = {self.lookup_field: lookup_field} if request is None: warnings.warn("Using `HyperlinkedIdentityField` without including the " diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index 9a61f299..8fc6ba77 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -27,6 +27,14 @@ class PhotoSerializer(serializers.Serializer): return Photo(**attrs) +class AlbumSerializer(serializers.ModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title') + + class Meta: + model = Album + fields = ('title', 'url') + + class BasicList(generics.ListCreateAPIView): model = BasicModel model_serializer_class = serializers.HyperlinkedModelSerializer @@ -73,6 +81,8 @@ class PhotoListCreate(generics.ListCreateAPIView): class AlbumDetail(generics.RetrieveAPIView): model = Album + serializer_class = AlbumSerializer + lookup_field = 'title' class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): @@ -180,6 +190,36 @@ class TestManyToManyHyperlinkedView(TestCase): self.assertEqual(response.data, self.data[0]) +class TestHyperlinkedIdentityFieldLookup(TestCase): + urls = 'rest_framework.tests.hyperlinkedserializers' + + def setUp(self): + """ + Create 3 Album instances. + """ + titles = ['foo', 'bar', 'baz'] + for title in titles: + album = Album(title=title) + album.save() + self.detail_view = AlbumDetail.as_view() + self.data = { + 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'}, + 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'}, + 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'} + } + + def test_lookup_field(self): + """ + GET requests to AlbumDetail view should return serialized Albums + with a url field keyed by `title`. + """ + for album in Album.objects.all(): + request = factory.get('/albums/{0}/'.format(album.title)) + response = self.detail_view(request, title=album.title) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[album.title]) + + class TestCreateWithForeignKeys(TestCase): urls = 'rest_framework.tests.hyperlinkedserializers' -- cgit v1.2.3 From 14ded26167b68aaf8316a6bf83b6be3e77c8bbd8 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 17 May 2013 21:28:33 +0100 Subject: PendingDeprecation warning to allow_empty --- rest_framework/mixins.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index ae703771..55d21a70 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -10,6 +10,7 @@ from django.http import Http404 from rest_framework import status from rest_framework.response import Response from rest_framework.request import clone_request +import warnings def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None): @@ -77,6 +78,12 @@ class ListModelMixin(object): # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. if not self.allow_empty and not self.object_list: + warnings.warn( + 'The `allow_empty` parameter is due to be deprecated. ' + 'To use `allow_empty=False` style behavior, You should override ' + '`get_queryset()` and explicitly raise a 404 on empty querysets.', + PendingDeprecationWarning + ) class_name = self.__class__.__name__ error_msg = self.empty_error % {'class_name': class_name} raise Http404(error_msg) -- cgit v1.2.3 From b6fb377c2b4b747597bc3291dadd52b633b135b4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 17 May 2013 21:57:11 +0100 Subject: Fix PendingDeprecation warnings in tests --- rest_framework/tests/filters.py | 12 ++++++------ rest_framework/tests/generics.py | 25 +++++++++---------------- rest_framework/tests/pagination.py | 4 ++-- 3 files changed, 17 insertions(+), 24 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/filters.py index a58c66ae..8ae6d530 100644 --- a/rest_framework/tests/filters.py +++ b/rest_framework/tests/filters.py @@ -24,7 +24,7 @@ if django_filters: class FilterFieldsRootView(generics.ListCreateAPIView): model = FilterableItem filter_fields = ['decimal', 'date'] - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) # These class are used to test a filter class. class SeveralFieldsFilter(django_filters.FilterSet): @@ -39,7 +39,7 @@ if django_filters: class FilterClassRootView(generics.ListCreateAPIView): model = FilterableItem filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) # These classes are used to test a misconfigured filter class. class MisconfiguredFilter(django_filters.FilterSet): @@ -52,12 +52,12 @@ if django_filters: class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): model = FilterableItem filter_class = MisconfiguredFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) class FilterClassDetailView(generics.RetrieveAPIView): model = FilterableItem filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) # Regression test for #814 class FilterableItemSerializer(serializers.ModelSerializer): @@ -68,12 +68,12 @@ if django_filters: queryset = FilterableItem.objects.all() serializer_class = FilterableItemSerializer filter_fields = ['decimal', 'date'] - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) class GetQuerysetView(generics.ListCreateAPIView): serializer_class = FilterableItemSerializer filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) def get_queryset(self): return FilterableItem.objects.all() diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index eca50d82..2799d143 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -39,6 +39,7 @@ class SlugBasedInstanceView(InstanceView): """ model = SlugBasedModel serializer_class = SlugSerializer + lookup_field = 'slug' class TestRootView(TestCase): @@ -434,22 +435,14 @@ class TestFilterBackendAppliedToViews(TestCase): {'id': obj.id, 'text': obj.text} for obj in self.objects.all() ] - self.root_view = RootView.as_view() - self.instance_view = InstanceView.as_view() - self.original_root_backend = getattr(RootView, 'filter_backend') - self.original_instance_backend = getattr(InstanceView, 'filter_backend') - - def tearDown(self): - setattr(RootView, 'filter_backend', self.original_root_backend) - setattr(InstanceView, 'filter_backend', self.original_instance_backend) def test_get_root_view_filters_by_name_with_filter_backend(self): """ GET requests to ListCreateAPIView should return filtered list. """ - setattr(RootView, 'filter_backend', InclusiveFilterBackend) + root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) request = factory.get('/') - response = self.root_view(request).render() + response = root_view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data), 1) self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}]) @@ -458,9 +451,9 @@ class TestFilterBackendAppliedToViews(TestCase): """ GET requests to ListCreateAPIView should return empty list when all models are filtered out. """ - setattr(RootView, 'filter_backend', ExclusiveFilterBackend) + root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) request = factory.get('/') - response = self.root_view(request).render() + response = root_view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, []) @@ -468,9 +461,9 @@ class TestFilterBackendAppliedToViews(TestCase): """ GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. """ - setattr(InstanceView, 'filter_backend', ExclusiveFilterBackend) + instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) request = factory.get('/1') - response = self.instance_view(request, pk=1).render() + response = instance_view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.data, {'detail': 'Not found'}) @@ -478,8 +471,8 @@ class TestFilterBackendAppliedToViews(TestCase): """ GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded """ - setattr(InstanceView, 'filter_backend', InclusiveFilterBackend) + instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) request = factory.get('/1') - response = self.instance_view(request, pk=1).render() + response = instance_view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foo'}) diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 894d53d6..e538a78e 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -130,7 +130,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): model = FilterableItem paginate_by = 10 filter_class = DecimalFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) view = FilterFieldsRootView.as_view() @@ -177,7 +177,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): class BasicFilterFieldsRootView(generics.ListCreateAPIView): model = FilterableItem paginate_by = 10 - filter_backend = DecimalFilterBackend + filter_backends = (DecimalFilterBackend,) view = BasicFilterFieldsRootView.as_view() -- cgit v1.2.3 From 34776da9249a5d73f822b3562bc56a5674b10ac7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 17 May 2013 22:09:23 +0100 Subject: Minor mixin refactoring --- rest_framework/mixins.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 55d21a70..f3cd5868 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -43,7 +43,6 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None) class CreateModelMixin(object): """ Create a model instance. - Should be mixed in with any `GenericAPIView`. """ def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.DATA, files=request.FILES) @@ -68,7 +67,6 @@ class CreateModelMixin(object): class ListModelMixin(object): """ List a queryset. - Should be mixed in with `MultipleObjectAPIView`. """ empty_error = "Empty list and '%(class_name)s.allow_empty' is False." @@ -101,7 +99,6 @@ class ListModelMixin(object): class RetrieveModelMixin(object): """ Retrieve a model instance. - Should be mixed in with `SingleObjectAPIView`. """ def retrieve(self, request, *args, **kwargs): self.object = self.get_object() @@ -112,17 +109,22 @@ class RetrieveModelMixin(object): class UpdateModelMixin(object): """ Update a model instance. - Should be mixed in with `SingleObjectAPIView`. """ - def update(self, request, *args, **kwargs): - partial = kwargs.pop('partial', False) - self.object = None + def get_object_or_none(self): try: - self.object = self.get_object() + return self.get_object() except Http404: # If this is a PUT-as-create operation, we need to ensure that # we have relevant permissions, as if this was a POST request. - self.check_permissions(clone_request(request, 'POST')) + # This will either raise a PermissionDenied exception, + # or simply return None + self.check_permissions(clone_request(self.request, 'POST')) + + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + self.object = self.get_object_or_none() + + if self.object is None: created = True save_kwargs = {'force_insert': True} success_status_code = status.HTTP_201_CREATED @@ -175,7 +177,6 @@ class UpdateModelMixin(object): class DestroyModelMixin(object): """ Destroy a model instance. - Should be mixed in with `SingleObjectAPIView`. """ def destroy(self, request, *args, **kwargs): obj = self.get_object() -- cgit v1.2.3 From aea040161ae29ec4b5335be5164aa8e5ada506e3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 18 May 2013 09:36:09 +0100 Subject: Forms in Broseable API support dynamic serializers based on request method --- rest_framework/renderers.py | 42 +++++++++++++++++++++++++++++++--------- rest_framework/tests/generics.py | 34 +++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 10 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 1917a080..8361cd40 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -336,7 +336,7 @@ class BrowsableAPIRenderer(BaseRenderer): return # Cannot use form overloading try: - view.check_permissions(clone_request(request, method)) + view.check_permissions(request) except exceptions.APIException: return False # Doesn't have permissions return True @@ -372,6 +372,30 @@ class BrowsableAPIRenderer(BaseRenderer): return fields + def _get_form(self, view, method, request): + # We need to impersonate a request with the correct method, + # so that eg. any dynamic get_serializer_class methods return the + # correct form for each method. + restore = view.request + request = clone_request(request, method) + view.request = request + try: + return self.get_form(view, method, request) + finally: + view.request = restore + + def _get_raw_data_form(self, view, method, request, media_types): + # We need to impersonate a request with the correct method, + # so that eg. any dynamic get_serializer_class methods return the + # correct form for each method. + restore = view.request + request = clone_request(request, method) + view.request = request + try: + return self.get_raw_data_form(view, method, request, media_types) + finally: + view.request = restore + def get_form(self, view, method, request): """ Get a form, possibly bound to either the input or output data. @@ -465,15 +489,15 @@ class BrowsableAPIRenderer(BaseRenderer): renderer = self.get_default_renderer(view) content = self.get_content(renderer, data, accepted_media_type, renderer_context) - put_form = self.get_form(view, 'PUT', request) - post_form = self.get_form(view, 'POST', request) - patch_form = self.get_form(view, 'PATCH', request) - delete_form = self.get_form(view, 'DELETE', request) - options_form = self.get_form(view, 'OPTIONS', request) + put_form = self._get_form(view, 'PUT', request) + post_form = self._get_form(view, 'POST', request) + patch_form = self._get_form(view, 'PATCH', request) + delete_form = self._get_form(view, 'DELETE', request) + options_form = self._get_form(view, 'OPTIONS', request) - raw_data_put_form = self.get_raw_data_form(view, 'PUT', request, media_types) - raw_data_post_form = self.get_raw_data_form(view, 'POST', request, media_types) - raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request, media_types) + raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types) + raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types) + raw_data_patch_form = self._get_raw_data_form(view, 'PATCH', request, media_types) raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form name = self.get_name(view) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 2799d143..15d87e86 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals from django.db import models from django.shortcuts import get_object_or_404 from django.test import TestCase -from rest_framework import generics, serializers, status +from rest_framework import generics, renderers, serializers, status from rest_framework.tests.utils import RequestFactory from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel from rest_framework.compat import six @@ -476,3 +476,35 @@ class TestFilterBackendAppliedToViews(TestCase): response = instance_view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foo'}) + + +class TwoFieldModel(models.Model): + field_a = models.CharField(max_length=100) + field_b = models.CharField(max_length=100) + + +class DynamicSerializerView(generics.ListCreateAPIView): + model = TwoFieldModel + renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) + + def get_serializer_class(self): + if self.request.method == 'POST': + class DynamicSerializer(serializers.ModelSerializer): + class Meta: + model = TwoFieldModel + fields = ('field_b',) + return DynamicSerializer + return super(DynamicSerializerView, self).get_serializer_class() + + +class TestFilterBackendAppliedToViews(TestCase): + + def test_dynamic_serializer_form_in_browsable_api(self): + """ + GET requests to ListCreateAPIView should return filtered list. + """ + view = DynamicSerializerView.as_view() + request = factory.get('/') + response = view(request).render() + self.assertContains(response, 'field_b') + self.assertNotContains(response, 'field_a') -- cgit v1.2.3 From ed0bd195f58ae6c0502f9c54cbd34681875adb14 Mon Sep 17 00:00:00 2001 From: Xavier Ordoquy Date: Sat, 18 May 2013 12:07:44 +0200 Subject: Updated the dependencies version and added the ALLOWED_HOSTS for tests. --- rest_framework/runtests/settings.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index 9b519f27..9dd7b545 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -4,6 +4,8 @@ DEBUG = True TEMPLATE_DEBUG = DEBUG DEBUG_PROPAGATE_EXCEPTIONS = True +ALLOWED_HOSTS = ['*'] + ADMINS = ( # ('Your Name', 'your_email@domain.com'), ) -- cgit v1.2.3 From 0cd7c80e6eaf3ca17d0fb8f8878054ce570e3932 Mon Sep 17 00:00:00 2001 From: Craig de Stigter Date: Sat, 18 May 2013 12:16:30 +0200 Subject: add tests for related field source for RelatedField and PrimaryKeyRelatedField. #694 --- rest_framework/tests/relations.py | 37 +++++++++++++++++++++++++++++++++- rest_framework/tests/relations_pk.py | 39 +++++++++++++++++++++++++++++++++++- rest_framework/tests/serializer.py | 17 ---------------- 3 files changed, 74 insertions(+), 19 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py index cbf93c65..f28f0de9 100644 --- a/rest_framework/tests/relations.py +++ b/rest_framework/tests/relations.py @@ -5,6 +5,7 @@ from __future__ import unicode_literals from django.db import models from django.test import TestCase from rest_framework import serializers +from rest_framework.tests.models import BlogPost class NullModel(models.Model): @@ -33,7 +34,7 @@ class FieldTests(TestCase): self.assertRaises(serializers.ValidationError, field.from_native, []) -class TestManyRelateMixin(TestCase): +class TestManyRelatedMixin(TestCase): def test_missing_many_to_many_related_field(self): ''' Regression test for #632 @@ -45,3 +46,37 @@ class TestManyRelateMixin(TestCase): into = {} field.field_from_native({}, None, 'field_name', into) self.assertEqual(into['field_name'], []) + + +# Regression tests for #694 (`source` attribute on related fields) + +class RelatedFieldSourceTests(TestCase): + def test_related_manager_source(self): + """ + Relational fields should be able to use manager-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.RelatedField(many=True, source='get_blogposts_manager') + + class ClassWithManagerMethod(object): + def get_blogposts_manager(self): + return BlogPost.objects + + obj = ClassWithManagerMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['BlogPost object']) + + def test_related_queryset_source(self): + """ + Relational fields should be able to use queryset-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.RelatedField(many=True, source='get_blogposts_queryset') + + class ClassWithQuerysetMethod(object): + def get_blogposts_queryset(self): + return BlogPost.objects.all() + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['BlogPost object']) diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index 5ce8b567..51fe59e9 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -1,7 +1,10 @@ from __future__ import unicode_literals from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource +from rest_framework.tests.models import ( + BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, + NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource, +) from rest_framework.compat import six @@ -421,3 +424,37 @@ class PKNullableOneToOneTests(TestCase): {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ] self.assertEqual(serializer.data, expected) + + +# Regression tests for #694 (`source` attribute on related fields) + +class PrimaryKeyRelatedFieldSourceTests(TestCase): + def test_related_manager_source(self): + """ + Relational fields should be able to use manager-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager') + + class ClassWithManagerMethod(object): + def get_blogposts_manager(self): + return BlogPost.objects + + obj = ClassWithManagerMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, [1]) + + def test_related_queryset_source(self): + """ + Relational fields should be able to use queryset-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset') + + class ClassWithQuerysetMethod(object): + def get_blogposts_queryset(self): + return BlogPost.objects.all() + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, [1]) diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index db3881f9..34acbaab 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -871,23 +871,6 @@ class RelatedTraversalTest(TestCase): self.assertEqual(serializer.data, expected) - def test_queryset_nested_traversal(self): - """ - Relational fields should be able to use methods as their source. - """ - BlogPost.objects.create(title='blah') - - class QuerysetMethodSerializer(serializers.Serializer): - blogposts = serializers.RelatedField(many=True, source='get_all_blogposts') - - class ClassWithQuerysetMethod(object): - def get_all_blogposts(self): - return BlogPost.objects - - obj = ClassWithQuerysetMethod() - serializer = QuerysetMethodSerializer(obj) - self.assertEqual(serializer.data, {'blogposts': ['BlogPost object']}) - class SerializerMethodFieldTests(TestCase): def setUp(self): -- cgit v1.2.3 From 930bd4d0e1f9a74ee56a57ef36c93b1c1d124f91 Mon Sep 17 00:00:00 2001 From: Craig de Stigter Date: Sat, 18 May 2013 12:23:12 +0200 Subject: add tests for related field source for HyperlinkedRelatedField. #694 --- rest_framework/tests/relations_hyperlink.py | 48 +++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index b1eed9a7..8fb4687f 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -4,6 +4,7 @@ from django.test.client import RequestFactory from rest_framework import serializers from rest_framework.compat import patterns, url from rest_framework.tests.models import ( + BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource ) @@ -16,6 +17,7 @@ def dummy_view(request, pk): pass urlpatterns = patterns('', + url(r'^dummyurl/(?P[0-9]+)/$', dummy_view, name='dummy-url'), url(r'^manytomanysource/(?P[0-9]+)/$', dummy_view, name='manytomanysource-detail'), url(r'^manytomanytarget/(?P[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), url(r'^foreignkeysource/(?P[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), @@ -451,3 +453,49 @@ class HyperlinkedNullableOneToOneTests(TestCase): {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ] self.assertEqual(serializer.data, expected) + + +# Regression tests for #694 (`source` attribute on related fields) + +class HyperlinkedRelatedFieldSourceTests(TestCase): + urls = 'rest_framework.tests.relations_hyperlink' + + def test_related_manager_source(self): + """ + Relational fields should be able to use manager-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.HyperlinkedRelatedField( + many=True, + source='get_blogposts_manager', + view_name='dummy-url', + ) + field.context = {'request': request} + + class ClassWithManagerMethod(object): + def get_blogposts_manager(self): + return BlogPost.objects + + obj = ClassWithManagerMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['http://testserver/dummyurl/1/']) + + def test_related_queryset_source(self): + """ + Relational fields should be able to use queryset-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.HyperlinkedRelatedField( + many=True, + source='get_blogposts_queryset', + view_name='dummy-url', + ) + field.context = {'request': request} + + class ClassWithQuerysetMethod(object): + def get_blogposts_queryset(self): + return BlogPost.objects.all() + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['http://testserver/dummyurl/1/']) -- cgit v1.2.3 From a73c16b85f79aeb9139734a64623b49bc169fce9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 18 May 2013 11:27:48 +0100 Subject: serializers.Field respects ordering on dicts if it exists. Closes #832 --- rest_framework/fields.py | 7 ++++++- rest_framework/tests/fields.py | 19 ++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c83ee5ec..49d2a6d5 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -19,6 +19,7 @@ from django import forms from django.forms import widgets from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ +from django.utils.datastructures import SortedDict from rest_framework import ISO_8601 from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time @@ -170,7 +171,11 @@ class Field(object): elif hasattr(value, '__iter__') and not isinstance(value, (dict, six.string_types)): return [self.to_native(item) for item in value] elif isinstance(value, dict): - return dict(map(self.to_native, (k, v)) for k, v in value.items()) + # Make sure we preserve field ordering, if it exists + ret = SortedDict() + for key, val in value.items(): + ret[key] = self.to_native(val) + return ret return smart_text(value) def attributes(self): diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index 3cdfa0f6..5b5ce835 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -2,13 +2,12 @@ General serializer field tests. """ from __future__ import unicode_literals +from django.utils.datastructures import SortedDict import datetime from decimal import Decimal - from django.db import models from django.test import TestCase from django.core import validators - from rest_framework import serializers from rest_framework.serializers import Serializer @@ -63,6 +62,20 @@ class BasicFieldTests(TestCase): serializer = CharPrimaryKeyModelSerializer() self.assertEqual(serializer.fields['id'].read_only, False) + def test_dict_field_ordering(self): + """ + Field should preserve dictionary ordering, if it exists. + See: https://github.com/tomchristie/django-rest-framework/issues/832 + """ + ret = SortedDict() + ret['c'] = 1 + ret['b'] = 1 + ret['a'] = 1 + ret['z'] = 1 + field = serializers.Field() + keys = list(field.to_native(ret).keys()) + self.assertEqual(keys, ['c', 'b', 'a', 'z']) + class DateFieldTest(TestCase): """ @@ -645,4 +658,4 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '12345.6'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) -- cgit v1.2.3 From c992b600f7b0aefb156cddb5e27b438ccc316b39 Mon Sep 17 00:00:00 2001 From: Craig de Stigter Date: Sat, 18 May 2013 12:32:48 +0200 Subject: add tests for dotted lookup in RelatedField, PrimaryKeyRelatedField, and HyperlinkedRelatedField. #694 --- rest_framework/tests/relations.py | 18 ++++++++++++++++++ rest_framework/tests/relations_hyperlink.py | 23 +++++++++++++++++++++++ rest_framework/tests/relations_pk.py | 18 ++++++++++++++++++ 3 files changed, 59 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py index f28f0de9..d19219c9 100644 --- a/rest_framework/tests/relations.py +++ b/rest_framework/tests/relations.py @@ -80,3 +80,21 @@ class RelatedFieldSourceTests(TestCase): obj = ClassWithQuerysetMethod() value = field.field_to_native(obj, 'field_name') self.assertEqual(value, ['BlogPost object']) + + def test_dotted_source(self): + """ + Source argument should support dotted.source notation. + """ + BlogPost.objects.create(title='blah') + field = serializers.RelatedField(many=True, source='a.b.c') + + class ClassWithQuerysetMethod(object): + a = { + 'b': { + 'c': BlogPost.objects.all() + } + } + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['BlogPost object']) diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index 8fb4687f..b3efbf52 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -499,3 +499,26 @@ class HyperlinkedRelatedFieldSourceTests(TestCase): obj = ClassWithQuerysetMethod() value = field.field_to_native(obj, 'field_name') self.assertEqual(value, ['http://testserver/dummyurl/1/']) + + def test_dotted_source(self): + """ + Source argument should support dotted.source notation. + """ + BlogPost.objects.create(title='blah') + field = serializers.HyperlinkedRelatedField( + many=True, + source='a.b.c', + view_name='dummy-url', + ) + field.context = {'request': request} + + class ClassWithQuerysetMethod(object): + a = { + 'b': { + 'c': BlogPost.objects.all() + } + } + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['http://testserver/dummyurl/1/']) diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index 51fe59e9..0f8c5247 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -458,3 +458,21 @@ class PrimaryKeyRelatedFieldSourceTests(TestCase): obj = ClassWithQuerysetMethod() value = field.field_to_native(obj, 'field_name') self.assertEqual(value, [1]) + + def test_dotted_source(self): + """ + Source argument should support dotted.source notation. + """ + BlogPost.objects.create(title='blah') + field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c') + + class ClassWithQuerysetMethod(object): + a = { + 'b': { + 'c': BlogPost.objects.all() + } + } + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, [1]) -- cgit v1.2.3 From de5cc8de423a22009d2a643f6c268805f715b212 Mon Sep 17 00:00:00 2001 From: Pablo Recio Date: Sat, 18 May 2013 12:40:25 +0200 Subject: A model's field is required if is null or blank --- rest_framework/serializers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 7707de7a..500bb306 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -705,15 +705,14 @@ class ModelSerializer(Serializer): Creates a default instance of a basic non-relational field. """ kwargs = {} - has_default = model_field.has_default() - if model_field.null or model_field.blank or has_default: + if model_field.null or model_field.blank: kwargs['required'] = False if isinstance(model_field, models.AutoField) or not model_field.editable: kwargs['read_only'] = True - if has_default: + if model_field.has_default(): kwargs['default'] = model_field.get_default() if issubclass(model_field.__class__, models.TextField): -- cgit v1.2.3 From ab8bd566f9db327a4c463317011818d421bbf89c Mon Sep 17 00:00:00 2001 From: Pablo Recio Date: Sat, 18 May 2013 12:40:25 +0200 Subject: Adding `BLANK_CHOICE_DASH` as a choice if the model's field isn't required --- rest_framework/fields.py | 3 +++ rest_framework/tests/fields.py | 28 +++++++++++++++++++++++- rest_framework/tests/models.py | 26 ++++++++++++++++++++++ rest_framework/tests/serializer.py | 45 +++++++++++++++++++++++++++++++++++++- 4 files changed, 100 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c83ee5ec..7fd4c638 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -15,6 +15,7 @@ import warnings from django.core import validators from django.core.exceptions import ValidationError from django.conf import settings +from django.db.models.fields import BLANK_CHOICE_DASH from django import forms from django.forms import widgets from django.utils.encoding import is_protected_type @@ -402,6 +403,8 @@ class ChoiceField(WritableField): def __init__(self, choices=(), *args, **kwargs): super(ChoiceField, self).__init__(*args, **kwargs) self.choices = choices + if not self.required: + self.choices = BLANK_CHOICE_DASH + self.choices def _get_choices(self): return self._choices diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index 3cdfa0f6..f313ba60 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -645,4 +645,30 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '12345.6'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) + + +class ChoiceFieldTests(TestCase): + """ + Tests for the ChoiceField options generator + """ + + SAMPLE_CHOICES = [ + ('red', 'Red'), + ('green', 'Green'), + ('blue', 'Blue'), + ] + + def test_choices_required(self): + """ + Make sure proper choices are rendered if field is required + """ + f = serializers.ChoiceField(required=True, choices=self.SAMPLE_CHOICES) + self.assertEqual(f.choices, self.SAMPLE_CHOICES) + + def test_choices_not_required(self): + """ + Make sure proper choices (plus blank) are rendered if the field isn't required + """ + f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES) + self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES) diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 40e41a64..5d98b04b 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -117,6 +117,32 @@ class OptionalRelationModel(RESTFrameworkModel): other = models.ForeignKey('OptionalRelationModel', blank=True, null=True) +# Model for issue #725 +class SeveralChoicesModel(RESTFrameworkModel): + color = models.CharField( + max_length=10, + choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')], + blank=False + ) + drink = models.CharField( + max_length=10, + choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')], + blank=False, + default='beer' + ) + os = models.CharField( + max_length=10, + choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')], + blank=True + ) + music_genre = models.CharField( + max_length=10, + choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')], + blank=True, + default='metal' + ) + + # Model for RegexField class Book(RESTFrameworkModel): isbn = models.CharField(max_length=13) diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index db3881f9..3f39308d 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -1,10 +1,11 @@ from __future__ import unicode_literals +from django.db.models.fields import BLANK_CHOICE_DASH from django.utils.datastructures import MultiValueDict from django.test import TestCase from rest_framework import serializers from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, - ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) + ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, SeveralChoicesModel) import datetime import pickle @@ -1018,6 +1019,48 @@ class SerializerPickleTests(TestCase): repr(pickle.loads(pickle.dumps(data, 0))) +# test for issue #725 +class SerializerChoiceFields(TestCase): + + def setUp(self): + super(SerializerChoiceFields, self).setUp() + + class SeveralChoicesSerializer(serializers.ModelSerializer): + class Meta: + model = SeveralChoicesModel + fields = ('color', 'drink', 'os', 'music_genre') + + self.several_choices_serializer = SeveralChoicesSerializer + + def test_choices_blank_false_not_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['color'].choices, + [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')] + ) + + def test_choices_blank_false_with_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['drink'].choices, + [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')] + ) + + def test_choices_blank_true_not_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['os'].choices, + BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')] + ) + + def test_choices_blank_true_with_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['music_genre'].choices, + BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')] + ) + + class DepthTest(TestCase): def test_implicit_nesting(self): -- cgit v1.2.3 From 8fe43236a22e56d1741b49b92f0c53e01cd9e5f6 Mon Sep 17 00:00:00 2001 From: Pablo Recio Date: Sat, 18 May 2013 13:23:38 +0200 Subject: Moved test model into closer to the testcase --- rest_framework/tests/models.py | 26 -------------------------- rest_framework/tests/serializer.py | 28 +++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 27 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 5d98b04b..40e41a64 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -117,32 +117,6 @@ class OptionalRelationModel(RESTFrameworkModel): other = models.ForeignKey('OptionalRelationModel', blank=True, null=True) -# Model for issue #725 -class SeveralChoicesModel(RESTFrameworkModel): - color = models.CharField( - max_length=10, - choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')], - blank=False - ) - drink = models.CharField( - max_length=10, - choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')], - blank=False, - default='beer' - ) - os = models.CharField( - max_length=10, - choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')], - blank=True - ) - music_genre = models.CharField( - max_length=10, - choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')], - blank=True, - default='metal' - ) - - # Model for RegexField class Book(RESTFrameworkModel): isbn = models.CharField(max_length=13) diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 85b95283..c043f417 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -1,11 +1,12 @@ from __future__ import unicode_literals +from django.db import models from django.db.models.fields import BLANK_CHOICE_DASH from django.utils.datastructures import MultiValueDict from django.test import TestCase from rest_framework import serializers from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, - ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, SeveralChoicesModel) + ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) import datetime import pickle @@ -1003,6 +1004,31 @@ class SerializerPickleTests(TestCase): # test for issue #725 +class SeveralChoicesModel(models.Model): + color = models.CharField( + max_length=10, + choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')], + blank=False + ) + drink = models.CharField( + max_length=10, + choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')], + blank=False, + default='beer' + ) + os = models.CharField( + max_length=10, + choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')], + blank=True + ) + music_genre = models.CharField( + max_length=10, + choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')], + blank=True, + default='metal' + ) + + class SerializerChoiceFields(TestCase): def setUp(self): -- cgit v1.2.3