diff options
| author | Timo Tuominen | 2014-09-01 15:54:33 +0300 |
|---|---|---|
| committer | Timo Tuominen | 2014-09-01 15:54:33 +0300 |
| commit | 582f6fdd4b0fb12a7c0d1fefe265499a284c9b79 (patch) | |
| tree | adb29baa9968cc5363efffa9e8e7689d18ed98ec | |
| parent | ae84b8b0e8a99261ea2436f77ab5238f21603c0c (diff) | |
| download | django-rest-framework-582f6fdd4b0fb12a7c0d1fefe265499a284c9b79.tar.bz2 | |
Add utility function to match classes in dictionary.
| -rw-r--r-- | rest_framework/serializers.py | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index f37fbf98..5c33300c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -625,6 +625,21 @@ class ModelSerializerOptions(SerializerOptions): self.write_only_fields = getattr(meta, 'write_only_fields', ()) +def _get_class_mapping(mapping, obj): + """ + Takes a dictionary with classes as keys, and an object. + Traverses the object's inheritance hierarchy in method + resolution order, and returns the first matching value + from the dictionary or None. + + """ + for baseclass in inspect.getmro(obj.__class__): + val = mapping.get(baseclass) + if val: + return val + return None + + class ModelSerializer(Serializer): """ A serializer that deals with model instances and querysets. @@ -899,15 +914,16 @@ class ModelSerializer(Serializer): models.URLField: ['max_length'], } - if model_field.__class__ in attribute_dict: - attributes = attribute_dict[model_field.__class__] + attributes = _get_class_mapping(attribute_dict, model_field) + if attributes: for attribute in attributes: kwargs.update({attribute: getattr(model_field, attribute)}) - for model_field_baseclass in inspect.getmro(model_field.__class__): - serializer_field_class = self.field_mapping.get(model_field_baseclass) - if serializer_field_class: - return serializer_field_class(**kwargs) + serializer_field_class = _get_class_mapping( + self.field_mapping, model_field) + + if serializer_field_class: + return serializer_field_class(**kwargs) return ModelField(model_field=model_field, **kwargs) def get_validation_exclusions(self, instance=None): |
