aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2014-09-25 11:40:32 +0100
committerTom Christie2014-09-25 11:40:32 +0100
commit64632da3718f501cb8174243385d38b547c2fefd (patch)
treecb1f7eba968fe61a302639008ba6c7e32678c329
parentb22c9602fa0f717b688fdb35e4f6f42c189af3f3 (diff)
downloaddjango-rest-framework-64632da3718f501cb8174243385d38b547c2fefd.tar.bz2
Clean up bind - no longer needs to be called multiple times in nested fields
-rw-r--r--rest_framework/fields.py21
-rw-r--r--rest_framework/relations.py5
-rw-r--r--rest_framework/serializers.py26
-rw-r--r--tests/test_relations.py8
4 files changed, 29 insertions, 31 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index d1aebbaf..446732c3 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -109,7 +109,8 @@ class Field(object):
def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None,
label=None, help_text=None, style=None,
- error_messages=None, validators=[], allow_null=False):
+ error_messages=None, validators=[], allow_null=False,
+ context=None):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
@@ -135,6 +136,11 @@ class Field(object):
self.validators = validators or self.default_validators[:]
self.allow_null = allow_null
+ # These are set up by `.bind()` when the field is added to a serializer.
+ self.field_name = None
+ self.parent = None
+ self._context = {} if (context is None) else context
+
# Collect default error message from self and parent classes
messages = {}
for cls in reversed(self.__class__.__mro__):
@@ -157,7 +163,14 @@ class Field(object):
kwargs = copy.deepcopy(self._kwargs)
return self.__class__(*args, **kwargs)
- def bind(self, field_name, parent, root):
+ @property
+ def context(self):
+ root = self
+ while root.parent is not None:
+ root = root.parent
+ return root._context
+
+ def bind(self, field_name, parent):
"""
Setup the context for the field instance.
"""
@@ -174,10 +187,8 @@ class Field(object):
self.field_name = field_name
self.parent = parent
- self.root = root
- self.context = parent.context
- # `self.label` should deafult to being based on the field name.
+ # `self.label` should default to being based on the field name.
if self.label is None:
self.label = field_name.replace('_', ' ').capitalize()
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 5aa1f8bd..b37a6fed 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -243,11 +243,6 @@ class ManyRelation(Field):
assert child_relation is not None, '`child_relation` is a required argument.'
super(ManyRelation, self).__init__(*args, **kwargs)
- def bind(self, field_name, parent, root):
- # ManyRelation needs to provide the current context to the child relation.
- super(ManyRelation, self).bind(field_name, parent, root)
- self.child_relation.bind(field_name, parent, root)
-
def to_internal_value(self, data):
return [
self.child_relation.to_internal_value(item)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 12e38090..04721c7a 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -150,13 +150,20 @@ class SerializerMetaclass(type):
class BindingDict(object):
+ """
+ This dict-like object is used to store fields on a serializer.
+
+ This ensures that whenever fields are added to the serializer we call
+ `field.bind()` so that the `field_name` and `parent` attributes
+ can be set correctly.
+ """
def __init__(self, serializer):
self.serializer = serializer
self.fields = SortedDict()
def __setitem__(self, key, field):
self.fields[key] = field
- field.bind(field_name=key, parent=self.serializer, root=self.serializer)
+ field.bind(field_name=key, parent=self.serializer)
def __getitem__(self, key):
return self.fields[key]
@@ -174,7 +181,6 @@ class BindingDict(object):
@six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer):
def __init__(self, *args, **kwargs):
- self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
kwargs.pop('many', None)
@@ -198,13 +204,6 @@ class Serializer(BaseSerializer):
def _get_base_fields(self):
return copy.deepcopy(self._declared_fields)
- def bind(self, field_name, parent, root):
- # If the serializer is used as a field then when it becomes bound
- # it also needs to bind all its child fields.
- super(Serializer, self).bind(field_name, parent, root)
- for field_name, field in self.fields.items():
- field.bind(field_name, self, root)
-
def get_initial(self):
return dict([
(field.field_name, field.get_initial())
@@ -290,17 +289,10 @@ class ListSerializer(BaseSerializer):
self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.'
assert not inspect.isclass(self.child), '`child` has not been instantiated.'
- self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
super(ListSerializer, self).__init__(*args, **kwargs)
- self.child.bind('', self, self)
-
- def bind(self, field_name, parent, root):
- # If the list is used as a field then it needs to provide
- # the current context to the child serializer.
- super(ListSerializer, self).bind(field_name, parent, root)
- self.child.bind(field_name, self, root)
+ self.child.bind(field_name='', parent=self)
def get_value(self, dictionary):
# We override the default field access in order to support
diff --git a/tests/test_relations.py b/tests/test_relations.py
index c29618ce..2d11672b 100644
--- a/tests/test_relations.py
+++ b/tests/test_relations.py
@@ -51,7 +51,7 @@ class TestHyperlinkedIdentityField(APISimpleTestCase):
self.instance = MockObject(pk=1, name='foo')
self.field = serializers.HyperlinkedIdentityField(view_name='example')
self.field.reverse = mock_reverse
- self.field.context = {'request': True}
+ self.field._context = {'request': True}
def test_representation(self):
representation = self.field.to_representation(self.instance)
@@ -62,7 +62,7 @@ class TestHyperlinkedIdentityField(APISimpleTestCase):
assert representation is None
def test_representation_with_format(self):
- self.field.context['format'] = 'xml'
+ self.field._context['format'] = 'xml'
representation = self.field.to_representation(self.instance)
assert representation == 'http://example.org/example/1.xml/'
@@ -91,14 +91,14 @@ class TestHyperlinkedIdentityFieldWithFormat(APISimpleTestCase):
self.instance = MockObject(pk=1, name='foo')
self.field = serializers.HyperlinkedIdentityField(view_name='example', format='json')
self.field.reverse = mock_reverse
- self.field.context = {'request': True}
+ self.field._context = {'request': True}
def test_representation(self):
representation = self.field.to_representation(self.instance)
assert representation == 'http://example.org/example/1/'
def test_representation_with_format(self):
- self.field.context['format'] = 'xml'
+ self.field._context['format'] = 'xml'
representation = self.field.to_representation(self.instance)
assert representation == 'http://example.org/example/1.json/'