aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTimo Tuominen2014-09-01 15:54:33 +0300
committerTimo Tuominen2014-09-01 15:54:33 +0300
commit582f6fdd4b0fb12a7c0d1fefe265499a284c9b79 (patch)
treeadb29baa9968cc5363efffa9e8e7689d18ed98ec
parentae84b8b0e8a99261ea2436f77ab5238f21603c0c (diff)
downloaddjango-rest-framework-582f6fdd4b0fb12a7c0d1fefe265499a284c9b79.tar.bz2
Add utility function to match classes in dictionary.
-rw-r--r--rest_framework/serializers.py28
1 files changed, 22 insertions, 6 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index f37fbf98..5c33300c 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -625,6 +625,21 @@ class ModelSerializerOptions(SerializerOptions):
self.write_only_fields = getattr(meta, 'write_only_fields', ())
+def _get_class_mapping(mapping, obj):
+ """
+ Takes a dictionary with classes as keys, and an object.
+ Traverses the object's inheritance hierarchy in method
+ resolution order, and returns the first matching value
+ from the dictionary or None.
+
+ """
+ for baseclass in inspect.getmro(obj.__class__):
+ val = mapping.get(baseclass)
+ if val:
+ return val
+ return None
+
+
class ModelSerializer(Serializer):
"""
A serializer that deals with model instances and querysets.
@@ -899,15 +914,16 @@ class ModelSerializer(Serializer):
models.URLField: ['max_length'],
}
- if model_field.__class__ in attribute_dict:
- attributes = attribute_dict[model_field.__class__]
+ attributes = _get_class_mapping(attribute_dict, model_field)
+ if attributes:
for attribute in attributes:
kwargs.update({attribute: getattr(model_field, attribute)})
- for model_field_baseclass in inspect.getmro(model_field.__class__):
- serializer_field_class = self.field_mapping.get(model_field_baseclass)
- if serializer_field_class:
- return serializer_field_class(**kwargs)
+ serializer_field_class = _get_class_mapping(
+ self.field_mapping, model_field)
+
+ if serializer_field_class:
+ return serializer_field_class(**kwargs)
return ModelField(model_field=model_field, **kwargs)
def get_validation_exclusions(self, instance=None):