aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/relations.py')
-rw-r--r--rest_framework/relations.py34
1 files changed, 20 insertions, 14 deletions
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 4785c009..3b234dd5 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -33,6 +33,7 @@ class RelatedField(WritableField):
many_widget = widgets.SelectMultiple
form_field_class = forms.ChoiceField
many_form_field_class = forms.MultipleChoiceField
+ null_values = (None, '', 'None')
cache_choices = False
empty_label = None
@@ -50,6 +51,8 @@ class RelatedField(WritableField):
super(RelatedField, self).__init__(*args, **kwargs)
if not self.required:
+ # Accessed in ModelChoiceIterator django/forms/models.py:1034
+ # If set adds empty choice.
self.empty_label = BLANK_CHOICE_DASH[0][1]
self.queryset = queryset
@@ -57,16 +60,11 @@ class RelatedField(WritableField):
def initialize(self, parent, field_name):
super(RelatedField, self).initialize(parent, field_name)
if self.queryset is None and not self.read_only:
- try:
- manager = getattr(self.parent.opts.model, self.source or field_name)
- if hasattr(manager, 'related'): # Forward
- self.queryset = manager.related.model._default_manager.all()
- else: # Reverse
- self.queryset = manager.field.rel.to._default_manager.all()
- except Exception:
- msg = ('Serializer related fields must include a `queryset`' +
- ' argument or set `read_only=True')
- raise Exception(msg)
+ manager = getattr(self.parent.opts.model, self.source or field_name)
+ if hasattr(manager, 'related'): # Forward
+ self.queryset = manager.related.model._default_manager.all()
+ else: # Reverse
+ self.queryset = manager.field.rel.to._default_manager.all()
### We need this stuff to make form choices work...
@@ -115,6 +113,14 @@ class RelatedField(WritableField):
choices = property(_get_choices, _set_choices)
+ ### Default value handling
+
+ def get_default_value(self):
+ default = super(RelatedField, self).get_default_value()
+ if self.many and default is None:
+ return []
+ return default
+
### Regular serializer stuff...
def field_to_native(self, obj, field_name):
@@ -163,11 +169,11 @@ class RelatedField(WritableField):
except KeyError:
if self.partial:
return
- value = [] if self.many else None
+ value = self.get_default_value()
- if value in (None, '') and self.required:
- raise ValidationError(self.error_messages['required'])
- elif value in (None, ''):
+ if value in self.null_values:
+ if self.required:
+ raise ValidationError(self.error_messages['required'])
into[(self.source or field_name)] = None
elif self.many:
into[(self.source or field_name)] = [self.from_native(item) for item in value]