diff options
| author | Dustin Farris | 2014-01-12 20:28:19 -0500 |
|---|---|---|
| committer | Dustin Farris | 2014-01-12 20:28:19 -0500 |
| commit | 2332382b5109939238801e7d4c018455d159fe91 (patch) | |
| tree | 5fb1fe74cfd2f44618c2c4081a8cfa306d8773fb | |
| parent | bf5b77ce6d171723f2d187aadd29c8ee4cdc3870 (diff) | |
| download | django-rest-framework-2332382b5109939238801e7d4c018455d159fe91.tar.bz2 | |
Add a sanity check to avoid running into unresolved related models.
| -rw-r--r-- | rest_framework/models.py | 23 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 3 | ||||
| -rw-r--r-- | rest_framework/tests/test_models.py | 28 |
3 files changed, 52 insertions, 2 deletions
diff --git a/rest_framework/models.py b/rest_framework/models.py index 5b53a526..249cdd82 100644 --- a/rest_framework/models.py +++ b/rest_framework/models.py @@ -1 +1,22 @@ -# Just to keep things like ./manage.py test happy +import inspect + +from django.db import models + + +def resolve_model(obj): + """ + Resolve supplied `obj` to a Django model class. + + `obj` must be a Django model class, or a string representation + of one. + + String representations should have the format: + 'appname.ModelName' + """ + if type(obj) == str and len(obj.split('.')) == 2: + app_name, model_name = obj.split('.') + return models.get_model(app_name, model_name) + elif inspect.isclass(obj) and issubclass(obj, models.Model): + return obj + else: + raise ValueError("{0} is not a valid Django model".format(obj)) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b22ca578..6b31c304 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -20,6 +20,7 @@ from django.db import models from django.forms import widgets from django.utils.datastructures import SortedDict from rest_framework.compat import get_concrete_model, six +from rest_framework.models import resolve_model # Note: We do the following so that users of the framework can use this style: # @@ -656,7 +657,7 @@ class ModelSerializer(Serializer): if model_field.rel: to_many = isinstance(model_field, models.fields.related.ManyToManyField) - related_model = model_field.rel.to + related_model = resolve_model(model_field.rel.to) if to_many and not model_field.rel.through._meta.auto_created: has_through_model = True diff --git a/rest_framework/tests/test_models.py b/rest_framework/tests/test_models.py new file mode 100644 index 00000000..5e92d48a --- /dev/null +++ b/rest_framework/tests/test_models.py @@ -0,0 +1,28 @@ +from django.db import models +from django.test import TestCase + +from rest_framework.models import resolve_model +from rest_framework.tests.models import BasicModel + + +class ResolveModelTests(TestCase): + """ + `resolve_model` should return a Django model class given the + provided argument is a Django model class itself, or a properly + formatted string representation of one. + """ + def test_resolve_django_model(self): + resolved_model = resolve_model(BasicModel) + self.assertEqual(resolved_model, BasicModel) + + def test_resolve_string_representation(self): + resolved_model = resolve_model('tests.BasicModel') + self.assertEqual(resolved_model, BasicModel) + + def test_resolve_non_django_model(self): + with self.assertRaises(ValueError): + resolve_model(TestCase) + + def test_resolve_with_improper_string_representation(self): + with self.assertRaises(ValueError): + resolve_model('BasicModel') |
