aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorAlex2013-11-12 23:40:24 +0000
committerAlex2013-11-12 23:40:24 +0000
commit5136798a040fc306a37b562c7cd629ab34bc02e3 (patch)
treeae51dec9e5ee4bcebbd22d1383602d66f06a996b /rest_framework
parentd1dc68d7550e90ba56a3122f8de1f38bb5aa1e3a (diff)
parent8552e79d7b62ca7f0edd131d3049ca220d879d48 (diff)
downloaddjango-rest-framework-5136798a040fc306a37b562c7cd629ab34bc02e3.tar.bz2
Merge branch 'master' into allow-aggregate-ordering
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/fields.py8
-rw-r--r--rest_framework/generics.py21
-rw-r--r--rest_framework/mixins.py3
-rw-r--r--rest_framework/parsers.py4
-rw-r--r--rest_framework/request.py2
-rw-r--r--rest_framework/serializers.py26
-rw-r--r--rest_framework/templates/rest_framework/base.html5
-rw-r--r--rest_framework/tests/test_fields.py74
-rw-r--r--rest_framework/tests/test_pagination.py85
-rw-r--r--rest_framework/tests/test_renderers.py13
-rw-r--r--rest_framework/tests/test_request.py32
-rw-r--r--rest_framework/tests/test_serializer.py27
-rw-r--r--rest_framework/tests/test_serializer_nested.py67
-rw-r--r--rest_framework/utils/encoders.py3
-rw-r--r--rest_framework/views.py4
-rw-r--r--rest_framework/viewsets.py2
16 files changed, 332 insertions, 44 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index e23fc001..6c07dbb3 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -497,6 +497,7 @@ class ChoiceField(WritableField):
}
def __init__(self, choices=(), *args, **kwargs):
+ self.empty = kwargs.pop('empty', '')
super(ChoiceField, self).__init__(*args, **kwargs)
self.choices = choices
if not self.required:
@@ -537,9 +538,10 @@ class ChoiceField(WritableField):
return False
def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
- return super(ChoiceField, self).from_native(value)
+ value = super(ChoiceField, self).from_native(value)
+ if value == self.empty or value in validators.EMPTY_VALUES:
+ return self.empty
+ return value
class EmailField(CharField):
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 4f134bce..7cb80a84 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -54,6 +54,7 @@ class GenericAPIView(views.APIView):
# If you want to use object lookups other than pk, set this attribute.
# For more complex lookup requirements override `get_object()`.
lookup_field = 'pk'
+ lookup_url_kwarg = None
# Pagination settings
paginate_by = api_settings.PAGINATE_BY
@@ -147,8 +148,8 @@ class GenericAPIView(views.APIView):
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1
try:
- page_number = strict_positive_int(page)
- except ValueError:
+ page_number = paginator.validate_number(page)
+ except InvalidPage:
if page == 'last':
page_number = paginator.num_pages
else:
@@ -174,6 +175,14 @@ class GenericAPIView(views.APIView):
method if you want to apply the configured filtering backend to the
default queryset.
"""
+ for backend in self.get_filter_backends():
+ queryset = backend().filter_queryset(self.request, queryset, self)
+ return queryset
+
+ def get_filter_backends(self):
+ """
+ Returns the list of filter backends that this view requires.
+ """
filter_backends = self.filter_backends or []
if not filter_backends and self.filter_backend:
warnings.warn(
@@ -184,10 +193,8 @@ class GenericAPIView(views.APIView):
PendingDeprecationWarning, stacklevel=2
)
filter_backends = [self.filter_backend]
+ return filter_backends
- for backend in filter_backends:
- queryset = backend().filter_queryset(self.request, queryset, self)
- return queryset
########################
### The following methods provide default implementations
@@ -278,9 +285,11 @@ class GenericAPIView(views.APIView):
pass # Deprecation warning
# Perform the lookup filtering.
+ # Note that `pk` and `slug` are deprecated styles of lookup filtering.
+ lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
+ lookup = self.kwargs.get(lookup_url_kwarg, None)
pk = self.kwargs.get(self.pk_url_kwarg, None)
slug = self.kwargs.get(self.slug_url_kwarg, None)
- lookup = self.kwargs.get(self.lookup_field, None)
if lookup is not None:
filter_kwargs = {self.lookup_field: lookup}
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 426865ff..4606c78b 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -158,7 +158,8 @@ class UpdateModelMixin(object):
Set any attributes on the object that are implicit in the request.
"""
# pk and/or slug attributes are implicit in the URL.
- lookup = self.kwargs.get(self.lookup_field, None)
+ lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
+ lookup = self.kwargs.get(lookup_url_kwarg, None)
pk = self.kwargs.get(self.pk_url_kwarg, None)
slug = self.kwargs.get(self.slug_url_kwarg, None)
slug_field = slug and self.slug_field or None
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 98fc0341..f1b3e38d 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -83,7 +83,7 @@ class YAMLParser(BaseParser):
data = stream.read().decode(encoding)
return yaml.safe_load(data)
except (ValueError, yaml.parser.ParserError) as exc:
- raise ParseError('YAML parse error - %s' % six.u(exc))
+ raise ParseError('YAML parse error - %s' % six.text_type(exc))
class FormParser(BaseParser):
@@ -153,7 +153,7 @@ class XMLParser(BaseParser):
try:
tree = etree.parse(stream, parser=parser, forbid_dtd=True)
except (etree.ParseError, ValueError) as exc:
- raise ParseError('XML parse error - %s' % six.u(exc))
+ raise ParseError('XML parse error - %s' % six.text_type(exc))
data = self._xml_convert(tree.getroot())
return data
diff --git a/rest_framework/request.py b/rest_framework/request.py
index 977d4d96..b883d0d4 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -334,7 +334,7 @@ class Request(object):
self._CONTENT_PARAM in self._data and
self._CONTENTTYPE_PARAM in self._data):
self._content_type = self._data[self._CONTENTTYPE_PARAM]
- self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(HTTP_HEADER_ENCODING))
+ self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))
self._data, self._files = (Empty, Empty)
def _parse(self):
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 33db82ee..163abf4f 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -6,8 +6,8 @@ form encoded input.
Serialization in REST framework is a two-phase process:
1. Serializers marshal between complex types like model instances, and
-python primatives.
-2. The process of marshalling between python primatives and request and
+python primitives.
+2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
from __future__ import unicode_literals
@@ -42,6 +42,7 @@ def pretty_name(name):
class RelationsList(list):
_deleted = []
+
class NestedValidationError(ValidationError):
"""
The default ValidationError behavior is to stringify each item in the list
@@ -56,9 +57,13 @@ class NestedValidationError(ValidationError):
def __init__(self, message):
if isinstance(message, dict):
- self.messages = [message]
+ self._messages = [message]
else:
- self.messages = message
+ self._messages = message
+
+ @property
+ def messages(self):
+ return self._messages
class DictWithMetadata(dict):
@@ -262,10 +267,13 @@ class BaseSerializer(WritableField):
for field_name, field in self.fields.items():
if field_name in self._errors:
continue
+
+ source = field.source or field_name
+ if self.partial and source not in attrs:
+ continue
try:
validate_method = getattr(self, 'validate_%s' % field_name, None)
if validate_method:
- source = field.source or field_name
attrs = validate_method(attrs, source)
except ValidationError as err:
self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
@@ -403,7 +411,7 @@ class BaseSerializer(WritableField):
return
# Set the serializer object if it exists
- obj = getattr(self.parent.object, field_name) if self.parent.object else None
+ obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
if self.source == '*':
@@ -791,6 +799,8 @@ class ModelSerializer(Serializer):
# TODO: TypedChoiceField?
if model_field.flatchoices: # This ModelField contains choices
kwargs['choices'] = model_field.flatchoices
+ if model_field.null:
+ kwargs['empty'] = None
return ChoiceField(**kwargs)
# put this below the ChoiceField because min_value isn't a valid initializer
@@ -868,7 +878,7 @@ class ModelSerializer(Serializer):
# Reverse m2m relations
for (obj, model) in meta.get_all_related_m2m_objects_with_model():
- field_name = obj.field.related_query_name()
+ field_name = obj.get_accessor_name()
if field_name in attrs:
m2m_data[field_name] = attrs.pop(field_name)
@@ -912,7 +922,7 @@ class ModelSerializer(Serializer):
def save_object(self, obj, **kwargs):
"""
- Save the deserialized object and return it.
+ Save the deserialized object.
"""
if getattr(obj, '_nested_forward_relations', None):
# Nested relationships need to be saved before we can save the
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 33be36db..495163b6 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -110,7 +110,9 @@
<div class="content-main">
<div class="page-header"><h1>{{ name }}</h1></div>
+ {% block description %}
{{ description }}
+ {% endblock %}
<div class="request-info" style="clear: both" >
<pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre>
</div>
@@ -219,9 +221,6 @@
</div><!-- ./wrapper -->
{% block footer %}
- <!--<div id="footer">
- <a class="powered-by" href='http://django-rest-framework.org'>Django REST framework</a>
- </div>-->
{% endblock %}
{% block script %}
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index 34fbab9c..ab2cceac 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -42,6 +42,31 @@ class TimeFieldModelSerializer(serializers.ModelSerializer):
model = TimeFieldModel
+SAMPLE_CHOICES = [
+ ('red', 'Red'),
+ ('green', 'Green'),
+ ('blue', 'Blue'),
+]
+
+
+class ChoiceFieldModel(models.Model):
+ choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255)
+
+
+class ChoiceFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChoiceFieldModel
+
+
+class ChoiceFieldModelWithNull(models.Model):
+ choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255)
+
+
+class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChoiceFieldModelWithNull
+
+
class BasicFieldTests(TestCase):
def test_auto_now_fields_read_only(self):
"""
@@ -667,34 +692,53 @@ class ChoiceFieldTests(TestCase):
"""
Tests for the ChoiceField options generator
"""
-
- SAMPLE_CHOICES = [
- ('red', 'Red'),
- ('green', 'Green'),
- ('blue', 'Blue'),
- ]
-
def test_choices_required(self):
"""
Make sure proper choices are rendered if field is required
"""
- f = serializers.ChoiceField(required=True, choices=self.SAMPLE_CHOICES)
- self.assertEqual(f.choices, self.SAMPLE_CHOICES)
+ f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.choices, SAMPLE_CHOICES)
def test_choices_not_required(self):
"""
Make sure proper choices (plus blank) are rendered if the field isn't required
"""
- f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)
- self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES)
+ f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES)
+
+ def test_invalid_choice_model(self):
+ s = ChoiceFieldModelSerializer(data={'choice' : 'wrong_value'})
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']})
+ self.assertEqual(s.data['choice'], '')
+
+ def test_empty_choice_model(self):
+ """
+ Test that the 'empty' value is correctly passed and used depending on the 'null' property on the model field.
+ """
+ s = ChoiceFieldModelSerializer(data={'choice' : ''})
+ self.assertTrue(s.is_valid())
+ self.assertEqual(s.data['choice'], '')
+
+ s = ChoiceFieldModelWithNullSerializer(data={'choice' : ''})
+ self.assertTrue(s.is_valid())
+ self.assertEqual(s.data['choice'], None)
def test_from_native_empty(self):
"""
- Make sure from_native() returns None on empty param.
+ Make sure from_native() returns an empty string on empty param by default.
"""
- f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES)
- result = f.from_native('')
- self.assertEqual(result, None)
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
+ self.assertEqual(f.from_native(''), '')
+ self.assertEqual(f.from_native(None), '')
+
+ def test_from_native_empty_override(self):
+ """
+ Make sure you can override from_native() behavior regarding empty values.
+ """
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None)
+ self.assertEqual(f.from_native(''), None)
+ self.assertEqual(f.from_native(None), None)
class EmailFieldTests(TestCase):
diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py
index d6bc7895..cadb515f 100644
--- a/rest_framework/tests/test_pagination.py
+++ b/rest_framework/tests/test_pagination.py
@@ -430,3 +430,88 @@ class TestCustomPaginationSerializer(TestCase):
'objects': ['john', 'paul']
}
self.assertEqual(serializer.data, expected)
+
+
+class NonIntegerPage(object):
+
+ def __init__(self, paginator, object_list, prev_token, token, next_token):
+ self.paginator = paginator
+ self.object_list = object_list
+ self.prev_token = prev_token
+ self.token = token
+ self.next_token = next_token
+
+ def has_next(self):
+ return not not self.next_token
+
+ def next_page_number(self):
+ return self.next_token
+
+ def has_previous(self):
+ return not not self.prev_token
+
+ def previous_page_number(self):
+ return self.prev_token
+
+
+class NonIntegerPaginator(object):
+
+ def __init__(self, object_list, per_page):
+ self.object_list = object_list
+ self.per_page = per_page
+
+ def count(self):
+ # pretend like we don't know how many pages we have
+ return None
+
+ def page(self, token=None):
+ if token:
+ try:
+ first = self.object_list.index(token)
+ except ValueError:
+ first = 0
+ else:
+ first = 0
+ n = len(self.object_list)
+ last = min(first + self.per_page, n)
+ prev_token = self.object_list[last - (2 * self.per_page)] if first else None
+ next_token = self.object_list[last] if last < n else None
+ return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token)
+
+
+class TestNonIntegerPagination(TestCase):
+
+
+ def test_custom_pagination_serializer(self):
+ objects = ['john', 'paul', 'george', 'ringo']
+ paginator = NonIntegerPaginator(objects, 2)
+
+ request = APIRequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=paginator.page(),
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': 'http://testserver/foobar?page={0}'.format(objects[2]),
+ 'prev': None
+ },
+ 'total_results': None,
+ 'objects': objects[:2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ request = APIRequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=paginator.page('george'),
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': None,
+ 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]),
+ },
+ 'total_results': None,
+ 'objects': objects[2:]
+ }
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py
index df6f4aa6..76299a89 100644
--- a/rest_framework/tests/test_renderers.py
+++ b/rest_framework/tests/test_renderers.py
@@ -328,7 +328,7 @@ if yaml:
class YAMLRendererTests(TestCase):
"""
- Tests specific to the JSON Renderer
+ Tests specific to the YAML Renderer
"""
def test_render(self):
@@ -354,6 +354,17 @@ if yaml:
data = parser.parse(StringIO(content))
self.assertEqual(obj, data)
+ def test_render_decimal(self):
+ """
+ Test YAML decimal rendering.
+ """
+ renderer = YAMLRenderer()
+ content = renderer.render({'field': Decimal('111.2')}, 'application/yaml')
+ self.assertYAMLContains(content, "field: '111.2'")
+
+ def assertYAMLContains(self, content, string):
+ self.assertTrue(string in content, '%r not in %r' % (string, content))
+
class XMLRendererTestCase(TestCase):
"""
diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py
index 969d8024..a60e7615 100644
--- a/rest_framework/tests/test_request.py
+++ b/rest_framework/tests/test_request.py
@@ -5,6 +5,7 @@ from __future__ import unicode_literals
from django.contrib.auth.models import User
from django.contrib.auth import authenticate, login, logout
from django.contrib.sessions.middleware import SessionMiddleware
+from django.core.handlers.wsgi import WSGIRequest
from django.test import TestCase
from rest_framework import status
from rest_framework.authentication import SessionAuthentication
@@ -15,12 +16,13 @@ from rest_framework.parsers import (
MultiPartParser,
JSONParser
)
-from rest_framework.request import Request
+from rest_framework.request import Request, Empty
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView
from rest_framework.compat import six
+from io import BytesIO
import json
@@ -146,6 +148,34 @@ class TestContentParsing(TestCase):
request.parsers = (JSONParser(), )
self.assertEqual(request.DATA, json_data)
+ def test_form_POST_unicode(self):
+ """
+ JSON POST via default web interface with unicode data
+ """
+ # Note: environ and other variables here have simplified content compared to real Request
+ CONTENT = b'_content_type=application%2Fjson&_content=%7B%22request%22%3A+4%2C+%22firm%22%3A+1%2C+%22text%22%3A+%22%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%21%22%7D'
+ environ = {
+ 'REQUEST_METHOD': 'POST',
+ 'CONTENT_TYPE': 'application/x-www-form-urlencoded',
+ 'CONTENT_LENGTH': len(CONTENT),
+ 'wsgi.input': BytesIO(CONTENT),
+ }
+ wsgi_request = WSGIRequest(environ=environ)
+ wsgi_request._load_post_and_files()
+ parsers = (JSONParser(), FormParser(), MultiPartParser())
+ parser_context = {
+ 'encoding': 'utf-8',
+ 'kwargs': {},
+ 'args': (),
+ }
+ request = Request(wsgi_request, parsers=parsers, parser_context=parser_context)
+ method = request.method
+ self.assertEqual(method, 'POST')
+ self.assertEqual(request._content_type, 'application/json')
+ self.assertEqual(request._stream.getvalue(), b'{"request": 4, "firm": 1, "text": "\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82!"}')
+ self.assertEqual(request._data, Empty)
+ self.assertEqual(request._files, Empty)
+
# def test_accessing_post_after_data_form(self):
# """
# Ensures request.POST can be accessed after request.DATA in
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index 29c08fbc..1f85a474 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -511,6 +511,33 @@ class CustomValidationTests(TestCase):
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']})
+ def test_partial_update(self):
+ """
+ Make sure that validate_email isn't called when partial=True and email
+ isn't found in data.
+ """
+ initial_data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(data=initial_data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.object
+
+ new_content = 'An *updated* test comment'
+ partial_data = {
+ 'content': new_content
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(instance=instance,
+ data=partial_data,
+ partial=True)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.object
+ self.assertEqual(instance.content, new_content)
+
class PositiveIntegerAsChoiceTests(TestCase):
def test_positive_integer_in_json_is_correctly_parsed(self):
diff --git a/rest_framework/tests/test_serializer_nested.py b/rest_framework/tests/test_serializer_nested.py
index 71d0e24b..029f8bff 100644
--- a/rest_framework/tests/test_serializer_nested.py
+++ b/rest_framework/tests/test_serializer_nested.py
@@ -244,3 +244,70 @@ class WritableNestedSerializerObjectTests(TestCase):
serializer = self.AlbumSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.object, expected_object)
+
+
+class ForeignKeyNestedSerializerUpdateTests(TestCase):
+ def setUp(self):
+ class Artist(object):
+ def __init__(self, name):
+ self.name = name
+
+ def __eq__(self, other):
+ return self.name == other.name
+
+ class Album(object):
+ def __init__(self, name, artist):
+ self.name, self.artist = name, artist
+
+ def __eq__(self, other):
+ return self.name == other.name and self.artist == other.artist
+
+ class ArtistSerializer(serializers.Serializer):
+ name = serializers.CharField()
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.name = attrs['name']
+ else:
+ instance = Artist(attrs['name'])
+ return instance
+
+ class AlbumSerializer(serializers.Serializer):
+ name = serializers.CharField()
+ by = ArtistSerializer(source='artist')
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.name = attrs['name']
+ instance.artist = attrs['artist']
+ else:
+ instance = Album(attrs['name'], attrs['artist'])
+ return instance
+
+ self.Artist = Artist
+ self.Album = Album
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_create_via_foreign_key_with_source(self):
+ """
+ Check that we can both *create* and *update* into objects across
+ ForeignKeys that have a `source` specified.
+ Regression test for #1170
+ """
+ data = {
+ 'name': 'Discovery',
+ 'by': {'name': 'Daft Punk'},
+ }
+
+ expected = self.Album(artist=self.Artist('Daft Punk'), name='Discovery')
+
+ # create
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+
+ # update
+ original = self.Album(artist=self.Artist('The Bats'), name='Free All the Monsters')
+ serializer = self.AlbumSerializer(instance=original, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 7efd5417..35ad206b 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -89,6 +89,9 @@ else:
node.flow_style = best_style
return node
+ SafeDumper.add_representer(decimal.Decimal,
+ SafeDumper.represent_decimal)
+
SafeDumper.add_representer(SortedDict,
yaml.representer.SafeRepresenter.represent_dict)
SafeDumper.add_representer(DictWithMetadata,
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 853e6461..e863af6d 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -154,8 +154,8 @@ class APIView(View):
Returns a dict that is passed through to Parser.parse(),
as the `parser_context` keyword argument.
"""
- # Note: Additionally `request` will also be added to the context
- # by the Request object.
+ # Note: Additionally `request` and `encoding` will also be added
+ # to the context by the Request object.
return {
'view': self,
'args': getattr(self, 'args', ()),
diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py
index d91323f2..7eb29f99 100644
--- a/rest_framework/viewsets.py
+++ b/rest_framework/viewsets.py
@@ -9,7 +9,7 @@ Actions are only bound to methods at the point of instantiating the views.
user_detail = UserViewSet.as_view({'get': 'retrieve'})
Typically, rather than instantiate views from viewsets directly, you'll
-regsiter the viewset with a router and let the URL conf be determined
+register the viewset with a router and let the URL conf be determined
automatically.
router = DefaultRouter()