aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--rest_framework/fields.py71
-rw-r--r--rest_framework/renderers.py14
-rw-r--r--rest_framework/serializers.py7
-rw-r--r--rest_framework/templates/rest_framework/base.html8
-rw-r--r--rest_framework/tests/serializer.py42
5 files changed, 107 insertions, 35 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 85ee5430..edc77e1a 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -7,7 +7,6 @@ from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS
-from django.db.models.related import RelatedObject
from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import parse_date, parse_datetime
@@ -181,6 +180,9 @@ class RelatedField(Field):
Subclass this and override `convert` to define custom behaviour when
serializing related objects.
"""
+ def __init__(self, *args, **kwargs):
+ self.queryset = kwargs.pop('queryset', None)
+ super(RelatedField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name):
obj = getattr(obj, self.source or field_name)
@@ -200,48 +202,61 @@ class RelatedField(Field):
class PrimaryKeyRelatedField(RelatedField):
"""
- Serializes a model related field or related manager to a pk value.
+ Serializes a related field or related object to a pk value.
"""
- # Note the we use ModelRelatedField's implementation, as we want to get the
- # raw database value directly, since that won't involve another
- # database lookup.
- #
- # An alternative implementation would simply be this...
- #
- # class PrimaryKeyRelatedField(RelatedField):
- # def to_native(self, obj):
- # return obj.pk
-
def to_native(self, pk):
"""
- Simply returns the object's pk. You can subclass this method to
- provide different serialization behavior of the pk.
- (For example returning a URL based on the model's pk.)
+ You can subclass this method to provide different serialization
+ behavior based on the pk.
"""
return pk
def field_to_native(self, obj, field_name):
+ # This is only implemented for performance reasons
+ #
+ # We could leave the default `RelatedField.field_to_native()` in place,
+ # and inside just implement `to_native()` as `return obj.pk`
+ #
+ # That would involve an extra database lookup.
try:
- obj = obj.serializable_value(self.source or field_name)
+ pk = obj.serializable_value(self.source or field_name)
except AttributeError:
- field = obj._meta.get_field_by_name(field_name)[0]
+ # RelatedObject (reverse relationship)
obj = getattr(obj, self.source or field_name)
- if obj.__class__.__name__ == 'RelatedManager':
- return [self.to_native(item.pk) for item in obj.all()]
- elif isinstance(field, RelatedObject):
- return self.to_native(obj.pk)
- raise
- if obj.__class__.__name__ == 'ManyRelatedManager':
- return [self.to_native(item.pk) for item in obj.all()]
- return self.to_native(obj)
+ return self.to_native(obj.pk)
+ # Forward relationship
+ return self.to_native(pk)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
- if hasattr(value, '__iter__'):
- into[field_name] = [self.from_native(item) for item in value]
+ into[field_name + '_id'] = self.from_native(value)
+
+
+class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
+ """
+ Serializes a to-many related field or related manager to a pk value.
+ """
+
+ def field_to_native(self, obj, field_name):
+ try:
+ queryset = obj.serializable_value(self.source or field_name)
+ except AttributeError:
+ # RelatedManager (reverse relationship)
+ queryset = getattr(obj, self.source or field_name)
+ return [self.to_native(item.pk) for item in queryset.all()]
+ # Forward relationship
+ return [self.to_native(item.pk) for item in queryset.all()]
+
+ def field_from_native(self, data, field_name, into):
+ try:
+ value = data.getlist(field_name)
+ except:
+ value = data.get(field_name)
else:
- into[field_name + '_id'] = self.from_native(value)
+ if value == ['']:
+ value = []
+ into[field_name] = [self.from_native(item) for item in value]
class NaturalKeyRelatedField(RelatedField):
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 9484e29b..5bc5d5f8 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -246,7 +246,9 @@ class DocumentingHTMLRenderer(BaseRenderer):
serializers.DateField: forms.DateField,
serializers.EmailField: forms.EmailField,
serializers.CharField: forms.CharField,
- serializers.BooleanField: forms.BooleanField
+ serializers.BooleanField: forms.BooleanField,
+ serializers.PrimaryKeyRelatedField: forms.ModelChoiceField,
+ serializers.ManyPrimaryKeyRelatedField: forms.ModelMultipleChoiceField
}
# Creating an on the fly form see: http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
@@ -257,12 +259,18 @@ class DocumentingHTMLRenderer(BaseRenderer):
serializer = view.get_serializer(instance=obj)
for k, v in serializer.get_fields(True).items():
+ print k, v
if v.readonly:
continue
+
+ kwargs = {}
+ if getattr(v, 'queryset', None):
+ kwargs['queryset'] = getattr(v, 'queryset', None)
+
try:
- fields[k] = field_mapping[v.__class__]()
+ fields[k] = field_mapping[v.__class__](**kwargs)
except KeyError:
- fields[k] = forms.CharField
+ fields[k] = forms.CharField()
OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields)
if obj and not view.request.method == 'DELETE': # Don't fill in the form when the object is deleted
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 683b9efc..03763824 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -351,7 +351,10 @@ class ModelSerializer(RelatedField, Serializer):
"""
Creates a default instance of a flat relational field.
"""
- return PrimaryKeyRelatedField()
+ queryset = model_field.rel.to._default_manager # .using(db).complex_filter(self.rel.limit_choices_to)
+ if isinstance(model_field, models.fields.related.ManyToManyField):
+ return ManyPrimaryKeyRelatedField(queryset=queryset)
+ return PrimaryKeyRelatedField(queryset=queryset)
def get_field(self, model_field):
"""
@@ -365,7 +368,7 @@ class ModelSerializer(RelatedField, Serializer):
models.EmailField: EmailField,
models.CharField: CharField,
models.CommaSeparatedIntegerField: CharField,
- models.BooleanField: BooleanField
+ models.BooleanField: BooleanField,
}
try:
return field_mapping[model_field.__class__]()
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 867051e6..a5e08942 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -122,6 +122,7 @@
{% if response.status_code != 403 %}
{% if post_form %}
+ <div class="well">
<form action="{{ request.get_full_path }}" method="POST" {% if post_form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
<fieldset>
<h2>POST: {{ name }}</h2>
@@ -131,7 +132,7 @@
<div class="control-group {% if field.errors %}error{% endif %}">
{{ field.label_tag|add_class:"control-label" }}
<div class="controls">
- {{ field }}
+ {{ field|add_class:"input-xlarge" }}
<span class="help-inline">{{ field.help_text }}</span>
{{ field.errors|add_class:"help-block" }}
</div>
@@ -142,9 +143,11 @@
</div>
</fieldset>
</form>
+ </div>
{% endif %}
{% if put_form %}
+ <div class="well">
<form action="{{ request.get_full_path }}" method="POST" {% if put_form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
<fieldset>
<h2>PUT: {{ name }}</h2>
@@ -155,7 +158,7 @@
<div class="control-group {% if field.errors %}error{% endif %}">
{{ field.label_tag|add_class:"control-label" }}
<div class="controls">
- {{ field }}
+ {{ field|add_class:"input-xlarge" }}
<span class='help-inline'>{{ field.help_text }}</span>
{{ field.errors|add_class:"help-block" }}
</div>
@@ -167,6 +170,7 @@
</fieldset>
</form>
+ </div>
{% endif %}
{% endif %}
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index f7412a32..db342c9e 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -160,6 +160,48 @@ class ManyToManyTests(TestCase):
self.assertEquals(instance.pk, 1)
self.assertEquals(list(instance.rel.all()), [self.anchor, new_anchor])
+ def test_create_empty_relationship(self):
+ """
+ Create an instance of a model with a ManyToMany relationship,
+ containing no items.
+ """
+ data = {'rel': []}
+ serializer = self.serializer_class(data)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ManyToManyModel.objects.all()), 2)
+ self.assertEquals(instance.pk, 2)
+ self.assertEquals(list(instance.rel.all()), [])
+
+ def test_update_empty_relationship(self):
+ """
+ Update an instance of a model with a ManyToMany relationship,
+ containing no items.
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {'rel': []}
+ serializer = self.serializer_class(data, instance=self.instance)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ManyToManyModel.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ self.assertEquals(list(instance.rel.all()), [])
+
+ def test_create_empty_relationship_flat_data(self):
+ """
+ Create an instance of a model with a ManyToMany relationship,
+ containing no items, using a representation that does not support
+ lists (eg form data).
+ """
+ data = {'rel': ''}
+ serializer = self.serializer_class(data)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ManyToManyModel.objects.all()), 2)
+ self.assertEquals(instance.pk, 2)
+ self.assertEquals(list(instance.rel.all()), [])
+
# def test_deserialization_for_update(self):
# serializer = self.serializer_class(self.data, instance=self.instance)
# expected = self.instance