diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/__init__.py | 2 | ||||
| -rw-r--r-- | rest_framework/authtoken/views.py | 2 | ||||
| -rw-r--r-- | rest_framework/fields.py | 28 | ||||
| -rw-r--r-- | rest_framework/negotiation.py | 4 | ||||
| -rw-r--r-- | rest_framework/renderers.py | 28 | ||||
| -rw-r--r-- | rest_framework/request.py | 9 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 36 | ||||
| -rw-r--r-- | rest_framework/tests/authentication.py | 4 | ||||
| -rw-r--r-- | rest_framework/tests/models.py | 25 | ||||
| -rw-r--r-- | rest_framework/tests/request.py | 28 | ||||
| -rw-r--r-- | rest_framework/tests/serializer.py | 41 |
11 files changed, 141 insertions, 66 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index a2233f3d..48cebbc5 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,3 +1,3 @@ -__version__ = '2.1.4' +__version__ = '2.1.6' VERSION = __version__ # synonym diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 3ac674e2..cfaacbe9 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -18,7 +18,7 @@ class ObtainAuthToken(APIView): if serializer.is_valid(): token, created = Token.objects.get_or_create(user=serializer.object['user']) return Response({'token': token.key}) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + return Response(serializer.errors, status=status.HTTP_401_UNAUTHORIZED) obtain_auth_token = ObtainAuthToken.as_view() diff --git a/rest_framework/fields.py b/rest_framework/fields.py index dff9123d..ff39fac4 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -10,6 +10,7 @@ from django.core import validators from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.core.urlresolvers import resolve, get_script_prefix from django.conf import settings +from django import forms from django.forms import widgets from django.forms.models import ModelChoiceIterator from django.utils.encoding import is_protected_type, smart_unicode @@ -35,6 +36,7 @@ class Field(object): empty = '' type_name = None _use_files = None + form_field_class = forms.CharField def __init__(self, source=None): self.parent = None @@ -55,7 +57,7 @@ class Field(object): self.root = parent.root or parent self.context = self.root.context if self.root.partial: - self.required = False + self.required = False def field_from_native(self, data, files, field_name, into): """ @@ -223,7 +225,7 @@ class ModelField(WritableField): getattr(self.model_field, 'min_length', None)) self.max_length = kwargs.pop('max_length', getattr(self.model_field, 'max_length', None)) - + super(ModelField, self).__init__(*args, **kwargs) if self.min_length is not None: @@ -394,6 +396,7 @@ class PrimaryKeyRelatedField(RelatedField): Represents a to-one relationship as a pk value. """ default_read_only = False + form_field_class = forms.ChoiceField # TODO: Remove these field hacks... def prepare_value(self, obj): @@ -440,6 +443,7 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField): Represents a to-many relationship as a pk value. """ default_read_only = False + form_field_class = forms.MultipleChoiceField def prepare_value(self, obj): return self.to_native(obj.pk) @@ -483,6 +487,7 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField): class SlugRelatedField(RelatedField): default_read_only = False + form_field_class = forms.ChoiceField def __init__(self, *args, **kwargs): self.slug_field = kwargs.pop('slug_field', None) @@ -504,7 +509,7 @@ class SlugRelatedField(RelatedField): class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField): - pass + form_field_class = forms.MultipleChoiceField ### Hyperlinked relationships @@ -517,6 +522,7 @@ class HyperlinkedRelatedField(RelatedField): slug_field = 'slug' slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden default_read_only = False + form_field_class = forms.ChoiceField def __init__(self, *args, **kwargs): try: @@ -616,7 +622,7 @@ class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField): """ Represents a to-many relationship, using hyperlinking. """ - pass + form_field_class = forms.MultipleChoiceField class HyperlinkedIdentityField(Field): @@ -674,6 +680,7 @@ class HyperlinkedIdentityField(Field): class BooleanField(WritableField): type_name = 'BooleanField' + form_field_class = forms.BooleanField widget = widgets.CheckboxInput default_error_messages = { 'invalid': _(u"'%s' value must be either True or False."), @@ -686,15 +693,16 @@ class BooleanField(WritableField): default = False def from_native(self, value): - if value in ('t', 'True', '1'): + if value in ('true', 't', 'True', '1'): return True - if value in ('f', 'False', '0'): + if value in ('false', 'f', 'False', '0'): return False return bool(value) class CharField(WritableField): type_name = 'CharField' + form_field_class = forms.CharField def __init__(self, max_length=None, min_length=None, *args, **kwargs): self.max_length, self.min_length = max_length, min_length @@ -739,6 +747,7 @@ class SlugField(CharField): class ChoiceField(WritableField): type_name = 'ChoiceField' + form_field_class = forms.ChoiceField widget = widgets.Select default_error_messages = { 'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'), @@ -785,6 +794,7 @@ class ChoiceField(WritableField): class EmailField(CharField): type_name = 'EmailField' + form_field_class = forms.EmailField default_error_messages = { 'invalid': _('Enter a valid e-mail address.'), @@ -836,6 +846,7 @@ class RegexField(CharField): class DateField(WritableField): type_name = 'DateField' widget = widgets.DateInput + form_field_class = forms.DateField default_error_messages = { 'invalid': _(u"'%s' value has an invalid date format. It must be " @@ -874,6 +885,7 @@ class DateField(WritableField): class DateTimeField(WritableField): type_name = 'DateTimeField' widget = widgets.DateTimeInput + form_field_class = forms.DateTimeField default_error_messages = { 'invalid': _(u"'%s' value has an invalid format. It must be in " @@ -928,6 +940,7 @@ class DateTimeField(WritableField): class IntegerField(WritableField): type_name = 'IntegerField' + form_field_class = forms.IntegerField default_error_messages = { 'invalid': _('Enter a whole number.'), @@ -957,6 +970,7 @@ class IntegerField(WritableField): class FloatField(WritableField): type_name = 'FloatField' + form_field_class = forms.FloatField default_error_messages = { 'invalid': _("'%s' value must be a float."), @@ -976,6 +990,7 @@ class FloatField(WritableField): class FileField(WritableField): _use_files = True type_name = 'FileField' + form_field_class = forms.FileField widget = widgets.FileInput default_error_messages = { @@ -1018,6 +1033,7 @@ class FileField(WritableField): class ImageField(FileField): _use_files = True + form_field_class = forms.ImageField default_error_messages = { 'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."), diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index dae38477..ee2800a6 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -2,6 +2,7 @@ from django.http import Http404 from rest_framework import exceptions from rest_framework.settings import api_settings from rest_framework.utils.mediatypes import order_by_precedence, media_type_matches +from rest_framework.utils.mediatypes import _MediaType class BaseContentNegotiation(object): @@ -48,7 +49,8 @@ class DefaultContentNegotiation(BaseContentNegotiation): for media_type in media_type_set: if media_type_matches(renderer.media_type, media_type): # Return the most specific media type as accepted. - if len(renderer.media_type) > len(media_type): + if (_MediaType(renderer.media_type).precedence > + _MediaType(media_type).precedence): # Eg client requests '*/*' # Accepted media type is 'application/json' return renderer, renderer.media_type diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 550963cb..25a32baa 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -306,26 +306,6 @@ class BrowsableAPIRenderer(BaseRenderer): return True def serializer_to_form_fields(self, serializer): - field_mapping = { - serializers.FloatField: forms.FloatField, - serializers.IntegerField: forms.IntegerField, - serializers.DateTimeField: forms.DateTimeField, - serializers.DateField: forms.DateField, - serializers.EmailField: forms.EmailField, - serializers.RegexField: forms.RegexField, - serializers.CharField: forms.CharField, - serializers.ChoiceField: forms.ChoiceField, - serializers.BooleanField: forms.BooleanField, - serializers.PrimaryKeyRelatedField: forms.ChoiceField, - serializers.ManyPrimaryKeyRelatedField: forms.MultipleChoiceField, - serializers.SlugRelatedField: forms.ChoiceField, - serializers.ManySlugRelatedField: forms.MultipleChoiceField, - serializers.HyperlinkedRelatedField: forms.ChoiceField, - serializers.ManyHyperlinkedRelatedField: forms.MultipleChoiceField, - serializers.FileField: forms.FileField, - serializers.ImageField: forms.ImageField, - } - fields = {} for k, v in serializer.get_fields().items(): if getattr(v, 'read_only', True): @@ -349,13 +329,7 @@ class BrowsableAPIRenderer(BaseRenderer): kwargs['label'] = k - try: - fields[k] = field_mapping[v.__class__](**kwargs) - except KeyError: - if getattr(v, 'choices', None) is not None: - fields[k] = forms.ChoiceField(**kwargs) - else: - fields[k] = forms.CharField(**kwargs) + fields[k] = v.form_field_class(**kwargs) return fields def get_form(self, view, method, request): diff --git a/rest_framework/request.py b/rest_framework/request.py index a1827ba4..39c64321 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -169,6 +169,15 @@ class Request(object): self._user, self._auth = self._authenticate() return self._user + @user.setter + def user(self, value): + """ + Sets the user on the current request. This is necessary to maintain + compatilbility with django.contrib.auth where the user proprety is + set in the login and logout functions. + """ + self._user = value + @property def auth(self): """ diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 775a8a1e..984f3ac5 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -60,7 +60,7 @@ def _get_declared_fields(bases, attrs): # If this class is subclassing another Serializer, add that Serializer's # fields. Note that we loop over the bases in *reverse*. This is necessary - # in order to the correct order of fields. + # in order to maintain the correct order of fields. for base in bases[::-1]: if hasattr(base, 'base_fields'): fields = base.base_fields.items() + fields @@ -94,7 +94,6 @@ class BaseSerializer(Field): def __init__(self, instance=None, data=None, files=None, context=None, partial=False, **kwargs): super(BaseSerializer, self).__init__(**kwargs) self.opts = self._options_class(self.Meta) - self.fields = copy.deepcopy(self.base_fields) self.parent = None self.root = None self.partial = partial @@ -104,7 +103,7 @@ class BaseSerializer(Field): self.init_data = data self.init_files = files self.object = instance - self.default_fields = self.get_default_fields() + self.fields = self.get_fields() self._data = None self._files = None @@ -140,13 +139,15 @@ class BaseSerializer(Field): ret = SortedDict() # Get the explicitly declared fields - for key, field in self.fields.items(): + base_fields = copy.deepcopy(self.base_fields) + for key, field in base_fields.items(): ret[key] = field # Set up the field field.initialize(parent=self, field_name=key) # Add in the default fields - for key, val in self.default_fields.items(): + default_fields = self.get_default_fields() + for key, val in default_fields.items(): if key not in ret: ret[key] = val @@ -193,8 +194,7 @@ class BaseSerializer(Field): ret = self._dict_class() ret.fields = {} - fields = self.get_fields() - for field_name, field in fields.items(): + for field_name, field in self.fields.items(): key = self.get_field_key(field_name) value = field.field_to_native(obj, field_name) ret[key] = value @@ -206,9 +206,8 @@ class BaseSerializer(Field): Core of deserialization, together with `restore_object`. Converts a dictionary of data into a dictionary of deserialized fields. """ - fields = self.get_fields() reverted_data = {} - for field_name, field in fields.items(): + for field_name, field in self.fields.items(): try: field.field_from_native(data, files, field_name, reverted_data) except ValidationError as err: @@ -220,10 +219,7 @@ class BaseSerializer(Field): """ Run `validate_<fieldname>()` and `validate()` methods on the serializer """ - # TODO: refactor this so we're not determining the fields again - fields = self.get_fields() - - for field_name, field in fields.items(): + for field_name, field in self.fields.items(): try: validate_method = getattr(self, 'validate_%s' % field_name, None) if validate_method: @@ -294,10 +290,18 @@ class BaseSerializer(Field): Override default so that we can apply ModelSerializer as a nested field to relationships. """ - obj = getattr(obj, self.source or field_name) - if is_simple_callable(obj): - obj = obj() + if self.source: + value = obj + for component in self.source.split('.'): + value = getattr(value, component) + if is_simple_callable(value): + value = value() + obj = value + else: + value = getattr(obj, field_name) + if is_simple_callable(value): + obj = value() # If the object has an "all" method, assume it's a relationship if is_simple_callable(getattr(obj, 'all', None)): diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 96ca9f52..802bc6c1 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -167,14 +167,14 @@ class TokenAuthTests(TestCase): client = Client(enforce_csrf_checks=True) response = client.post('/auth-token/login/', json.dumps({'username': self.username, 'password': "badpass"}), 'application/json') - self.assertEqual(response.status_code, 400) + self.assertEqual(response.status_code, 401) def test_token_login_json_missing_fields(self): """Ensure token login view using JSON POST fails if missing fields.""" client = Client(enforce_csrf_checks=True) response = client.post('/auth-token/login/', json.dumps({'username': self.username}), 'application/json') - self.assertEqual(response.status_code, 400) + self.assertEqual(response.status_code, 401) def test_token_login_form(self): """Ensure token login view using form POST works.""" diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 9a59e841..428bf130 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -124,8 +124,21 @@ class ActionItem(RESTFrameworkModel): # Models for reverse relations +class Person(RESTFrameworkModel): + name = models.CharField(max_length=10) + age = models.IntegerField(null=True, blank=True) + + @property + def info(self): + return { + 'name': self.name, + 'age': self.age, + } + + class BlogPost(RESTFrameworkModel): title = models.CharField(max_length=100) + writer = models.ForeignKey(Person, null=True, blank=True) def get_first_comment(self): return self.blogpostcomment_set.all()[0] @@ -145,18 +158,6 @@ class Photo(RESTFrameworkModel): album = models.ForeignKey(Album) -class Person(RESTFrameworkModel): - name = models.CharField(max_length=10) - age = models.IntegerField(null=True, blank=True) - - @property - def info(self): - return { - 'name': self.name, - 'age': self.age, - } - - # Model for issue #324 class BlankFieldModel(RESTFrameworkModel): title = models.CharField(max_length=100, blank=True, null=True) diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index ff48f3fa..2850992d 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -3,6 +3,8 @@ Tests for content parsing, and form-overloaded content parsing. """ from django.conf.urls.defaults import patterns from django.contrib.auth.models import User +from django.contrib.auth import authenticate, login, logout +from django.contrib.sessions.middleware import SessionMiddleware from django.test import TestCase, Client from django.utils import simplejson as json @@ -276,3 +278,29 @@ class TestContentParsingWithAuthentication(TestCase): # response = self.csrf_client.post('/', content) # self.assertEqual(status.OK, response.status_code, "POST data is malformed") + + +class TestUserSetter(TestCase): + + def setUp(self): + # Pass request object through session middleware so session is + # available to login and logout functions + self.request = Request(factory.get('/')) + SessionMiddleware().process_request(self.request) + + User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow') + self.user = authenticate(username='ringo', password='yellow') + + def test_user_can_be_set(self): + self.request.user = self.user + self.assertEqual(self.request.user, self.user) + + def test_user_can_login(self): + login(self.request, self.user) + self.assertEqual(self.request.user, self.user) + + def test_user_can_logout(self): + self.request.user = self.user + self.assertFalse(self.request.user.is_anonymous()) + logout(self.request) + self.assertTrue(self.request.user.is_anonymous()) diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index bdf72a91..18e24b71 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -577,6 +577,47 @@ class ManyRelatedTests(TestCase): self.assertEqual(serializer.data, expected) +class RelatedTraversalTest(TestCase): + def test_nested_traversal(self): + user = Person.objects.create(name="django") + post = BlogPost.objects.create(title="Test blog post", writer=user) + post.blogpostcomment_set.create(text="I love this blog post") + + from rest_framework.tests.models import BlogPostComment + + class PersonSerializer(serializers.ModelSerializer): + class Meta: + model = Person + fields = ("name", "age") + + class BlogPostCommentSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPostComment + fields = ("text", "post_owner") + + text = serializers.CharField() + post_owner = PersonSerializer(source='blog_post.writer') + + class BlogPostSerializer(serializers.Serializer): + title = serializers.CharField() + comments = BlogPostCommentSerializer(source='blogpostcomment_set') + + serializer = BlogPostSerializer(instance=post) + + expected = { + 'title': u'Test blog post', + 'comments': [{ + 'text': u'I love this blog post', + 'post_owner': { + "name": u"django", + "age": None + } + }] + } + + self.assertEqual(serializer.data, expected) + + class SerializerMethodFieldTests(TestCase): def setUp(self): |
