diff options
| author | Tom Christie | 2014-09-12 17:03:42 +0100 | 
|---|---|---|
| committer | Tom Christie | 2014-09-12 17:03:42 +0100 | 
| commit | b73a205cc021983d9a508b447f30e144a1ce4129 (patch) | |
| tree | 2a414943c68c31867f25b67770142f04abc01dc3 | |
| parent | f95e7fae38968f58e742b93842bda9110a61b9f7 (diff) | |
| download | django-rest-framework-b73a205cc021983d9a508b447f30e144a1ce4129.tar.bz2 | |
Tests for relational fields (not including many=True)
| -rw-r--r-- | rest_framework/relations.py | 143 | ||||
| -rw-r--r-- | tests/test_relations.py | 139 | ||||
| -rw-r--r-- | tests/utils.py | 53 | 
3 files changed, 286 insertions, 49 deletions
| diff --git a/rest_framework/relations.py b/rest_framework/relations.py index e23a4152..75ec89a8 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,19 +1,24 @@ +from rest_framework.compat import smart_text, urlparse  from rest_framework.fields import Field  from rest_framework.reverse import reverse -from django.core.exceptions import ObjectDoesNotExist +from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured  from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch  from django.db.models.query import QuerySet -from rest_framework.compat import urlparse +from django.utils.translation import ugettext_lazy as _  class RelatedField(Field):      def __init__(self, **kwargs):          self.queryset = kwargs.pop('queryset', None)          self.many = kwargs.pop('many', False) -        assert self.queryset is not None or kwargs.get('read_only', False), ( +        assert self.queryset is not None or kwargs.get('read_only', None), (              'Relational field must provide a `queryset` argument, '              'or set read_only=`True`.'          ) +        assert not (self.queryset is not None and kwargs.get('read_only', None)), ( +            'Relational fields should not provide a `queryset` argument, ' +            'when setting read_only=`True`.' +        )          super(RelatedField, self).__init__(**kwargs)      def get_queryset(self): @@ -25,6 +30,11 @@ class RelatedField(Field):  class StringRelatedField(Field): +    """ +    A read only field that represents its targets using their +    plain string representation. +    """ +      def __init__(self, **kwargs):          kwargs['read_only'] = True          super(StringRelatedField, self).__init__(**kwargs) @@ -34,10 +44,10 @@ class StringRelatedField(Field):  class PrimaryKeyRelatedField(RelatedField): -    MESSAGES = { +    default_error_messages = {          'required': 'This field is required.',          'does_not_exist': "Invalid pk '{pk_value}' - object does not exist.", -        'incorrect_type': 'Incorrect type.  Expected pk value, received {data_type}.', +        'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.',      }      def to_internal_value(self, data): @@ -48,22 +58,33 @@ class PrimaryKeyRelatedField(RelatedField):          except (TypeError, ValueError):              self.fail('incorrect_type', data_type=type(data).__name__) +    def to_representation(self, value): +        return value.pk +  class HyperlinkedRelatedField(RelatedField):      lookup_field = 'pk' -    MESSAGES = { +    default_error_messages = {          'required': 'This field is required.',          'no_match': 'Invalid hyperlink - No URL match',          'incorrect_match': 'Invalid hyperlink - Incorrect URL match.', -        'does_not_exist': "Invalid hyperlink - Object does not exist.", -        'incorrect_type': 'Incorrect type.  Expected URL string, received {data_type}.', +        'does_not_exist': 'Invalid hyperlink - Object does not exist.', +        'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',      } -    def __init__(self, **kwargs): -        self.view_name = kwargs.pop('view_name') +    def __init__(self, view_name, **kwargs): +        self.view_name = view_name          self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)          self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) +        self.format = kwargs.pop('format', None) + +        # We include these simply for dependancy injection in tests. +        # We can't add them as class attributes or they would expect an +        # implict `self` argument to be passed. +        self.reverse = reverse +        self.resolve = resolve +          super(HyperlinkedRelatedField, self).__init__(**kwargs)      def get_object(self, view_name, view_args, view_kwargs): @@ -77,21 +98,36 @@ class HyperlinkedRelatedField(RelatedField):          lookup_kwargs = {self.lookup_field: lookup_value}          return self.get_queryset().get(**lookup_kwargs) -    def to_internal_value(self, value): +    def get_url(self, obj, view_name, request, format): +        """ +        Given an object, return the URL that hyperlinks to the object. + +        May raise a `NoReverseMatch` if the `view_name` and `lookup_field` +        attributes are not configured to correctly match the URL conf. +        """ +        # Unsaved objects will not yet have a valid URL. +        if obj.pk is None: +            return None + +        lookup_value = getattr(obj, self.lookup_field) +        kwargs = {self.lookup_url_kwarg: lookup_value} +        return self.reverse(view_name, kwargs=kwargs, request=request, format=format) + +    def to_internal_value(self, data):          try: -            http_prefix = value.startswith(('http:', 'https:')) +            http_prefix = data.startswith(('http:', 'https:'))          except AttributeError: -            self.fail('incorrect_type', data_type=type(value).__name__) +            self.fail('incorrect_type', data_type=type(data).__name__)          if http_prefix:              # If needed convert absolute URLs to relative path -            value = urlparse.urlparse(value).path +            data = urlparse.urlparse(data).path              prefix = get_script_prefix() -            if value.startswith(prefix): -                value = '/' + value[len(prefix):] +            if data.startswith(prefix): +                data = '/' + data[len(prefix):]          try: -            match = resolve(value) +            match = self.resolve(data)          except Exception:              self.fail('no_match') @@ -103,41 +139,14 @@ class HyperlinkedRelatedField(RelatedField):          except (ObjectDoesNotExist, TypeError, ValueError):              self.fail('does_not_exist') - -class HyperlinkedIdentityField(RelatedField): -    lookup_field = 'pk' - -    def __init__(self, **kwargs): -        kwargs['read_only'] = True -        kwargs['source'] = '*' -        self.view_name = kwargs.pop('view_name') -        self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) -        self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) -        super(HyperlinkedIdentityField, self).__init__(**kwargs) - -    def get_url(self, obj, view_name, request, format): -        """ -        Given an object, return the URL that hyperlinks to the object. - -        May raise a `NoReverseMatch` if the `view_name` and `lookup_field` -        attributes are not configured to correctly match the URL conf. -        """ -        # Unsaved objects will not yet have a valid URL. -        if obj.pk is None: -            return None - -        lookup_value = getattr(obj, self.lookup_field) -        kwargs = {self.lookup_url_kwarg: lookup_value} -        return reverse(view_name, kwargs=kwargs, request=request, format=format) -      def to_representation(self, value):          request = self.context.get('request', None)          format = self.context.get('format', None)          assert request is not None, ( -            "`HyperlinkedIdentityField` requires the request in the serializer" +            "`%s` requires the request in the serializer"              " context. Add `context={'request': request}` when instantiating " -            "the serializer." +            "the serializer." % self.__class__.__name__          )          # By default use whatever format is given for the current context @@ -162,9 +171,45 @@ class HyperlinkedIdentityField(RelatedField):                  'model in your API, or incorrectly configured the '                  '`lookup_field` attribute on this field.'              ) -            raise Exception(msg % self.view_name) +            raise ImproperlyConfigured(msg % self.view_name) + + +class HyperlinkedIdentityField(HyperlinkedRelatedField): +    """ +    A read-only field that represents the identity URL for an object, itself. + +    This is in contrast to `HyperlinkedRelatedField` which represents the +    URL of relationships to other objects. +    """ + +    def __init__(self, view_name, **kwargs): +        kwargs['read_only'] = True +        kwargs['source'] = '*' +        super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)  class SlugRelatedField(RelatedField): -    def __init__(self, **kwargs): -        self.slug_field = kwargs.pop('slug_field', None) +    """ +    A read-write field the represents the target of the relationship +    by a unique 'slug' attribute. +    """ + +    default_error_messages = { +        'does_not_exist': _("Object with {slug_name}={value} does not exist."), +        'invalid': _('Invalid value.'), +    } + +    def __init__(self, slug_field, **kwargs): +        self.slug_field = slug_field +        super(SlugRelatedField, self).__init__(**kwargs) + +    def to_internal_value(self, data): +        try: +            return self.get_queryset().get(**{self.slug_field: data}) +        except ObjectDoesNotExist: +            self.fail('does_not_exist', slug_name=self.slug_field, value=smart_text(data)) +        except (TypeError, ValueError): +            self.fail('invalid') + +    def to_representation(self, obj): +        return getattr(obj, self.slug_field) diff --git a/tests/test_relations.py b/tests/test_relations.py index b1bc66b6..a3672117 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,3 +1,142 @@ +from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset +from django.core.exceptions import ImproperlyConfigured, ValidationError +from rest_framework import serializers +from rest_framework.test import APISimpleTestCase +import pytest + + +class TestStringRelatedField(APISimpleTestCase): +    def setUp(self): +        self.instance = MockObject(pk=1, name='foo') +        self.field = serializers.StringRelatedField() + +    def test_string_related_representation(self): +        representation = self.field.to_representation(self.instance) +        assert representation == '<MockObject name=foo, pk=1>' + + +class TestPrimaryKeyRelatedField(APISimpleTestCase): +    def setUp(self): +        self.queryset = MockQueryset([ +            MockObject(pk=1, name='foo'), +            MockObject(pk=2, name='bar'), +            MockObject(pk=3, name='baz') +        ]) +        self.instance = self.queryset.items[2] +        self.field = serializers.PrimaryKeyRelatedField(queryset=self.queryset) + +    def test_pk_related_lookup_exists(self): +        instance = self.field.to_internal_value(self.instance.pk) +        assert instance is self.instance + +    def test_pk_related_lookup_does_not_exist(self): +        with pytest.raises(ValidationError) as excinfo: +            self.field.to_internal_value(4) +        msg = excinfo.value.message +        assert msg == "Invalid pk '4' - object does not exist." + +    def test_pk_related_lookup_invalid_type(self): +        with pytest.raises(ValidationError) as excinfo: +            self.field.to_internal_value(BadType()) +        msg = excinfo.value.message +        assert msg == 'Incorrect type. Expected pk value, received BadType.' + +    def test_pk_representation(self): +        representation = self.field.to_representation(self.instance) +        assert representation == self.instance.pk + + +class TestHyperlinkedIdentityField(APISimpleTestCase): +    def setUp(self): +        self.instance = MockObject(pk=1, name='foo') +        self.field = serializers.HyperlinkedIdentityField(view_name='example') +        self.field.reverse = mock_reverse +        self.field.context = {'request': True} + +    def test_representation(self): +        representation = self.field.to_representation(self.instance) +        assert representation == 'http://example.org/example/1/' + +    def test_representation_unsaved_object(self): +        representation = self.field.to_representation(MockObject(pk=None)) +        assert representation is None + +    def test_representation_with_format(self): +        self.field.context['format'] = 'xml' +        representation = self.field.to_representation(self.instance) +        assert representation == 'http://example.org/example/1.xml/' + +    def test_improperly_configured(self): +        """ +        If a matching view cannot be reversed with the given instance, +        the the user has misconfigured something, as the URL conf and the +        hyperlinked field do not match. +        """ +        self.field.reverse = fail_reverse +        with pytest.raises(ImproperlyConfigured): +            self.field.to_representation(self.instance) + + +class TestHyperlinkedIdentityFieldWithFormat(APISimpleTestCase): +    """ +    Tests for a hyperlinked identity field that has a `format` set, +    which enforces that alternate formats are never linked too. + +    Eg. If your API includes some endpoints that accept both `.xml` and `.json`, +    but other endpoints that only accept `.json`, we allow for hyperlinked +    relationships that enforce only a single suffix type. +    """ + +    def setUp(self): +        self.instance = MockObject(pk=1, name='foo') +        self.field = serializers.HyperlinkedIdentityField(view_name='example', format='json') +        self.field.reverse = mock_reverse +        self.field.context = {'request': True} + +    def test_representation(self): +        representation = self.field.to_representation(self.instance) +        assert representation == 'http://example.org/example/1/' + +    def test_representation_with_format(self): +        self.field.context['format'] = 'xml' +        representation = self.field.to_representation(self.instance) +        assert representation == 'http://example.org/example/1.json/' + + +class TestSlugRelatedField(APISimpleTestCase): +    def setUp(self): +        self.queryset = MockQueryset([ +            MockObject(pk=1, name='foo'), +            MockObject(pk=2, name='bar'), +            MockObject(pk=3, name='baz') +        ]) +        self.instance = self.queryset.items[2] +        self.field = serializers.SlugRelatedField( +            slug_field='name', queryset=self.queryset +        ) + +    def test_slug_related_lookup_exists(self): +        instance = self.field.to_internal_value(self.instance.name) +        assert instance is self.instance + +    def test_slug_related_lookup_does_not_exist(self): +        with pytest.raises(ValidationError) as excinfo: +            self.field.to_internal_value('doesnotexist') +        msg = excinfo.value.message +        assert msg == 'Object with name=doesnotexist does not exist.' + +    def test_slug_related_lookup_invalid_type(self): +        with pytest.raises(ValidationError) as excinfo: +            self.field.to_internal_value(BadType()) +        msg = excinfo.value.message +        assert msg == 'Invalid value.' + +    def test_representation(self): +        representation = self.field.to_representation(self.instance) +        assert representation == self.instance.name + +# Older tests, for review... +  # """  # General tests for relational fields.  # """ diff --git a/tests/utils.py b/tests/utils.py index 28be81bd..5e902ba9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,6 @@  from contextlib import contextmanager +from django.core.exceptions import ObjectDoesNotExist +from django.core.urlresolvers import NoReverseMatch  from django.utils import six  from rest_framework.settings import api_settings @@ -23,3 +25,54 @@ def temporary_setting(setting, value, module=None):      if module is not None:          six.moves.reload_module(module) + + +class MockObject(object): +    def __init__(self, **kwargs): +        self._kwargs = kwargs +        for key, val in kwargs.items(): +            setattr(self, key, val) + +    def __str__(self): +        kwargs_str = ', '.join([ +            '%s=%s' % (key, value) +            for key, value in sorted(self._kwargs.items()) +        ]) +        return '<MockObject %s>' % kwargs_str + + +class MockQueryset(object): +    def __init__(self, iterable): +        self.items = iterable + +    def get(self, **lookup): +        for item in self.items: +            if all([ +                getattr(item, key, None) == value +                for key, value in lookup.items() +            ]): +                return item +        raise ObjectDoesNotExist() + + +class BadType(object): +    """ +    When used as a lookup with a `MockQueryset`, these objects +    will raise a `TypeError`, as occurs in Django when making +    queryset lookups with an incorrect type for the lookup value. +    """ +    def __eq__(self): +        raise TypeError() + + +def mock_reverse(view_name, args=None, kwargs=None, request=None, format=None): +    args = args or [] +    kwargs = kwargs or {} +    value = (args + list(kwargs.values()) + ['-'])[0] +    prefix = 'http://example.org' if request else '' +    suffix = ('.' + format) if (format is not None) else '' +    return '%s/%s/%s%s/' % (prefix, view_name, value, suffix) + + +def fail_reverse(view_name, args=None, kwargs=None, request=None, format=None): +    raise NoReverseMatch() | 
