aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--rest_framework/serializers.py28
1 files changed, 22 insertions, 6 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index be8ad3f2..b3db3582 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -625,6 +625,20 @@ class ModelSerializerOptions(SerializerOptions):
self.write_only_fields = getattr(meta, 'write_only_fields', ())
+def _get_class_mapping(mapping, obj):
+ """
+ Takes a dictionary with classes as keys, and an object.
+ Traverses the object's inheritance hierarchy in method
+ resolution order, and returns the first matching value
+ from the dictionary or None.
+
+ """
+ return next(
+ (mapping[cls] for cls in inspect.getmro(obj.__class__) if cls in mapping),
+ None
+ )
+
+
class ModelSerializer(Serializer):
"""
A serializer that deals with model instances and querysets.
@@ -899,15 +913,17 @@ class ModelSerializer(Serializer):
models.URLField: ['max_length'],
}
- if model_field.__class__ in attribute_dict:
- attributes = attribute_dict[model_field.__class__]
+ attributes = _get_class_mapping(attribute_dict, model_field)
+ if attributes:
for attribute in attributes:
kwargs.update({attribute: getattr(model_field, attribute)})
- try:
- return self.field_mapping[model_field.__class__](**kwargs)
- except KeyError:
- return ModelField(model_field=model_field, **kwargs)
+ serializer_field_class = _get_class_mapping(
+ self.field_mapping, model_field)
+
+ if serializer_field_class:
+ return serializer_field_class(**kwargs)
+ return ModelField(model_field=model_field, **kwargs)
def get_validation_exclusions(self, instance=None):
"""