diff options
| author | Tom Christie | 2012-10-09 17:49:04 +0100 |
|---|---|---|
| committer | Tom Christie | 2012-10-09 17:49:04 +0100 |
| commit | 9bbc1cc403e3cf171710ae02255e2b7e6185f823 (patch) | |
| tree | 63f81aeb23e0586d34dc4f8439666c9d3446613c /rest_framework | |
| parent | b0c370dd2b42db9074c2580ca4a48d7dda088abf (diff) | |
| download | django-rest-framework-9bbc1cc403e3cf171710ae02255e2b7e6185f823.tar.bz2 | |
Add flag in get_related_field
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/serializers.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 2141619f..0faad703 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -323,7 +323,9 @@ class ModelSerializer(Serializer): elif model_field.rel and nested: field = self.get_nested_field(model_field) elif model_field.rel: - field = self.get_related_field(model_field) + to_many = isinstance(model_field, + models.fields.related.ManyToManyField) + field = self.get_related_field(model_field, to_many=to_many) else: field = self.get_field(model_field) @@ -345,14 +347,14 @@ class ModelSerializer(Serializer): """ return ModelSerializer() - def get_related_field(self, model_field): + def get_related_field(self, model_field, to_many=False): """ Creates a default instance of a flat relational field. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) queryset = model_field.rel.to._default_manager - if isinstance(model_field, models.fields.related.ManyToManyField): + if to_many: return ManyPrimaryKeyRelatedField(queryset=queryset) return PrimaryKeyRelatedField(queryset=queryset) @@ -446,7 +448,7 @@ class HyperlinkedModelSerializer(ModelSerializer): def get_pk_field(self, model_field): return None - def get_related_field(self, model_field): + def get_related_field(self, model_field, to_many): """ Creates a default instance of a flat relational field. """ @@ -458,6 +460,6 @@ class HyperlinkedModelSerializer(ModelSerializer): 'queryset': queryset, 'view_name': self._get_default_view_name(rel) } - if isinstance(model_field, models.fields.related.ManyToManyField): + if to_many: return ManyHyperlinkedRelatedField(**kwargs) return HyperlinkedRelatedField(**kwargs) |
