diff options
| author | Tom Christie | 2014-09-11 20:22:32 +0100 | 
|---|---|---|
| committer | Tom Christie | 2014-09-11 20:22:32 +0100 | 
| commit | ab40780dc2f341a271c2f489659dcd48eb47c07d (patch) | |
| tree | b15c9999a535e9ae9a268ecf44f4fdbb4eef2888 | |
| parent | 3318f75a7166cbac76a40d0461ca7b3e4640d3a2 (diff) | |
| download | django-rest-framework-ab40780dc2f341a271c2f489659dcd48eb47c07d.tar.bz2 | |
Tidy up lookup_class
| -rw-r--r-- | rest_framework/serializers.py | 21 | 
1 files changed, 11 insertions, 10 deletions
| diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8fe999ae..4322f213 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -317,17 +317,17 @@ class ModelSerializerOptions(object):          self.depth = getattr(meta, 'depth', 0) -def lookup_class(mapping, obj): +def lookup_class(mapping, instance):      """      Takes a dictionary with classes as keys, and an object.      Traverses the object's inheritance hierarchy in method      resolution order, and returns the first matching value -    from the dictionary or None. +    from the dictionary or raises a KeyError if nothing matches.      """ -    return next( -        (mapping[cls] for cls in inspect.getmro(obj.__class__) if cls in mapping), -        None -    ) +    for cls in inspect.getmro(instance.__class__): +        if cls in mapping: +            return mapping[cls] +    raise KeyError('Class %s not found in lookup.', cls.__name__)  class ModelSerializer(Serializer): @@ -341,6 +341,7 @@ class ModelSerializer(Serializer):          models.DateTimeField: DateTimeField,          models.DecimalField: DecimalField,          models.EmailField: EmailField, +        models.Field: ModelField,          models.FileField: FileField,          models.FloatField: FloatField,          models.ImageField: ImageField, @@ -484,6 +485,7 @@ class ModelSerializer(Serializer):          """          Creates a default instance of a basic non-relational field.          """ +        serializer_cls = lookup_class(self.field_mapping, model_field)          kwargs = {}          validator_kwarg = model_field.validators @@ -602,11 +604,10 @@ class ModelSerializer(Serializer):          if validator_kwarg:              kwargs['validators'] = validator_kwarg -        cls = lookup_class(self.field_mapping, model_field) -        if cls is None: -            cls = ModelField +        if issubclass(serializer_cls, ModelField):              kwargs['model_field'] = model_field -        return cls(**kwargs) + +        return serializer_cls(**kwargs)  class HyperlinkedModelSerializerOptions(ModelSerializerOptions): | 
