diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/__init__.py | 2 | ||||
| -rw-r--r-- | rest_framework/authentication.py | 80 | ||||
| -rw-r--r-- | rest_framework/decorators.py | 9 | ||||
| -rw-r--r-- | rest_framework/exceptions.py | 16 | ||||
| -rw-r--r-- | rest_framework/pagination.py | 26 | ||||
| -rw-r--r-- | rest_framework/relations.py | 20 | ||||
| -rw-r--r-- | rest_framework/request.py | 29 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 12 | ||||
| -rw-r--r-- | rest_framework/tests/authentication.py | 41 | ||||
| -rw-r--r-- | rest_framework/tests/decorators.py | 22 | ||||
| -rw-r--r-- | rest_framework/tests/genericrelations.py | 96 | ||||
| -rw-r--r-- | rest_framework/tests/models.py | 21 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 43 | ||||
| -rw-r--r-- | rest_framework/tests/relations_hyperlink.py | 9 | ||||
| -rw-r--r-- | rest_framework/tests/relations_pk.py | 7 | ||||
| -rw-r--r-- | rest_framework/tests/relations_slug.py | 162 | ||||
| -rw-r--r-- | rest_framework/tests/serializer.py | 88 | ||||
| -rw-r--r-- | rest_framework/tests/urlpatterns.py | 78 | ||||
| -rw-r--r-- | rest_framework/urlpatterns.py | 45 | ||||
| -rw-r--r-- | rest_framework/views.py | 21 |
20 files changed, 660 insertions, 167 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index bc267fad..f9882c57 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,3 +1,3 @@ -__version__ = '2.1.16' +__version__ = '2.1.17' VERSION = __version__ # synonym diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 30c78ebc..fc169189 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -21,32 +21,46 @@ class BaseAuthentication(object): """ raise NotImplementedError(".authenticate() must be overridden.") + def authenticate_header(self, request): + """ + Return a string to be used as the value of the `WWW-Authenticate` + header in a `401 Unauthenticated` response, or `None` if the + authentication scheme should return `403 Permission Denied` responses. + """ + pass + class BasicAuthentication(BaseAuthentication): """ HTTP Basic authentication against username/password. """ + www_authenticate_realm = 'api' def authenticate(self, request): """ Returns a `User` if a correct username and password have been supplied using HTTP Basic authentication. Otherwise returns `None`. """ - if 'HTTP_AUTHORIZATION' in request.META: - auth = request.META['HTTP_AUTHORIZATION'].split() - if len(auth) == 2 and auth[0].lower() == "basic": - try: - auth_parts = base64.b64decode(auth[1]).partition(':') - except TypeError: - return None - - try: - userid = smart_unicode(auth_parts[0]) - password = smart_unicode(auth_parts[2]) - except DjangoUnicodeDecodeError: - return None - - return self.authenticate_credentials(userid, password) + auth = request.META.get('HTTP_AUTHORIZATION', '').split() + + if not auth or auth[0].lower() != "basic": + return None + + if len(auth) != 2: + raise exceptions.AuthenticationFailed('Invalid basic header') + + try: + auth_parts = base64.b64decode(auth[1]).partition(':') + except TypeError: + raise exceptions.AuthenticationFailed('Invalid basic header') + + try: + userid = smart_unicode(auth_parts[0]) + password = smart_unicode(auth_parts[2]) + except DjangoUnicodeDecodeError: + raise exceptions.AuthenticationFailed('Invalid basic header') + + return self.authenticate_credentials(userid, password) def authenticate_credentials(self, userid, password): """ @@ -55,6 +69,10 @@ class BasicAuthentication(BaseAuthentication): user = authenticate(username=userid, password=password) if user is not None and user.is_active: return (user, None) + raise exceptions.AuthenticationFailed('Invalid username/password') + + def authenticate_header(self, request): + return 'Basic realm="%s"' % self.www_authenticate_realm class SessionAuthentication(BaseAuthentication): @@ -74,7 +92,7 @@ class SessionAuthentication(BaseAuthentication): # Unauthenticated, CSRF validation not required if not user or not user.is_active: - return + return None # Enforce CSRF validation for session based authentication. class CSRFCheck(CsrfViewMiddleware): @@ -85,7 +103,7 @@ class SessionAuthentication(BaseAuthentication): reason = CSRFCheck().process_view(http_request, None, (), {}) if reason: # CSRF failed, bail with explicit error message - raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) + raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason) # CSRF passed with authenticated user return (user, None) @@ -112,14 +130,26 @@ class TokenAuthentication(BaseAuthentication): def authenticate(self, request): auth = request.META.get('HTTP_AUTHORIZATION', '').split() - if len(auth) == 2 and auth[0].lower() == "token": - key = auth[1] - try: - token = self.model.objects.get(key=key) - except self.model.DoesNotExist: - return None + if not auth or auth[0].lower() != "token": + return None + + if len(auth) != 2: + raise exceptions.AuthenticationFailed('Invalid token header') + + return self.authenticate_credentials(auth[1]) + + def authenticate_credentials(self, key): + try: + token = self.model.objects.get(key=key) + except self.model.DoesNotExist: + raise exceptions.AuthenticationFailed('Invalid token') + + if token.user.is_active: + return (token.user, token) + raise exceptions.AuthenticationFailed('User inactive or deleted') + + def authenticate_header(self, request): + return 'Token' - if token.user.is_active: - return (token.user, token) # TODO: OAuthAuthentication diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 1b710a03..7a4103e1 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,4 +1,5 @@ from rest_framework.views import APIView +import types def api_view(http_method_names): @@ -23,6 +24,14 @@ def api_view(http_method_names): # pass # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this + # api_view applied without (method_names) + assert not(isinstance(http_method_names, types.FunctionType)), \ + '@api_view missing list of allowed HTTP methods' + + # api_view applied with eg. string instead of list of strings + assert isinstance(http_method_names, (list, tuple)), \ + '@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__ + allowed_methods = set(http_method_names) | set(('options',)) WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 89479deb..d635351c 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -23,6 +23,22 @@ class ParseError(APIException): self.detail = detail or self.default_detail +class AuthenticationFailed(APIException): + status_code = status.HTTP_401_UNAUTHORIZED + default_detail = 'Incorrect authentication credentials.' + + def __init__(self, detail=None): + self.detail = detail or self.default_detail + + +class NotAuthenticated(APIException): + status_code = status.HTTP_401_UNAUTHORIZED + default_detail = 'Authentication credentials were not provided.' + + def __init__(self, detail=None): + self.detail = detail or self.default_detail + + class PermissionDenied(APIException): status_code = status.HTTP_403_FORBIDDEN default_detail = 'You do not have permission to perform this action.' diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index d241ade7..92d41e0e 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -34,6 +34,17 @@ class PreviousPageField(serializers.Field): return replace_query_param(url, self.page_field, page) +class DefaultObjectSerializer(serializers.Field): + """ + If no object serializer is specified, then this serializer will be applied + as the default. + """ + + def __init__(self, source=None, context=None): + # Note: Swallow context kwarg - only required for eg. ModelSerializer. + super(DefaultObjectSerializer, self).__init__(source=source) + + class PaginationSerializerOptions(serializers.SerializerOptions): """ An object that stores the options that may be provided to a @@ -44,7 +55,7 @@ class PaginationSerializerOptions(serializers.SerializerOptions): def __init__(self, meta): super(PaginationSerializerOptions, self).__init__(meta) self.object_serializer_class = getattr(meta, 'object_serializer_class', - serializers.Field) + DefaultObjectSerializer) class BasePaginationSerializer(serializers.Serializer): @@ -62,14 +73,13 @@ class BasePaginationSerializer(serializers.Serializer): super(BasePaginationSerializer, self).__init__(*args, **kwargs) results_field = self.results_field object_serializer = self.opts.object_serializer_class - self.fields[results_field] = object_serializer(source='object_list') - def to_native(self, obj): - """ - Prevent default behaviour of iterating over elements, and serializing - each in turn. - """ - return self.convert_object(obj) + if 'context' in kwargs: + context_kwarg = {'context': kwargs['context']} + else: + context_kwarg = {} + + self.fields[results_field] = object_serializer(source='object_list', **context_kwarg) class PaginationSerializer(BasePaginationSerializer): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 7ded3891..af63ceaa 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -177,7 +177,7 @@ class PrimaryKeyRelatedField(RelatedField): default_error_messages = { 'does_not_exist': _("Invalid pk '%s' - object does not exist."), - 'invalid': _('Invalid value.'), + 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'), } # TODO: Remove these field hacks... @@ -208,7 +208,8 @@ class PrimaryKeyRelatedField(RelatedField): msg = self.error_messages['does_not_exist'] % smart_unicode(data) raise ValidationError(msg) except (TypeError, ValueError): - msg = self.error_messages['invalid'] + received = type(data).__name__ + msg = self.error_messages['incorrect_type'] % received raise ValidationError(msg) def field_to_native(self, obj, field_name): @@ -235,7 +236,7 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField): default_error_messages = { 'does_not_exist': _("Invalid pk '%s' - object does not exist."), - 'invalid': _('Invalid value.'), + 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'), } def prepare_value(self, obj): @@ -275,7 +276,8 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField): msg = self.error_messages['does_not_exist'] % smart_unicode(data) raise ValidationError(msg) except (TypeError, ValueError): - msg = self.error_messages['invalid'] + received = type(data).__name__ + msg = self.error_messages['incorrect_type'] % received raise ValidationError(msg) ### Slug relationships @@ -333,7 +335,7 @@ class HyperlinkedRelatedField(RelatedField): 'incorrect_match': _('Invalid hyperlink - Incorrect URL match'), 'configuration_error': _('Invalid hyperlink due to configuration error'), 'does_not_exist': _("Invalid hyperlink - object does not exist."), - 'invalid': _('Invalid value.'), + 'incorrect_type': _('Incorrect type. Expected url string, received %s.'), } def __init__(self, *args, **kwargs): @@ -397,8 +399,8 @@ class HyperlinkedRelatedField(RelatedField): try: http_prefix = value.startswith('http:') or value.startswith('https:') except AttributeError: - msg = self.error_messages['invalid'] - raise ValidationError(msg) + msg = self.error_messages['incorrect_type'] + raise ValidationError(msg % type(value).__name__) if http_prefix: # If needed convert absolute URLs to relative path @@ -434,8 +436,8 @@ class HyperlinkedRelatedField(RelatedField): except ObjectDoesNotExist: raise ValidationError(self.error_messages['does_not_exist']) except (TypeError, ValueError): - msg = self.error_messages['invalid'] - raise ValidationError(msg) + msg = self.error_messages['incorrect_type'] + raise ValidationError(msg % type(value).__name__) return obj diff --git a/rest_framework/request.py b/rest_framework/request.py index b7133608..1c28cd17 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -86,6 +86,7 @@ class Request(object): self._method = Empty self._content_type = Empty self._stream = Empty + self._authenticator = None if self.parser_context is None: self.parser_context = {} @@ -166,7 +167,7 @@ class Request(object): by the authentication classes provided to the request. """ if not hasattr(self, '_user'): - self._user, self._auth = self._authenticate() + self._authenticator, self._user, self._auth = self._authenticate() return self._user @user.setter @@ -185,7 +186,7 @@ class Request(object): request, such as an authentication token. """ if not hasattr(self, '_auth'): - self._user, self._auth = self._authenticate() + self._authenticator, self._user, self._auth = self._authenticate() return self._auth @auth.setter @@ -196,6 +197,14 @@ class Request(object): """ self._auth = value + @property + def successful_authenticator(self): + """ + Return the instance of the authentication instance class that was used + to authenticate the request, or `None`. + """ + return self._authenticator + def _load_data_and_files(self): """ Parses the request content into self.DATA and self.FILES. @@ -299,21 +308,23 @@ class Request(object): def _authenticate(self): """ - Attempt to authenticate the request using each authentication instance in turn. - Returns a two-tuple of (user, authtoken). + Attempt to authenticate the request using each authentication instance + in turn. + Returns a three-tuple of (authenticator, user, authtoken). """ for authenticator in self.authenticators: user_auth_tuple = authenticator.authenticate(self) if not user_auth_tuple is None: - return user_auth_tuple + user, auth = user_auth_tuple + return (authenticator, user, auth) return self._not_authenticated() def _not_authenticated(self): """ - Return a two-tuple of (user, authtoken), representing an - unauthenticated request. + Return a three-tuple of (authenticator, user, authtoken), representing + an unauthenticated request. - By default this will be (AnonymousUser, None). + By default this will be (None, AnonymousUser, None). """ if api_settings.UNAUTHENTICATED_USER: user = api_settings.UNAUTHENTICATED_USER() @@ -325,7 +336,7 @@ class Request(object): else: auth = None - return (user, auth) + return (None, user, auth) def __getattr__(self, attr): """ diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 27458f96..6ecc7b45 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -2,6 +2,7 @@ import copy import datetime import types from decimal import Decimal +from django.core.paginator import Page from django.db import models from django.forms import widgets from django.utils.datastructures import SortedDict @@ -227,6 +228,8 @@ class BaseSerializer(Field): Run `validate_<fieldname>()` and `validate()` methods on the serializer """ for field_name, field in self.fields.items(): + if field_name in self._errors: + continue try: validate_method = getattr(self, 'validate_%s' % field_name, None) if validate_method: @@ -271,7 +274,11 @@ class BaseSerializer(Field): """ Serialize objects -> primitives. """ - if hasattr(obj, '__iter__'): + # Note: At the moment we have an ugly hack to determine if we should + # walk over iterables. At some point, serializers will require an + # explicit `many=True` in order to iterate over a set, and this hack + # will disappear. + if hasattr(obj, '__iter__') and not isinstance(obj, Page): return [self.convert_object(item) for item in obj] return self.convert_object(obj) @@ -298,6 +305,9 @@ class BaseSerializer(Field): Override default so that we can apply ModelSerializer as a nested field to relationships. """ + if self.source == '*': + return self.to_native(obj) + try: if self.source: for component in self.source.split('.'): diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index e86041bc..1f17e8d2 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -4,7 +4,7 @@ from django.test import Client, TestCase from rest_framework import permissions from rest_framework.authtoken.models import Token -from rest_framework.authentication import TokenAuthentication +from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication from rest_framework.compat import patterns from rest_framework.views import APIView @@ -21,10 +21,10 @@ class MockView(APIView): def put(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) -MockView.authentication_classes += (TokenAuthentication,) - urlpatterns = patterns('', - (r'^$', MockView.as_view()), + (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), + (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), + (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), ) @@ -43,24 +43,25 @@ class BasicAuthTests(TestCase): def test_post_form_passing_basic_auth(self): """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() - response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) def test_post_json_passing_basic_auth(self): """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) def test_post_form_failing_basic_auth(self): """Ensure POSTing form over basic auth without correct credentials fails""" - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/basic/', {'example': 'example'}) + self.assertEqual(response.status_code, 401) def test_post_json_failing_basic_auth(self): """Ensure POSTing json over basic auth without correct credentials fails""" - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, 401) + self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') class SessionAuthTests(TestCase): @@ -83,7 +84,7 @@ class SessionAuthTests(TestCase): Ensure POSTing form over session authentication without CSRF token fails. """ self.csrf_client.login(username=self.username, password=self.password) - response = self.csrf_client.post('/', {'example': 'example'}) + response = self.csrf_client.post('/session/', {'example': 'example'}) self.assertEqual(response.status_code, 403) def test_post_form_session_auth_passing(self): @@ -91,7 +92,7 @@ class SessionAuthTests(TestCase): Ensure POSTing form over session authentication with logged in user and CSRF token passes. """ self.non_csrf_client.login(username=self.username, password=self.password) - response = self.non_csrf_client.post('/', {'example': 'example'}) + response = self.non_csrf_client.post('/session/', {'example': 'example'}) self.assertEqual(response.status_code, 200) def test_put_form_session_auth_passing(self): @@ -99,14 +100,14 @@ class SessionAuthTests(TestCase): Ensure PUTting form over session authentication with logged in user and CSRF token passes. """ self.non_csrf_client.login(username=self.username, password=self.password) - response = self.non_csrf_client.put('/', {'example': 'example'}) + response = self.non_csrf_client.put('/session/', {'example': 'example'}) self.assertEqual(response.status_code, 200) def test_post_form_session_auth_failing(self): """ Ensure POSTing form over session authentication without logged in user fails. """ - response = self.csrf_client.post('/', {'example': 'example'}) + response = self.csrf_client.post('/session/', {'example': 'example'}) self.assertEqual(response.status_code, 403) @@ -127,24 +128,24 @@ class TokenAuthTests(TestCase): def test_post_form_passing_token_auth(self): """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) def test_post_json_passing_token_auth(self): """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) def test_post_form_failing_token_auth(self): """Ensure POSTing form over token auth without correct credentials fails""" - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/token/', {'example': 'example'}) + self.assertEqual(response.status_code, 401) def test_post_json_failing_token_auth(self): """Ensure POSTing json over token auth without correct credentials fails""" - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, 401) def test_token_has_auto_assigned_key_if_none_provided(self): """Ensure creating a token with no key will auto-assign a key""" diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index 5e6bce4e..82f912e9 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -28,13 +28,27 @@ class DecoratorTestCase(TestCase): response.request = request return APIView.finalize_response(self, request, response, *args, **kwargs) - def test_wrap_view(self): + def test_api_view_incorrect(self): + """ + If @api_view is not applied correct, we should raise an assertion. + """ - @api_view(['GET']) + @api_view def view(request): - return Response({}) + return Response() + + request = self.factory.get('/') + self.assertRaises(AssertionError, view, request) + + def test_api_view_incorrect_arguments(self): + """ + If @api_view is missing arguments, we should raise an assertion. + """ - self.assertTrue(isinstance(view.cls_instance, APIView)) + with self.assertRaises(AssertionError): + @api_view('GET') + def view(request): + return Response() def test_calling_method(self): diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py index bc7378e1..146ad1e4 100644 --- a/rest_framework/tests/genericrelations.py +++ b/rest_framework/tests/genericrelations.py @@ -1,25 +1,61 @@ +from django.contrib.contenttypes.models import ContentType +from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey +from django.db import models from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import * + + +class Tag(models.Model): + """ + Tags have a descriptive slug, and are attached to an arbitrary object. + """ + tag = models.SlugField() + content_type = models.ForeignKey(ContentType) + object_id = models.PositiveIntegerField() + tagged_item = GenericForeignKey('content_type', 'object_id') + + def __unicode__(self): + return self.tag + + +class Bookmark(models.Model): + """ + A URL bookmark that may have multiple tags attached. + """ + url = models.URLField() + tags = GenericRelation(Tag) + + def __unicode__(self): + return 'Bookmark: %s' % self.url + + +class Note(models.Model): + """ + A textual note that may have multiple tags attached. + """ + text = models.TextField() + tags = GenericRelation(Tag) + + def __unicode__(self): + return 'Note: %s' % self.text class TestGenericRelations(TestCase): def setUp(self): - bookmark = Bookmark(url='https://www.djangoproject.com/') - bookmark.save() - django = Tag(tag_name='django') - django.save() - python = Tag(tag_name='python') - python.save() - t1 = TaggedItem(content_object=bookmark, tag=django) - t1.save() - t2 = TaggedItem(content_object=bookmark, tag=python) - t2.save() - self.bookmark = bookmark - - def test_reverse_generic_relation(self): + self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') + Tag.objects.create(tagged_item=self.bookmark, tag='django') + Tag.objects.create(tagged_item=self.bookmark, tag='python') + self.note = Note.objects.create(text='Remember the milk') + Tag.objects.create(tagged_item=self.note, tag='reminder') + + def test_generic_relation(self): + """ + Test a relationship that spans a GenericRelation field. + IE. A reverse generic relationship. + """ + class BookmarkSerializer(serializers.ModelSerializer): - tags = serializers.ManyRelatedField(source='tags') + tags = serializers.ManyRelatedField() class Meta: model = Bookmark @@ -31,3 +67,33 @@ class TestGenericRelations(TestCase): 'url': u'https://www.djangoproject.com/' } self.assertEquals(serializer.data, expected) + + def test_generic_fk(self): + """ + Test a relationship that spans a GenericForeignKey field. + IE. A forward generic relationship. + """ + + class TagSerializer(serializers.ModelSerializer): + tagged_item = serializers.RelatedField() + + class Meta: + model = Tag + exclude = ('id', 'content_type', 'object_id') + + serializer = TagSerializer(Tag.objects.all()) + expected = [ + { + 'tag': u'django', + 'tagged_item': u'Bookmark: https://www.djangoproject.com/' + }, + { + 'tag': u'python', + 'tagged_item': u'Bookmark: https://www.djangoproject.com/' + }, + { + 'tag': u'reminder', + 'tagged_item': u'Note: Remember the milk' + } + ] + self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 93f09761..9ab15328 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -86,27 +86,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel): text = models.CharField(max_length=100, default='anchor') rel = models.ManyToManyField(Anchor) -# Models to test generic relations - - -class Tag(RESTFrameworkModel): - tag_name = models.SlugField() - - -class TaggedItem(RESTFrameworkModel): - tag = models.ForeignKey(Tag, related_name='items') - content_type = models.ForeignKey(ContentType) - object_id = models.PositiveIntegerField() - content_object = GenericForeignKey('content_type', 'object_id') - - def __unicode__(self): - return self.tag.tag_name - - -class Bookmark(RESTFrameworkModel): - url = models.URLField() - tags = GenericRelation(TaggedItem) - # Model to test filtering. class FilterableItem(RESTFrameworkModel): diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 3b550877..697dfb5b 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -252,6 +252,8 @@ class TestCustomPaginateByParam(TestCase): self.assertEquals(response.data['results'], self.data[:5]) +### Tests for context in pagination serializers + class CustomField(serializers.Field): def to_native(self, value): if not 'view' in self.context: @@ -262,6 +264,11 @@ class CustomField(serializers.Field): class BasicModelSerializer(serializers.Serializer): text = CustomField() + def __init__(self, *args, **kwargs): + super(BasicModelSerializer, self).__init__(*args, **kwargs) + if not 'view' in self.context: + raise RuntimeError("context isn't getting passed into serializer init") + class TestContextPassedToCustomField(TestCase): def setUp(self): @@ -279,3 +286,39 @@ class TestContextPassedToCustomField(TestCase): self.assertEquals(response.status_code, status.HTTP_200_OK) + +### Tests for custom pagination serializers + +class LinksSerializer(serializers.Serializer): + next = pagination.NextPageField(source='*') + prev = pagination.PreviousPageField(source='*') + + +class CustomPaginationSerializer(pagination.BasePaginationSerializer): + links = LinksSerializer(source='*') # Takes the page object as the source + total_results = serializers.Field(source='paginator.count') + + results_field = 'objects' + + +class TestCustomPaginationSerializer(TestCase): + def setUp(self): + objects = ['john', 'paul', 'george', 'ringo'] + paginator = Paginator(objects, 2) + self.page = paginator.page(1) + + def test_custom_pagination_serializer(self): + request = RequestFactory().get('/foobar') + serializer = CustomPaginationSerializer( + instance=self.page, + context={'request': request} + ) + expected = { + 'links': { + 'next': 'http://testserver/foobar?page=2', + 'prev': None + }, + 'total_results': 4, + 'objects': ['john', 'paul'] + } + self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index 7d65eae7..6d137f68 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -215,6 +215,13 @@ class HyperlinkedForeignKeyTests(TestCase): ] self.assertEquals(serializer.data, expected) + def test_foreign_key_update_incorrect_type(self): + data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': 2} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'target': [u'Incorrect type. Expected url string, received int.']}) + def test_reverse_foreign_key_update(self): data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']} instance = ForeignKeyTarget.objects.get(pk=2) @@ -227,7 +234,7 @@ class HyperlinkedForeignKeyTests(TestCase): expected = [ {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, - ] + ] self.assertEquals(new_serializer.data, expected) serializer.save() diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index dd1e86b5..3391e60a 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -194,6 +194,13 @@ class PKForeignKeyTests(TestCase): ] self.assertEquals(serializer.data, expected) + def test_foreign_key_update_incorrect_type(self): + data = {'id': 1, 'name': u'source-1', 'target': 'foo'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'target': [u'Incorrect type. Expected pk value, received str.']}) + def test_reverse_foreign_key_update(self): data = {'id': 2, 'name': u'target-2', 'sources': [1, 3]} instance = ForeignKeyTarget.objects.get(pk=2) diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/relations_slug.py index 503b61e8..37ccc75e 100644 --- a/rest_framework/tests/relations_slug.py +++ b/rest_framework/tests/relations_slug.py @@ -1,9 +1,23 @@ from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import NullableForeignKeySource, ForeignKeyTarget +from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget -class NullableSlugSourceSerializer(serializers.ModelSerializer): +class ForeignKeyTargetSerializer(serializers.ModelSerializer): + sources = serializers.ManySlugRelatedField(slug_field='name') + + class Meta: + model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): + target = serializers.SlugRelatedField(slug_field='name') + + class Meta: + model = ForeignKeySource + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): target = serializers.SlugRelatedField(slug_field='name', null=True) class Meta: @@ -11,6 +25,132 @@ class NullableSlugSourceSerializer(serializers.ModelSerializer): # TODO: M2M Tests, FKTests (Non-nulable), One2One +class PKForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + new_target = ForeignKeyTarget(name='target-2') + new_target.save() + for idx in range(1, 4): + source = ForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve(self): + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 'target-1'}, + {'id': 2, 'name': u'source-2', 'target': 'target-1'}, + {'id': 3, 'name': u'source-3', 'target': 'target-1'} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_foreign_key_retrieve(self): + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, + {'id': 2, 'name': u'target-2', 'sources': []}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update(self): + data = {'id': 1, 'name': u'source-1', 'target': 'target-2'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 'target-2'}, + {'id': 2, 'name': u'source-2', 'target': 'target-1'}, + {'id': 3, 'name': u'source-3', 'target': 'target-1'} + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update_incorrect_type(self): + data = {'id': 1, 'name': u'source-1', 'target': 123} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'target': [u'Object with name=123 does not exist.']}) + + def test_reverse_foreign_key_update(self): + data = {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']} + instance = ForeignKeyTarget.objects.get(pk=2) + serializer = ForeignKeyTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + # We shouldn't have saved anything to the db yet since save + # hasn't been called. + queryset = ForeignKeyTarget.objects.all() + new_serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, + {'id': 2, 'name': u'target-2', 'sources': []}, + ] + self.assertEquals(new_serializer.data, expected) + + serializer.save() + self.assertEquals(serializer.data, data) + + # Ensure target 2 is update, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': ['source-2']}, + {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_create(self): + data = {'id': 4, 'name': u'source-4', 'target': 'target-2'} + serializer = ForeignKeySourceSerializer(data=data) + serializer.is_valid() + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'source-4') + + # Ensure source 4 is added, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 'target-1'}, + {'id': 2, 'name': u'source-2', 'target': 'target-1'}, + {'id': 3, 'name': u'source-3', 'target': 'target-1'}, + {'id': 4, 'name': u'source-4', 'target': 'target-2'}, + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_foreign_key_create(self): + data = {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']} + serializer = ForeignKeyTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'target-3') + + # Ensure target 3 is added, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': ['source-2']}, + {'id': 2, 'name': u'target-2', 'sources': []}, + {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update_with_invalid_null(self): + data = {'id': 1, 'name': u'source-1', 'target': None} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) + class SlugNullableForeignKeyTests(TestCase): def setUp(self): @@ -24,7 +164,7 @@ class SlugNullableForeignKeyTests(TestCase): def test_foreign_key_retrieve_with_null(self): queryset = NullableForeignKeySource.objects.all() - serializer = NullableSlugSourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset) expected = [ {'id': 1, 'name': u'source-1', 'target': 'target-1'}, {'id': 2, 'name': u'source-2', 'target': 'target-1'}, @@ -34,7 +174,7 @@ class SlugNullableForeignKeyTests(TestCase): def test_foreign_key_create_with_valid_null(self): data = {'id': 4, 'name': u'source-4', 'target': None} - serializer = NullableSlugSourceSerializer(data=data) + serializer = NullableForeignKeySourceSerializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() self.assertEquals(serializer.data, data) @@ -42,7 +182,7 @@ class SlugNullableForeignKeyTests(TestCase): # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableSlugSourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset) expected = [ {'id': 1, 'name': u'source-1', 'target': 'target-1'}, {'id': 2, 'name': u'source-2', 'target': 'target-1'}, @@ -58,7 +198,7 @@ class SlugNullableForeignKeyTests(TestCase): """ data = {'id': 4, 'name': u'source-4', 'target': ''} expected_data = {'id': 4, 'name': u'source-4', 'target': None} - serializer = NullableSlugSourceSerializer(data=data) + serializer = NullableForeignKeySourceSerializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() self.assertEquals(serializer.data, expected_data) @@ -66,7 +206,7 @@ class SlugNullableForeignKeyTests(TestCase): # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableSlugSourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset) expected = [ {'id': 1, 'name': u'source-1', 'target': 'target-1'}, {'id': 2, 'name': u'source-2', 'target': 'target-1'}, @@ -78,14 +218,14 @@ class SlugNullableForeignKeyTests(TestCase): def test_foreign_key_update_with_valid_null(self): data = {'id': 1, 'name': u'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) - serializer = NullableSlugSourceSerializer(instance, data=data) + serializer = NullableForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) self.assertEquals(serializer.data, data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableSlugSourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset) expected = [ {'id': 1, 'name': u'source-1', 'target': None}, {'id': 2, 'name': u'source-2', 'target': 'target-1'}, @@ -101,14 +241,14 @@ class SlugNullableForeignKeyTests(TestCase): data = {'id': 1, 'name': u'source-1', 'target': ''} expected_data = {'id': 1, 'name': u'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) - serializer = NullableSlugSourceSerializer(instance, data=data) + serializer = NullableForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) self.assertEquals(serializer.data, expected_data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableSlugSourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset) expected = [ {'id': 1, 'name': u'source-1', 'target': None}, {'id': 2, 'name': u'source-2', 'target': 'target-1'}, diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index bd96ba23..b4428ca3 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -162,7 +162,6 @@ class BasicTests(TestCase): """ Attempting to update fields set as read_only should have no effect. """ - serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99}) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() @@ -183,8 +182,7 @@ class ValidationTests(TestCase): 'content': 'x' * 1001, 'created': datetime.datetime(2012, 1, 1) } - self.actionitem = ActionItem(title='Some to do item', - ) + self.actionitem = ActionItem(title='Some to do item',) def test_create(self): serializer = CommentSerializer(data=self.data) @@ -216,31 +214,6 @@ class ValidationTests(TestCase): self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.errors, {}) - def test_field_validation(self): - - class CommentSerializerWithFieldValidator(CommentSerializer): - - def validate_content(self, attrs, source): - value = attrs[source] - if "test" not in value: - raise serializers.ValidationError("Test not in value") - return attrs - - data = { - 'email': 'tom@example.com', - 'content': 'A test comment', - 'created': datetime.datetime(2012, 1, 1) - } - - serializer = CommentSerializerWithFieldValidator(data=data) - self.assertTrue(serializer.is_valid()) - - data['content'] = 'This should not validate' - - serializer = CommentSerializerWithFieldValidator(data=data) - self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) - def test_bad_type_data_is_false(self): """ Data of the wrong type is not valid. @@ -310,12 +283,69 @@ class ValidationTests(TestCase): self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']}) +class CustomValidationTests(TestCase): + class CommentSerializerWithFieldValidator(CommentSerializer): + + def validate_email(self, attrs, source): + value = attrs[source] + + return attrs + + def validate_content(self, attrs, source): + value = attrs[source] + if "test" not in value: + raise serializers.ValidationError("Test not in value") + return attrs + + def test_field_validation(self): + data = { + 'email': 'tom@example.com', + 'content': 'A test comment', + 'created': datetime.datetime(2012, 1, 1) + } + + serializer = self.CommentSerializerWithFieldValidator(data=data) + self.assertTrue(serializer.is_valid()) + + data['content'] = 'This should not validate' + + serializer = self.CommentSerializerWithFieldValidator(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) + + def test_missing_data(self): + """ + Make sure that validate_content isn't called if the field is missing + """ + incomplete_data = { + 'email': 'tom@example.com', + 'created': datetime.datetime(2012, 1, 1) + } + serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'content': [u'This field is required.']}) + + def test_wrong_data(self): + """ + Make sure that validate_content isn't called if the field input is wrong + """ + wrong_data = { + 'email': 'not an email', + 'content': 'A test comment', + 'created': datetime.datetime(2012, 1, 1) + } + serializer = self.CommentSerializerWithFieldValidator(data=wrong_data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'email': [u'Enter a valid e-mail address.']}) + + class PositiveIntegerAsChoiceTests(TestCase): def test_positive_integer_in_json_is_correctly_parsed(self): - data = {'some_integer':1} + data = {'some_integer': 1} serializer = PositiveIntegerAsChoiceSerializer(data=data) self.assertEquals(serializer.is_valid(), True) + class ModelValidationTests(TestCase): def test_validate_unique(self): """ diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/urlpatterns.py new file mode 100644 index 00000000..43e8ef69 --- /dev/null +++ b/rest_framework/tests/urlpatterns.py @@ -0,0 +1,78 @@ +from collections import namedtuple + +from django.core import urlresolvers + +from django.test import TestCase +from django.test.client import RequestFactory + +from rest_framework.compat import patterns, url, include +from rest_framework.urlpatterns import format_suffix_patterns + + +# A container class for test paths for the test case +URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs']) + + +def dummy_view(request, *args, **kwargs): + pass + + +class FormatSuffixTests(TestCase): + """ + Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. + """ + def _resolve_urlpatterns(self, urlpatterns, test_paths): + factory = RequestFactory() + try: + urlpatterns = format_suffix_patterns(urlpatterns) + except: + self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns") + resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) + for test_path in test_paths: + request = factory.get(test_path.path) + try: + callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) + except: + self.fail("Failed to resolve URL: %s" % request.path_info) + self.assertEquals(callback_args, test_path.args) + self.assertEquals(callback_kwargs, test_path.kwargs) + + def test_format_suffix(self): + urlpatterns = patterns( + '', + url(r'^test$', dummy_view), + ) + test_paths = [ + URLTestPath('/test', (), {}), + URLTestPath('/test.api', (), {'format': 'api'}), + URLTestPath('/test.asdf', (), {'format': 'asdf'}), + ] + self._resolve_urlpatterns(urlpatterns, test_paths) + + def test_default_args(self): + urlpatterns = patterns( + '', + url(r'^test$', dummy_view, {'foo': 'bar'}), + ) + test_paths = [ + URLTestPath('/test', (), {'foo': 'bar', }), + URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}), + URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}), + ] + self._resolve_urlpatterns(urlpatterns, test_paths) + + def test_included_urls(self): + nested_patterns = patterns( + '', + url(r'^path$', dummy_view) + ) + urlpatterns = patterns( + '', + url(r'^test/', include(nested_patterns), {'foo': 'bar'}), + ) + test_paths = [ + URLTestPath('/test/path', (), {'foo': 'bar', }), + URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}), + URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}), + ] + self._resolve_urlpatterns(urlpatterns, test_paths) diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 143928c9..47789026 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,5 +1,35 @@ -from rest_framework.compat import url +from rest_framework.compat import url, include from rest_framework.settings import api_settings +from django.core.urlresolvers import RegexURLResolver + + +def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required): + ret = [] + for urlpattern in urlpatterns: + if isinstance(urlpattern, RegexURLResolver): + # Set of included URL patterns + regex = urlpattern.regex.pattern + namespace = urlpattern.namespace + app_name = urlpattern.app_name + kwargs = urlpattern.default_kwargs + # Add in the included patterns, after applying the suffixes + patterns = apply_suffix_patterns(urlpattern.url_patterns, + suffix_pattern, + suffix_required) + ret.append(url(regex, include(patterns, namespace, app_name), kwargs)) + + else: + # Regular URL pattern + regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern + view = urlpattern._callback or urlpattern._callback_str + kwargs = urlpattern.default_args + name = urlpattern.name + # Add in both the existing and the new urlpattern + if not suffix_required: + ret.append(urlpattern) + ret.append(url(regex, view, kwargs, name)) + + return ret def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): @@ -28,15 +58,4 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): else: suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg - ret = [] - for urlpattern in urlpatterns: - # Form our complementing '.format' urlpattern - regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern - view = urlpattern._callback or urlpattern._callback_str - kwargs = urlpattern.default_args - name = urlpattern.name - # Add in both the existing and the new urlpattern - if not suffix_required: - ret.append(urlpattern) - ret.append(url(regex, view, kwargs, name)) - return ret + return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required) diff --git a/rest_framework/views.py b/rest_framework/views.py index 10bdd5a5..ac9b3385 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -148,6 +148,8 @@ class APIView(View): """ If request is not permitted, determine what kind of exception to raise. """ + if not self.request.successful_authenticator: + raise exceptions.NotAuthenticated() raise exceptions.PermissionDenied() def throttled(self, request, wait): @@ -156,6 +158,15 @@ class APIView(View): """ raise exceptions.Throttled(wait) + def get_authenticate_header(self, request): + """ + If a request is unauthenticated, determine the WWW-Authenticate + header to use for 401 responses, if any. + """ + authenticators = self.get_authenticators() + if authenticators: + return authenticators[0].authenticate_header(request) + def get_parser_context(self, http_request): """ Returns a dict that is passed through to Parser.parse(), @@ -319,6 +330,16 @@ class APIView(View): # Throttle wait header self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait + if isinstance(exc, (exceptions.NotAuthenticated, + exceptions.AuthenticationFailed)): + # WWW-Authenticate header for 401 responses, else coerce to 403 + auth_header = self.get_authenticate_header(self.request) + + if auth_header: + self.headers['WWW-Authenticate'] = auth_header + else: + exc.status_code = status.HTTP_403_FORBIDDEN + if isinstance(exc, exceptions.APIException): return Response({'detail': exc.detail}, status=exc.status_code, |
