diff options
| author | Mark Aaron Shirley | 2013-01-08 08:33:01 -0800 | 
|---|---|---|
| committer | Mark Aaron Shirley | 2013-01-08 08:33:01 -0800 | 
| commit | 81691ff9008c69ee02d4a337dc91ddc523c81b6a (patch) | |
| tree | 99886aa8aacafeec89bc90aa04c616be3429ce5a /rest_framework | |
| parent | a897eb5480348838b11fdb428ce0d110e8bc8da1 (diff) | |
| parent | 431ced66e49905fd76db0c36f62794dc3f42470b (diff) | |
| download | django-rest-framework-81691ff9008c69ee02d4a337dc91ddc523c81b6a.tar.bz2 | |
Merge remote-tracking branch 'upstream/master' into null-one-to-one
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/authtoken/views.py | 3 | ||||
| -rw-r--r-- | rest_framework/parsers.py | 2 | ||||
| -rw-r--r-- | rest_framework/relations.py | 79 | ||||
| -rw-r--r-- | rest_framework/renderers.py | 2 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 9 | ||||
| -rw-r--r-- | rest_framework/settings.py | 4 | ||||
| -rw-r--r-- | rest_framework/templatetags/rest_framework.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/authentication.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/extras/__init__.py | 0 | ||||
| -rw-r--r-- | rest_framework/tests/extras/bad_import.py | 1 | ||||
| -rw-r--r-- | rest_framework/tests/generics.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/hyperlinkedserializers.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 4 | ||||
| -rw-r--r-- | rest_framework/tests/relations.py | 33 | ||||
| -rw-r--r-- | rest_framework/tests/request.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/serializer.py | 20 | ||||
| -rw-r--r-- | rest_framework/tests/settings.py | 21 | ||||
| -rw-r--r-- | rest_framework/utils/encoders.py | 2 | 
18 files changed, 157 insertions, 33 deletions
diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index d318c723..7c03cb76 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -12,10 +12,11 @@ class ObtainAuthToken(APIView):      permission_classes = ()      parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)      renderer_classes = (renderers.JSONRenderer,) +    serializer_class = AuthTokenSerializer      model = Token      def post(self, request): -        serializer = AuthTokenSerializer(data=request.DATA) +        serializer = self.serializer_class(data=request.DATA)          if serializer.is_valid():              token, created = Token.objects.get_or_create(user=serializer.object['user'])              return Response({'token': token.key}) diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 4841676c..149d6431 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -8,11 +8,11 @@ on the request, such as form content or json encoded data.  from django.http import QueryDict  from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser  from django.http.multipartparser import MultiPartParserError -from django.utils import simplejson as json  from rest_framework.compat import yaml, ETParseError  from rest_framework.exceptions import ParseError  from xml.etree import ElementTree as ET  from xml.parsers.expat import ExpatError +import json  import datetime  import decimal diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 6c1d4f5b..5e4552b7 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -4,6 +4,7 @@ from django import forms  from django.forms import widgets  from django.forms.models import ModelChoiceIterator  from django.utils.encoding import smart_unicode +from django.utils.translation import ugettext_lazy as _  from rest_framework.fields import Field, WritableField  from rest_framework.reverse import reverse  from urlparse import urlparse @@ -171,6 +172,11 @@ class PrimaryKeyRelatedField(RelatedField):      default_read_only = False      form_field_class = forms.ChoiceField +    default_error_messages = { +        'does_not_exist': _("Invalid pk '%s' - object does not exist."), +        'invalid': _('Invalid value.'), +    } +      # TODO: Remove these field hacks...      def prepare_value(self, obj):          return self.to_native(obj.pk) @@ -196,7 +202,10 @@ class PrimaryKeyRelatedField(RelatedField):          try:              return self.queryset.get(pk=data)          except ObjectDoesNotExist: -            msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data) +            msg = self.error_messages['does_not_exist'] % smart_unicode(data) +            raise ValidationError(msg) +        except (TypeError, ValueError): +            msg = self.error_messages['invalid']              raise ValidationError(msg)      def field_to_native(self, obj, field_name): @@ -221,6 +230,11 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):      default_read_only = False      form_field_class = forms.MultipleChoiceField +    default_error_messages = { +        'does_not_exist': _("Invalid pk '%s' - object does not exist."), +        'invalid': _('Invalid value.'), +    } +      def prepare_value(self, obj):          return self.to_native(obj.pk) @@ -255,7 +269,10 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):          try:              return self.queryset.get(pk=data)          except ObjectDoesNotExist: -            msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data) +            msg = self.error_messages['does_not_exist'] % smart_unicode(data) +            raise ValidationError(msg) +        except (TypeError, ValueError): +            msg = self.error_messages['invalid']              raise ValidationError(msg)  ### Slug relationships @@ -265,6 +282,11 @@ class SlugRelatedField(RelatedField):      default_read_only = False      form_field_class = forms.ChoiceField +    default_error_messages = { +        'does_not_exist': _("Object with %s=%s does not exist."), +        'invalid': _('Invalid value.'), +    } +      def __init__(self, *args, **kwargs):          self.slug_field = kwargs.pop('slug_field', None)          assert self.slug_field, 'slug_field is required' @@ -280,8 +302,11 @@ class SlugRelatedField(RelatedField):          try:              return self.queryset.get(**{self.slug_field: data})          except ObjectDoesNotExist: -            raise ValidationError('Object with %s=%s does not exist.' % +            raise ValidationError(self.error_messages['does_not_exist'] %                                    (self.slug_field, unicode(data))) +        except (TypeError, ValueError): +            msg = self.error_messages['invalid'] +            raise ValidationError(msg)  class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField): @@ -300,6 +325,14 @@ class HyperlinkedRelatedField(RelatedField):      default_read_only = False      form_field_class = forms.ChoiceField +    default_error_messages = { +        'no_match': _('Invalid hyperlink - No URL match'), +        'incorrect_match': _('Invalid hyperlink - Incorrect URL match'), +        'configuration_error': _('Invalid hyperlink due to configuration error'), +        'does_not_exist': _("Invalid hyperlink - object does not exist."), +        'invalid': _('Invalid value.'), +    } +      def __init__(self, *args, **kwargs):          try:              self.view_name = kwargs.pop('view_name') @@ -336,21 +369,21 @@ class HyperlinkedRelatedField(RelatedField):          slug = getattr(obj, self.slug_field, None)          if not slug: -            raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) +            raise Exception('Could not resolve URL for field using view name "%s"' % view_name)          kwargs = {self.slug_url_kwarg: slug}          try: -            return reverse(self.view_name, kwargs=kwargs, request=request, format=format) +            return reverse(view_name, kwargs=kwargs, request=request, format=format)          except:              pass          kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}          try: -            return reverse(self.view_name, kwargs=kwargs, request=request, format=format) +            return reverse(view_name, kwargs=kwargs, request=request, format=format)          except:              pass -        raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) +        raise Exception('Could not resolve URL for field using view name "%s"' % view_name)      def from_native(self, value):          # Convert URL -> model instance pk @@ -358,7 +391,13 @@ class HyperlinkedRelatedField(RelatedField):          if self.queryset is None:              raise Exception('Writable related fields must include a `queryset` argument') -        if value.startswith('http:') or value.startswith('https:'): +        try: +            http_prefix = value.startswith('http:') or value.startswith('https:') +        except AttributeError: +            msg = self.error_messages['invalid'] +            raise ValidationError(msg) + +        if http_prefix:              # If needed convert absolute URLs to relative path              value = urlparse(value).path              prefix = get_script_prefix() @@ -368,10 +407,10 @@ class HyperlinkedRelatedField(RelatedField):          try:              match = resolve(value)          except: -            raise ValidationError('Invalid hyperlink - No URL match') +            raise ValidationError(self.error_messages['no_match']) -        if match.url_name != self.view_name: -            raise ValidationError('Invalid hyperlink - Incorrect URL match') +        if match.view_name != self.view_name: +            raise ValidationError(self.error_messages['incorrect_match'])          pk = match.kwargs.get(self.pk_url_kwarg, None)          slug = match.kwargs.get(self.slug_url_kwarg, None) @@ -383,14 +422,18 @@ class HyperlinkedRelatedField(RelatedField):          elif slug is not None:              slug_field = self.get_slug_field()              queryset = self.queryset.filter(**{slug_field: slug}) -        # If none of those are defined, it's an error. +        # If none of those are defined, it's probably a configuation error.          else: -            raise ValidationError('Invalid hyperlink') +            raise ValidationError(self.error_messages['configuration_error'])          try:              obj = queryset.get()          except ObjectDoesNotExist: -            raise ValidationError('Invalid hyperlink - object does not exist.') +            raise ValidationError(self.error_messages['does_not_exist']) +        except (TypeError, ValueError): +            msg = self.error_messages['invalid'] +            raise ValidationError(msg) +          return obj @@ -449,18 +492,18 @@ class HyperlinkedIdentityField(Field):          slug = getattr(obj, self.slug_field, None)          if not slug: -            raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) +            raise Exception('Could not resolve URL for field using view name "%s"' % view_name)          kwargs = {self.slug_url_kwarg: slug}          try: -            return reverse(self.view_name, kwargs=kwargs, request=request, format=format) +            return reverse(view_name, kwargs=kwargs, request=request, format=format)          except:              pass          kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}          try: -            return reverse(self.view_name, kwargs=kwargs, request=request, format=format) +            return reverse(view_name, kwargs=kwargs, request=request, format=format)          except:              pass -        raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) +        raise Exception('Could not resolve URL for field using view name "%s"' % view_name) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index a4ae717d..0a34abaa 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -8,10 +8,10 @@ REST framework also provides an HTML renderer the renders the browsable API.  """  import copy  import string +import json  from django import forms  from django.http.multipartparser import parse_header  from django.template import RequestContext, loader, Template -from django.utils import simplejson as json  from rest_framework.compat import yaml  from rest_framework.exceptions import ConfigurationError  from rest_framework.settings import api_settings diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 3391a262..da0af467 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -208,6 +208,11 @@ class BaseSerializer(Field):          Converts a dictionary of data into a dictionary of deserialized fields.          """          reverted_data = {} + +        if data is not None and not isinstance(data, dict): +            self._errors['non_field_errors'] = [u'Invalid data'] +            return None +          for field_name, field in self.fields.items():              field.initialize(parent=self, field_name=field_name)              try: @@ -276,7 +281,7 @@ class BaseSerializer(Field):          """          if hasattr(data, '__iter__') and not isinstance(data, dict):              # TODO: error data when deserializing lists -            return (self.from_native(item) for item in data) +            return [self.from_native(item, None) for item in data]          self._errors = {}          if data is not None or files is not None: @@ -428,7 +433,7 @@ class ModelSerializer(Serializer):          # TODO: filter queryset using:          # .using(db).complex_filter(self.rel.limit_choices_to)          kwargs = { -            'null': model_field.null, +            'null': model_field.null or model_field.blank,              'queryset': model_field.rel.to._default_manager          } diff --git a/rest_framework/settings.py b/rest_framework/settings.py index ee24a4ad..5c77c55c 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -115,8 +115,8 @@ def import_from_string(val, setting_name):          module_path, class_name = '.'.join(parts[:-1]), parts[-1]          module = importlib.import_module(module_path)          return getattr(module, class_name) -    except: -        msg = "Could not import '%s' for API setting '%s'" % (val, setting_name) +    except ImportError as e: +        msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)          raise ImportError(msg) diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 09c658bc..82fcdfe7 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -23,7 +23,7 @@ register = template.Library()  # conflicts with this rest_framework template tag module.  try:  # Django 1.5+ -    from django.contrib.staticfiles.templatetags import StaticFilesNode +    from django.contrib.staticfiles.templatetags.staticfiles import StaticFilesNode      @register.tag('static')      def do_static(parser, token): diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 838e081b..e86041bc 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -1,7 +1,6 @@  from django.contrib.auth.models import User  from django.http import HttpResponse  from django.test import Client, TestCase -from django.utils import simplejson as json  from rest_framework import permissions  from rest_framework.authtoken.models import Token @@ -9,6 +8,7 @@ from rest_framework.authentication import TokenAuthentication  from rest_framework.compat import patterns  from rest_framework.views import APIView +import json  import base64 diff --git a/rest_framework/tests/extras/__init__.py b/rest_framework/tests/extras/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/extras/__init__.py diff --git a/rest_framework/tests/extras/bad_import.py b/rest_framework/tests/extras/bad_import.py new file mode 100644 index 00000000..68263d94 --- /dev/null +++ b/rest_framework/tests/extras/bad_import.py @@ -0,0 +1 @@ +raise ValueError diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 843017eb..4799a04b 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,6 +1,6 @@ +import json  from django.db import models  from django.test import TestCase -from django.utils import simplejson as json  from rest_framework import generics, serializers, status  from rest_framework.tests.utils import RequestFactory  from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index ee4d8e57..c6a8224b 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -1,6 +1,6 @@ +import json  from django.test import TestCase  from django.test.client import RequestFactory -from django.utils import simplejson as json  from rest_framework import generics, status, serializers  from rest_framework.compat import patterns, url  from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 81d297a1..3b550877 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -181,10 +181,10 @@ class UnitTestPagination(TestCase):          """          Ensure context gets passed through to the object serializer.          """ -        serializer = PassOnContextPaginationSerializer(self.first_page) +        serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})          serializer.data          results = serializer.fields[serializer.results_field] -        self.assertTrue(serializer.context is results.context) +        self.assertEquals(serializer.context, results.context)  class TestUnpaginated(TestCase): diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py new file mode 100644 index 00000000..91daea8a --- /dev/null +++ b/rest_framework/tests/relations.py @@ -0,0 +1,33 @@ +""" +General tests for relational fields. +""" + +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class NullModel(models.Model): +    pass + + +class FieldTests(TestCase): +    def test_pk_related_field_with_empty_string(self): +        """ +        Regression test for #446 + +        https://github.com/tomchristie/django-rest-framework/issues/446 +        """ +        field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) +        self.assertRaises(serializers.ValidationError, field.from_native, '') +        self.assertRaises(serializers.ValidationError, field.from_native, []) + +    def test_hyperlinked_related_field_with_empty_string(self): +        field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') +        self.assertRaises(serializers.ValidationError, field.from_native, '') +        self.assertRaises(serializers.ValidationError, field.from_native, []) + +    def test_slug_related_field_with_empty_string(self): +        field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') +        self.assertRaises(serializers.ValidationError, field.from_native, '') +        self.assertRaises(serializers.ValidationError, field.from_native, []) diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index 1f05ff8f..4b032405 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -1,12 +1,12 @@  """  Tests for content parsing, and form-overloaded content parsing.  """ +import json  from django.contrib.auth.models import User  from django.contrib.auth import authenticate, login, logout  from django.contrib.sessions.middleware import SessionMiddleware  from django.test import TestCase, Client  from django.test.client import RequestFactory -from django.utils import simplejson as json  from rest_framework import status  from rest_framework.authentication import SessionAuthentication  from rest_framework.compat import patterns diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 8767385e..bd96ba23 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -69,6 +69,7 @@ class AlbumsSerializer(serializers.ModelSerializer):          model = Album          fields = ['title']  # lists are also valid options +  class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):      class Meta:          model = HasPositiveIntegerAsChoice @@ -240,6 +241,25 @@ class ValidationTests(TestCase):          self.assertFalse(serializer.is_valid())          self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) +    def test_bad_type_data_is_false(self): +        """ +        Data of the wrong type is not valid. +        """ +        data = ['i am', 'a', 'list'] +        serializer = CommentSerializer(self.comment, data=data) +        self.assertEquals(serializer.is_valid(), False) +        self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) + +        data = 'and i am a string' +        serializer = CommentSerializer(self.comment, data=data) +        self.assertEquals(serializer.is_valid(), False) +        self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) + +        data = 42 +        serializer = CommentSerializer(self.comment, data=data) +        self.assertEquals(serializer.is_valid(), False) +        self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) +      def test_cross_field_validation(self):          class CommentSerializerWithCrossFieldValidator(CommentSerializer): diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/settings.py new file mode 100644 index 00000000..0293fdc3 --- /dev/null +++ b/rest_framework/tests/settings.py @@ -0,0 +1,21 @@ +"""Tests for the settings module""" +from django.test import TestCase + +from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS + + +class TestSettings(TestCase): +    """Tests relating to the api settings""" + +    def test_non_import_errors(self): +        """Make sure other errors aren't suppressed.""" +        settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) +        with self.assertRaises(ValueError): +            settings.DEFAULT_MODEL_SERIALIZER_CLASS + +    def test_import_error_message_maintained(self): +        """Make sure real import errors are captured and raised sensibly.""" +        settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) +        with self.assertRaises(ImportError) as cm: +            settings.DEFAULT_MODEL_SERIALIZER_CLASS +        self.assertTrue('ImportError' in str(cm.exception)) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 2d1fb353..c70b24dd 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -4,7 +4,7 @@ Helper classes for parsers.  import datetime  import decimal  import types -from django.utils import simplejson as json +import json  from django.utils.datastructures import SortedDict  from rest_framework.compat import timezone  from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata  | 
