aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDustin Farris2014-01-12 20:28:19 -0500
committerDustin Farris2014-01-12 20:28:19 -0500
commit2332382b5109939238801e7d4c018455d159fe91 (patch)
tree5fb1fe74cfd2f44618c2c4081a8cfa306d8773fb
parentbf5b77ce6d171723f2d187aadd29c8ee4cdc3870 (diff)
downloaddjango-rest-framework-2332382b5109939238801e7d4c018455d159fe91.tar.bz2
Add a sanity check to avoid running into unresolved related models.
-rw-r--r--rest_framework/models.py23
-rw-r--r--rest_framework/serializers.py3
-rw-r--r--rest_framework/tests/test_models.py28
3 files changed, 52 insertions, 2 deletions
diff --git a/rest_framework/models.py b/rest_framework/models.py
index 5b53a526..249cdd82 100644
--- a/rest_framework/models.py
+++ b/rest_framework/models.py
@@ -1 +1,22 @@
-# Just to keep things like ./manage.py test happy
+import inspect
+
+from django.db import models
+
+
+def resolve_model(obj):
+ """
+ Resolve supplied `obj` to a Django model class.
+
+ `obj` must be a Django model class, or a string representation
+ of one.
+
+ String representations should have the format:
+ 'appname.ModelName'
+ """
+ if type(obj) == str and len(obj.split('.')) == 2:
+ app_name, model_name = obj.split('.')
+ return models.get_model(app_name, model_name)
+ elif inspect.isclass(obj) and issubclass(obj, models.Model):
+ return obj
+ else:
+ raise ValueError("{0} is not a valid Django model".format(obj))
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index b22ca578..6b31c304 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -20,6 +20,7 @@ from django.db import models
from django.forms import widgets
from django.utils.datastructures import SortedDict
from rest_framework.compat import get_concrete_model, six
+from rest_framework.models import resolve_model
# Note: We do the following so that users of the framework can use this style:
#
@@ -656,7 +657,7 @@ class ModelSerializer(Serializer):
if model_field.rel:
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
- related_model = model_field.rel.to
+ related_model = resolve_model(model_field.rel.to)
if to_many and not model_field.rel.through._meta.auto_created:
has_through_model = True
diff --git a/rest_framework/tests/test_models.py b/rest_framework/tests/test_models.py
new file mode 100644
index 00000000..5e92d48a
--- /dev/null
+++ b/rest_framework/tests/test_models.py
@@ -0,0 +1,28 @@
+from django.db import models
+from django.test import TestCase
+
+from rest_framework.models import resolve_model
+from rest_framework.tests.models import BasicModel
+
+
+class ResolveModelTests(TestCase):
+ """
+ `resolve_model` should return a Django model class given the
+ provided argument is a Django model class itself, or a properly
+ formatted string representation of one.
+ """
+ def test_resolve_django_model(self):
+ resolved_model = resolve_model(BasicModel)
+ self.assertEqual(resolved_model, BasicModel)
+
+ def test_resolve_string_representation(self):
+ resolved_model = resolve_model('tests.BasicModel')
+ self.assertEqual(resolved_model, BasicModel)
+
+ def test_resolve_non_django_model(self):
+ with self.assertRaises(ValueError):
+ resolve_model(TestCase)
+
+ def test_resolve_with_improper_string_representation(self):
+ with self.assertRaises(ValueError):
+ resolve_model('BasicModel')