diff options
| author | Tom Christie | 2014-09-06 07:20:31 +0100 | 
|---|---|---|
| committer | Tom Christie | 2014-09-06 07:20:31 +0100 | 
| commit | e8fac28d8848dce62a31879e07300842bd1755bd (patch) | |
| tree | 4bfd3d9ab923af4a8addc297a557785b7204cad8 | |
| parent | 5bbfef36f46979591ad599c56126a8698a47513a (diff) | |
| parent | e437520217e20d500d641b95482d49484b1f24a7 (diff) | |
| download | django-rest-framework-e8fac28d8848dce62a31879e07300842bd1755bd.tar.bz2 | |
Merge pull request #1818 from tituomin/serializer-subclass-mapping
Better mapping for custom model fields to serializer fields.
| -rw-r--r-- | rest_framework/serializers.py | 28 | 
1 files changed, 22 insertions, 6 deletions
| diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index be8ad3f2..b3db3582 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -625,6 +625,20 @@ class ModelSerializerOptions(SerializerOptions):          self.write_only_fields = getattr(meta, 'write_only_fields', ()) +def _get_class_mapping(mapping, obj): +    """ +    Takes a dictionary with classes as keys, and an object. +    Traverses the object's inheritance hierarchy in method +    resolution order, and returns the first matching value +    from the dictionary or None. + +    """ +    return next( +        (mapping[cls] for cls in inspect.getmro(obj.__class__) if cls in mapping), +        None +    ) + +  class ModelSerializer(Serializer):      """      A serializer that deals with model instances and querysets. @@ -899,15 +913,17 @@ class ModelSerializer(Serializer):              models.URLField: ['max_length'],          } -        if model_field.__class__ in attribute_dict: -            attributes = attribute_dict[model_field.__class__] +        attributes = _get_class_mapping(attribute_dict, model_field) +        if attributes:              for attribute in attributes:                  kwargs.update({attribute: getattr(model_field, attribute)}) -        try: -            return self.field_mapping[model_field.__class__](**kwargs) -        except KeyError: -            return ModelField(model_field=model_field, **kwargs) +        serializer_field_class = _get_class_mapping( +            self.field_mapping, model_field) + +        if serializer_field_class: +            return serializer_field_class(**kwargs) +        return ModelField(model_field=model_field, **kwargs)      def get_validation_exclusions(self, instance=None):          """ | 
