diff options
Diffstat (limited to 'rest_framework/relations.py')
| -rw-r--r-- | rest_framework/relations.py | 154 |
1 files changed, 133 insertions, 21 deletions
diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 5aa1f8bd..d1ea497a 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,12 +1,32 @@ from rest_framework.compat import smart_text, urlparse -from rest_framework.fields import Field +from rest_framework.fields import get_attribute, empty, Field from rest_framework.reverse import reverse +from rest_framework.utils import html from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch, Resolver404 from django.db.models.query import QuerySet +from django.utils import six from django.utils.translation import ugettext_lazy as _ +class PKOnlyObject(object): + """ + This is a mock object, used for when we only need the pk of the object + instance, but still want to return an object with a .pk attribute, + in order to keep the same interface as a regular model instance. + """ + def __init__(self, pk): + self.pk = pk + + +# We assume that 'validators' are intended for the child serializer, +# rather than the parent serializer. +MANY_RELATION_KWARGS = ( + 'read_only', 'write_only', 'required', 'default', 'initial', 'source', + 'label', 'help_text', 'style', 'error_messages' +) + + class RelatedField(Field): def __init__(self, **kwargs): self.queryset = kwargs.pop('queryset', None) @@ -22,14 +42,40 @@ class RelatedField(Field): def __new__(cls, *args, **kwargs): # We override this method in order to automagically create - # `ManyRelation` classes instead when `many=True` is set. + # `ManyRelatedField` classes instead when `many=True` is set. if kwargs.pop('many', False): - return ManyRelation( - child_relation=cls(*args, **kwargs), - read_only=kwargs.get('read_only', False) - ) + return cls.many_init(*args, **kwargs) return super(RelatedField, cls).__new__(cls, *args, **kwargs) + @classmethod + def many_init(cls, *args, **kwargs): + """ + This method handles creating a parent `ManyRelatedField` instance + when the `many=True` keyword argument is passed. + + Typically you won't need to override this method. + + Note that we're over-cautious in passing most arguments to both parent + and child classes in order to try to cover the general case. If you're + overriding this method you'll probably want something much simpler, eg: + + @classmethod + def many_init(cls, *args, **kwargs): + kwargs['child'] = cls() + return CustomManyRelatedField(*args, **kwargs) + """ + list_kwargs = {'child_relation': cls(*args, **kwargs)} + for key in kwargs.keys(): + if key in MANY_RELATION_KWARGS: + list_kwargs[key] = kwargs[key] + return ManyRelatedField(**list_kwargs) + + def run_validation(self, data=empty): + # We force empty strings to None values for relational fields. + if data == '': + data = None + return super(RelatedField, self).run_validation(data) + def get_queryset(self): queryset = self.queryset if isinstance(queryset, QuerySet): @@ -37,8 +83,22 @@ class RelatedField(Field): queryset = queryset.all() return queryset + def get_iterable(self, instance, source_attrs): + relationship = get_attribute(instance, source_attrs) + return relationship.all() if (hasattr(relationship, 'all')) else relationship + + @property + def choices(self): + return dict([ + ( + str(self.to_representation(item)), + str(item) + ) + for item in self.queryset.all() + ]) + -class StringRelatedField(Field): +class StringRelatedField(RelatedField): """ A read only field that represents its targets using their plain string representation. @@ -49,7 +109,7 @@ class StringRelatedField(Field): super(StringRelatedField, self).__init__(**kwargs) def to_representation(self, value): - return str(value) + return six.text_type(value) class PrimaryKeyRelatedField(RelatedField): @@ -67,6 +127,32 @@ class PrimaryKeyRelatedField(RelatedField): except (TypeError, ValueError): self.fail('incorrect_type', data_type=type(data).__name__) + def get_attribute(self, instance): + # We customize `get_attribute` here for performance reasons. + # For relationships the instance will already have the pk of + # the related object. We return this directly instead of returning the + # object itself, which would require a database lookup. + try: + instance = get_attribute(instance, self.source_attrs[:-1]) + return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1])) + except AttributeError: + return get_attribute(instance, self.source_attrs) + + def get_iterable(self, instance, source_attrs): + # For consistency with `get_attribute` we're using `serializable_value()` + # here. Typically there won't be any difference, but some custom field + # types might return a non-primative value for the pk otherwise. + # + # We could try to get smart with `values_list('pk', flat=True)`, which + # would be better in some case, but would actually end up with *more* + # queries if the developer is using `prefetch_related` across the + # relationship. + relationship = super(PrimaryKeyRelatedField, self).get_iterable(instance, source_attrs) + return [ + PKOnlyObject(pk=item.serializable_value('pk')) + for item in relationship + ] + def to_representation(self, value): return value.pk @@ -89,9 +175,9 @@ class HyperlinkedRelatedField(RelatedField): self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) self.format = kwargs.pop('format', None) - # We include these simply for dependancy injection in tests. + # We include these simply for dependency injection in tests. # We can't add them as class attributes or they would expect an - # implict `self` argument to be passed. + # implicit `self` argument to be passed. self.reverse = reverse self.resolve = resolve @@ -227,26 +313,33 @@ class SlugRelatedField(RelatedField): return getattr(obj, self.slug_field) -class ManyRelation(Field): +class ManyRelatedField(Field): """ Relationships with `many=True` transparently get coerced into instead being - a ManyRelation with a child relationship. + a ManyRelatedField with a child relationship. - The `ManyRelation` class is responsible for handling iterating through + The `ManyRelatedField` class is responsible for handling iterating through the values and passing each one to the child relationship. - You shouldn't need to be using this class directly yourself. + This class is treated as private API. + You shouldn't generally need to be using this class directly yourself, + and should instead simply set 'many=True' on the relationship. """ + initial = [] + default_empty_html = [] def __init__(self, child_relation=None, *args, **kwargs): self.child_relation = child_relation assert child_relation is not None, '`child_relation` is a required argument.' - super(ManyRelation, self).__init__(*args, **kwargs) + super(ManyRelatedField, self).__init__(*args, **kwargs) + self.child_relation.bind(field_name='', parent=self) - def bind(self, field_name, parent, root): - # ManyRelation needs to provide the current context to the child relation. - super(ManyRelation, self).bind(field_name, parent, root) - self.child_relation.bind(field_name, parent, root) + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if html.is_html_input(dictionary): + return dictionary.getlist(self.field_name) + return dictionary.get(self.field_name, empty) def to_internal_value(self, data): return [ @@ -254,8 +347,27 @@ class ManyRelation(Field): for item in data ] - def to_representation(self, obj): + def get_attribute(self, instance): + return self.child_relation.get_iterable(instance, self.source_attrs) + + def to_representation(self, iterable): return [ self.child_relation.to_representation(value) - for value in obj.all() + for value in iterable ] + + @property + def choices(self): + queryset = self.child_relation.queryset + iterable = queryset.all() if (hasattr(queryset, 'all')) else queryset + items_and_representations = [ + (item, self.child_relation.to_representation(item)) + for item in iterable + ] + return dict([ + ( + str(item_representation), + str(item) + ' - ' + str(item_representation) + ) + for item, item_representation in items_and_representations + ]) |
