diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/relations.py | 11 | ||||
| -rw-r--r-- | rest_framework/renderers.py | 161 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 68 | ||||
| -rw-r--r-- | rest_framework/templates/rest_framework/base.html | 10 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_nested.py | 351 | 
5 files changed, 448 insertions, 153 deletions
diff --git a/rest_framework/relations.py b/rest_framework/relations.py index edaf76d6..3ad16ee5 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -134,9 +134,9 @@ class RelatedField(WritableField):              value = obj              for component in source.split('.'): -                value = get_component(value, component)                  if value is None:                      break +                value = get_component(value, component)          except ObjectDoesNotExist:              return None @@ -244,6 +244,8 @@ class PrimaryKeyRelatedField(RelatedField):                  source = self.source or field_name                  queryset = obj                  for component in source.split('.'): +                    if queryset is None: +                        return []                      queryset = get_component(queryset, component)              # Forward relationship @@ -567,8 +569,13 @@ class HyperlinkedIdentityField(Field):          May raise a `NoReverseMatch` if the `view_name` and `lookup_field`          attributes are not configured to correctly match the URL conf.          """ -        lookup_field = getattr(obj, self.lookup_field) +        lookup_field = getattr(obj, self.lookup_field, None)          kwargs = {self.lookup_field: lookup_field} + +        # Handle unsaved object case +        if lookup_field is None: +            return None +          try:              return reverse(view_name, kwargs=kwargs, request=request, format=format)          except NoReverseMatch: diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index c87014e2..c07b1652 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -318,6 +318,71 @@ class StaticHTMLRenderer(TemplateHTMLRenderer):          return data +class HTMLFormRenderer(BaseRenderer): +    """ +    Renderers serializer data into an HTML form. + +    If the serializer was instantiated without an object then this will +    return an HTML form not bound to any object, +    otherwise it will return an HTML form with the appropriate initial data +    populated from the object. +    """ +    media_type = 'text/html' +    format = 'form' +    template = 'rest_framework/form.html' +    charset = 'utf-8' + +    def data_to_form_fields(self, data): +        fields = {} +        for key, val in data.fields.items(): +            if getattr(val, 'read_only', True): +                continue + +            kwargs = {} +            kwargs['required'] = val.required + +            #if getattr(v, 'queryset', None): +            #    kwargs['queryset'] = v.queryset + +            if getattr(val, 'choices', None) is not None: +                kwargs['choices'] = val.choices + +            if getattr(val, 'regex', None) is not None: +                kwargs['regex'] = val.regex + +            if getattr(val, 'widget', None): +                widget = copy.deepcopy(val.widget) +                kwargs['widget'] = widget + +            if getattr(val, 'default', None) is not None: +                kwargs['initial'] = val.default + +            if getattr(val, 'label', None) is not None: +                kwargs['label'] = val.label + +            if getattr(val, 'help_text', None) is not None: +                kwargs['help_text'] = val.help_text + +            fields[key] = val.form_field_class(**kwargs) + +        return fields + +    def render(self, data, accepted_media_type=None, renderer_context=None): +        self.renderer_context = renderer_context or {} +        request = renderer_context['request'] + +        # Creating an on the fly form see: +        # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python +        fields = self.data_to_form_fields(data) +        DynamicForm = type(str('DynamicForm'), (forms.Form,), fields) +        data = None if data.empty else data + +        template = loader.get_template(self.template) +        context = RequestContext(request, {'form': DynamicForm(data)}) + +        return template.render(context) + +  class BrowsableAPIRenderer(BaseRenderer):      """      HTML renderer used to self-document the API. @@ -326,6 +391,7 @@ class BrowsableAPIRenderer(BaseRenderer):      format = 'api'      template = 'rest_framework/api.html'      charset = 'utf-8' +    form_renderer_class = HTMLFormRenderer      def get_default_renderer(self, view):          """ @@ -376,54 +442,7 @@ class BrowsableAPIRenderer(BaseRenderer):              return False  # Doesn't have permissions          return True -    def serializer_to_form_fields(self, serializer): -        fields = {} -        for k, v in serializer.get_fields().items(): -            if getattr(v, 'read_only', True): -                continue - -            kwargs = {} -            kwargs['required'] = v.required - -            #if getattr(v, 'queryset', None): -            #    kwargs['queryset'] = v.queryset - -            if getattr(v, 'choices', None) is not None: -                kwargs['choices'] = v.choices - -            if getattr(v, 'regex', None) is not None: -                kwargs['regex'] = v.regex - -            if getattr(v, 'widget', None): -                widget = copy.deepcopy(v.widget) -                kwargs['widget'] = widget - -            if getattr(v, 'default', None) is not None: -                kwargs['initial'] = v.default - -            if getattr(v, 'label', None) is not None: -                kwargs['label'] = v.label - -            if getattr(v, 'help_text', None) is not None: -                kwargs['help_text'] = v.help_text - -            fields[k] = v.form_field_class(**kwargs) - -        return fields - -    def _get_form(self, view, method, request): -        # We need to impersonate a request with the correct method, -        # so that eg. any dynamic get_serializer_class methods return the -        # correct form for each method. -        restore = view.request -        request = clone_request(request, method) -        view.request = request -        try: -            return self.get_form(view, method, request) -        finally: -            view.request = restore - -    def _get_raw_data_form(self, view, method, request, media_types): +    def _get_rendered_html_form(self, view, method, request):          # We need to impersonate a request with the correct method,          # so that eg. any dynamic get_serializer_class methods return the          # correct form for each method. @@ -431,15 +450,16 @@ class BrowsableAPIRenderer(BaseRenderer):          request = clone_request(request, method)          view.request = request          try: -            return self.get_raw_data_form(view, method, request, media_types) +            return self.get_rendered_html_form(view, method, request)          finally:              view.request = restore -    def get_form(self, view, method, request): +    def get_rendered_html_form(self, view, method, request):          """ -        Get a form, possibly bound to either the input or output data. -        In the absence on of the Resource having an associated form then -        provide a form that can be used to submit arbitrary content. +        Return a string representing a rendered HTML form, possibly bound to +        either the input or output data. + +        In the absence of the View having an associated form then return None.          """          obj = getattr(view, 'object', None)          if not self.show_form_for_method(view, method, request, obj): @@ -452,14 +472,21 @@ class BrowsableAPIRenderer(BaseRenderer):              return          serializer = view.get_serializer(instance=obj) -        fields = self.serializer_to_form_fields(serializer) +        data = serializer.data +        form_renderer = self.form_renderer_class() +        return form_renderer.render(data, self.accepted_media_type, self.renderer_context) -        # Creating an on the fly form see: -        # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python -        OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields) -        data = (obj is not None) and serializer.data or None -        form_instance = OnTheFlyForm(data) -        return form_instance +    def _get_raw_data_form(self, view, method, request, media_types): +        # We need to impersonate a request with the correct method, +        # so that eg. any dynamic get_serializer_class methods return the +        # correct form for each method. +        restore = view.request +        request = clone_request(request, method) +        view.request = request +        try: +            return self.get_raw_data_form(view, method, request, media_types) +        finally: +            view.request = restore      def get_raw_data_form(self, view, method, request, media_types):          """ @@ -514,8 +541,8 @@ class BrowsableAPIRenderer(BaseRenderer):          """          Render the HTML for the browsable API representation.          """ -        accepted_media_type = accepted_media_type or '' -        renderer_context = renderer_context or {} +        self.accepted_media_type = accepted_media_type or '' +        self.renderer_context = renderer_context or {}          view = renderer_context['view']          request = renderer_context['request'] @@ -525,11 +552,11 @@ class BrowsableAPIRenderer(BaseRenderer):          renderer = self.get_default_renderer(view)          content = self.get_content(renderer, data, accepted_media_type, renderer_context) -        put_form = self._get_form(view, 'PUT', request) -        post_form = self._get_form(view, 'POST', request) -        patch_form = self._get_form(view, 'PATCH', request) -        delete_form = self._get_form(view, 'DELETE', request) -        options_form = self._get_form(view, 'OPTIONS', request) +        put_form = self._get_rendered_html_form(view, 'PUT', request) +        post_form = self._get_rendered_html_form(view, 'POST', request) +        patch_form = self._get_rendered_html_form(view, 'PATCH', request) +        delete_form = self._get_rendered_html_form(view, 'DELETE', request) +        options_form = self._get_rendered_html_form(view, 'OPTIONS', request)          raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types)          raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 31cfa344..97e0a005 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -32,6 +32,9 @@ from rest_framework.relations import *  from rest_framework.fields import * +class RelationsList(list): +    _deleted = [] +  class NestedValidationError(ValidationError):      """      The default ValidationError behavior is to stringify each item in the list @@ -161,7 +164,6 @@ class BaseSerializer(WritableField):          self._data = None          self._files = None          self._errors = None -        self._deleted = None          if many and instance is not None and not hasattr(instance, '__iter__'):              raise ValueError('instance should be a queryset or other iterable with many=True') @@ -298,7 +300,8 @@ class BaseSerializer(WritableField):          Serialize objects -> primitives.          """          ret = self._dict_class() -        ret.fields = {} +        ret.fields = self._dict_class() +        ret.empty = obj is None          for field_name, field in self.fields.items():              field.initialize(parent=self, field_name=field_name) @@ -336,9 +339,9 @@ class BaseSerializer(WritableField):              value = obj              for component in source.split('.'): -                value = get_component(value, component)                  if value is None: -                    break +                    return self.to_native(None) +                value = get_component(value, component)          except ObjectDoesNotExist:              return None @@ -378,6 +381,7 @@ class BaseSerializer(WritableField):          # Set the serializer object if it exists          obj = getattr(self.parent.object, field_name) if self.parent.object else None +        obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj          if self.source == '*':              if value: @@ -391,7 +395,8 @@ class BaseSerializer(WritableField):                      'data': value,                      'context': self.context,                      'partial': self.partial, -                    'many': self.many +                    'many': self.many, +                    'allow_add_remove': self.allow_add_remove                  }                  serializer = self.__class__(**kwargs) @@ -434,7 +439,7 @@ class BaseSerializer(WritableField):                                    DeprecationWarning, stacklevel=3)              if many: -                ret = [] +                ret = RelationsList()                  errors = []                  update = self.object is not None @@ -461,8 +466,8 @@ class BaseSerializer(WritableField):                          ret.append(self.from_native(item, None))                          errors.append(self._errors) -                    if update: -                        self._deleted = identity_to_objects.values() +                    if update and self.allow_add_remove: +                        ret._deleted = identity_to_objects.values()                      self._errors = any(errors) and errors or []                  else: @@ -514,12 +519,12 @@ class BaseSerializer(WritableField):          """          if isinstance(self.object, list):              [self.save_object(item, **kwargs) for item in self.object] + +            if self.object._deleted: +                [self.delete_object(item) for item in self.object._deleted]          else:              self.save_object(self.object, **kwargs) -        if self.allow_add_remove and self._deleted: -            [self.delete_object(item) for item in self._deleted] -          return self.object      def metadata(self): @@ -795,9 +800,12 @@ class ModelSerializer(Serializer):          cls = self.opts.model          opts = get_concrete_model(cls)._meta          exclusions = [field.name for field in opts.fields + opts.many_to_many] +          for field_name, field in self.fields.items():              field_name = field.source or field_name -            if field_name in exclusions and not field.read_only: +            if field_name in exclusions \ +                and not field.read_only \ +                and not isinstance(field, Serializer):                  exclusions.remove(field_name)          return exclusions @@ -823,6 +831,7 @@ class ModelSerializer(Serializer):          """          m2m_data = {}          related_data = {} +        nested_forward_relations = {}          meta = self.opts.model._meta          # Reverse fk or one-to-one relations @@ -842,6 +851,12 @@ class ModelSerializer(Serializer):              if field.name in attrs:                  m2m_data[field.name] = attrs.pop(field.name) +        # Nested forward relations - These need to be marked so we can save +        # them before saving the parent model instance. +        for field_name in attrs.keys(): +            if isinstance(self.fields.get(field_name, None), Serializer): +                nested_forward_relations[field_name] = attrs[field_name] +          # Update an existing instance...          if instance is not None:              for key, val in attrs.items(): @@ -857,6 +872,7 @@ class ModelSerializer(Serializer):          # at the point of save.          instance._related_data = related_data          instance._m2m_data = m2m_data +        instance._nested_forward_relations = nested_forward_relations          return instance @@ -872,6 +888,14 @@ class ModelSerializer(Serializer):          """          Save the deserialized object and return it.          """ +        if getattr(obj, '_nested_forward_relations', None): +            # Nested relationships need to be saved before we can save the +            # parent instance. +            for field_name, sub_object in obj._nested_forward_relations.items(): +                if sub_object: +                    self.save_object(sub_object) +                setattr(obj, field_name, sub_object) +          obj.save(**kwargs)          if getattr(obj, '_m2m_data', None): @@ -881,7 +905,25 @@ class ModelSerializer(Serializer):          if getattr(obj, '_related_data', None):              for accessor_name, related in obj._related_data.items(): -                setattr(obj, accessor_name, related) +                if isinstance(related, RelationsList): +                    # Nested reverse fk relationship +                    for related_item in related: +                        fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name +                        setattr(related_item, fk_field, obj) +                        self.save_object(related_item) + +                    # Delete any removed objects +                    if related._deleted: +                        [self.delete_object(item) for item in related._deleted] + +                elif isinstance(related, models.Model): +                    # Nested reverse one-one relationship +                    fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name +                    setattr(related, fk_field, obj) +                    self.save_object(related) +                else: +                    # Reverse FK or reverse one-one +                    setattr(obj, accessor_name, related)              del(obj._related_data) diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 51f9c291..6ae47563 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -136,9 +136,9 @@                          {% if post_form %}                          <div class="tab-pane" id="object-form">                              {% with form=post_form %} -                            <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> +                            <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal">                                  <fieldset> -                                    {% include "rest_framework/form.html" %} +                                    {{ post_form }}                                      <div class="form-actions">                                          <button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button>                                      </div> @@ -174,16 +174,14 @@                      <div class="well tab-content">                          {% if put_form %}                          <div class="tab-pane" id="object-form"> -                            {% with form=put_form %} -                            <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> +                            <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal">                                  <fieldset> -                                    {% include "rest_framework/form.html" %} +                                    {{ put_form }}                                      <div class="form-actions">                                          <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button>                                      </div>                                  </fieldset>                              </form> -                            {% endwith %}                          </div>                          {% endif %}                          <div {% if put_form %}class="tab-pane"{% endif %} id="generic-content-form"> diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index f6d006b3..d393b0c3 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -1,107 +1,328 @@  from __future__ import unicode_literals +from django.db import models  from django.test import TestCase  from rest_framework import serializers -from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource -class ForeignKeySourceSerializer(serializers.ModelSerializer): -    class Meta: -        model = ForeignKeySource -        fields = ('id', 'name', 'target') -        depth = 1 +class OneToOneTarget(models.Model): +    name = models.CharField(max_length=100) -class ForeignKeyTargetSerializer(serializers.ModelSerializer): -    class Meta: -        model = ForeignKeyTarget -        fields = ('id', 'name', 'sources') -        depth = 1 +class OneToOneSource(models.Model): +    name = models.CharField(max_length=100) +    target = models.OneToOneField(OneToOneTarget, related_name='source', +                                  null=True, blank=True) -class NullableForeignKeySourceSerializer(serializers.ModelSerializer): -    class Meta: -        model = NullableForeignKeySource -        fields = ('id', 'name', 'target') -        depth = 1 +class OneToManyTarget(models.Model): +    name = models.CharField(max_length=100) -class NullableOneToOneTargetSerializer(serializers.ModelSerializer): -    class Meta: -        model = OneToOneTarget -        fields = ('id', 'name', 'nullable_source') -        depth = 1 +class OneToManySource(models.Model): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(OneToManyTarget, related_name='sources') -class ReverseForeignKeyTests(TestCase): +class ReverseNestedOneToOneTests(TestCase):      def setUp(self): -        target = ForeignKeyTarget(name='target-1') -        target.save() -        new_target = ForeignKeyTarget(name='target-2') -        new_target.save() +        class OneToOneSourceSerializer(serializers.ModelSerializer): +            class Meta: +                model = OneToOneSource +                fields = ('id', 'name') + +        class OneToOneTargetSerializer(serializers.ModelSerializer): +            source = OneToOneSourceSerializer() + +            class Meta: +                model = OneToOneTarget +                fields = ('id', 'name', 'source') + +        self.Serializer = OneToOneTargetSerializer +          for idx in range(1, 4): -            source = ForeignKeySource(name='source-%d' % idx, target=target) +            target = OneToOneTarget(name='target-%d' % idx) +            target.save() +            source = OneToOneSource(name='source-%d' % idx, target=target)              source.save() -    def test_foreign_key_retrieve(self): -        queryset = ForeignKeySource.objects.all() -        serializer = ForeignKeySourceSerializer(queryset, many=True) +    def test_one_to_one_retrieve(self): +        queryset = OneToOneTarget.objects.all() +        serializer = self.Serializer(queryset, many=True)          expected = [ -            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, -            {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, -            {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}}, +            {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, +            {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, +            {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}          ]          self.assertEqual(serializer.data, expected) -    def test_reverse_foreign_key_retrieve(self): -        queryset = ForeignKeyTarget.objects.all() -        serializer = ForeignKeyTargetSerializer(queryset, many=True) +    def test_one_to_one_create(self): +        data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} +        serializer = self.Serializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-4') + +        # Ensure (target 4, target_source 4, source 4) are added, and +        # everything else is as expected. +        queryset = OneToOneTarget.objects.all() +        serializer = self.Serializer(queryset, many=True)          expected = [ -            {'id': 1, 'name': 'target-1', 'sources': [ -                {'id': 1, 'name': 'source-1', 'target': 1}, -                {'id': 2, 'name': 'source-2', 'target': 1}, -                {'id': 3, 'name': 'source-3', 'target': 1}, -            ]}, -            {'id': 2, 'name': 'target-2', 'sources': [ -            ]} +            {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, +            {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, +            {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}, +            {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}          ]          self.assertEqual(serializer.data, expected) +    def test_one_to_one_create_with_invalid_data(self): +        data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}} +        serializer = self.Serializer(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]}) -class NestedNullableForeignKeyTests(TestCase): +    def test_one_to_one_update(self): +        data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} +        instance = OneToOneTarget.objects.get(pk=3) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-3-updated') + +        # Ensure (target 3, target_source 3, source 3) are updated, +        # and everything else is as expected. +        queryset = OneToOneTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, +            {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, +            {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} +        ] +        self.assertEqual(serializer.data, expected) + + +class ForwardNestedOneToOneTests(TestCase):      def setUp(self): -        target = ForeignKeyTarget(name='target-1') -        target.save() +        class OneToOneTargetSerializer(serializers.ModelSerializer): +            class Meta: +                model = OneToOneTarget +                fields = ('id', 'name') + +        class OneToOneSourceSerializer(serializers.ModelSerializer): +            target = OneToOneTargetSerializer() + +            class Meta: +                model = OneToOneSource +                fields = ('id', 'name', 'target') + +        self.Serializer = OneToOneSourceSerializer +          for idx in range(1, 4): -            if idx == 3: -                target = None -            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            target = OneToOneTarget(name='target-%d' % idx) +            target.save() +            source = OneToOneSource(name='source-%d' % idx, target=target)              source.save() -    def test_foreign_key_retrieve_with_null(self): -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +    def test_one_to_one_retrieve(self): +        queryset = OneToOneSource.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, +            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, +            {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_one_create(self): +        data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} +        serializer = self.Serializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure (target 4, target_source 4, source 4) are added, and +        # everything else is as expected. +        queryset = OneToOneSource.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, +            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, +            {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}, +            {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_one_create_with_invalid_data(self): +        data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}} +        serializer = self.Serializer(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]}) + +    def test_one_to_one_update(self): +        data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} +        instance = OneToOneSource.objects.get(pk=3) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-3-updated') + +        # Ensure (target 3, target_source 3, source 3) are updated, +        # and everything else is as expected. +        queryset = OneToOneSource.objects.all() +        serializer = self.Serializer(queryset, many=True)          expected = [              {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, -            {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, -            {'id': 3, 'name': 'source-3', 'target': None}, +            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, +            {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}          ]          self.assertEqual(serializer.data, expected) +    def test_one_to_one_update_to_null(self): +        data = {'id': 3, 'name': 'source-3-updated', 'target': None} +        instance = OneToOneSource.objects.get(pk=3) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() -class NestedNullableOneToOneTests(TestCase): +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-3-updated') +        self.assertEqual(obj.target, None) + +        queryset = OneToOneSource.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, +            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, +            {'id': 3, 'name': 'source-3-updated', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    # TODO: Nullable 1-1 tests +    # def test_one_to_one_delete(self): +    #     data = {'id': 3, 'name': 'target-3', 'target_source': None} +    #     instance = OneToOneTarget.objects.get(pk=3) +    #     serializer = self.Serializer(instance, data=data) +    #     self.assertTrue(serializer.is_valid()) +    #     serializer.save() + +    #     # Ensure (target_source 3, source 3) are deleted, +    #     # and everything else is as expected. +    #     queryset = OneToOneTarget.objects.all() +    #     serializer = self.Serializer(queryset) +    #     expected = [ +    #         {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, +    #         {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, +    #         {'id': 3, 'name': 'target-3', 'source': None} +    #     ] +    #     self.assertEqual(serializer.data, expected) + + +class ReverseNestedOneToManyTests(TestCase):      def setUp(self): -        target = OneToOneTarget(name='target-1') +        class OneToManySourceSerializer(serializers.ModelSerializer): +            class Meta: +                model = OneToManySource +                fields = ('id', 'name') + +        class OneToManyTargetSerializer(serializers.ModelSerializer): +            sources = OneToManySourceSerializer(many=True, allow_add_remove=True) + +            class Meta: +                model = OneToManyTarget +                fields = ('id', 'name', 'sources') + +        self.Serializer = OneToManyTargetSerializer + +        target = OneToManyTarget(name='target-1')          target.save() -        new_target = OneToOneTarget(name='target-2') -        new_target.save() -        source = NullableOneToOneSource(name='source-1', target=target) -        source.save() +        for idx in range(1, 4): +            source = OneToManySource(name='source-%d' % idx, target=target) +            source.save() -    def test_reverse_foreign_key_retrieve_with_null(self): -        queryset = OneToOneTarget.objects.all() -        serializer = NullableOneToOneTargetSerializer(queryset, many=True) +    def test_one_to_many_retrieve(self): +        queryset = OneToManyTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                      {'id': 2, 'name': 'source-2'}, +                                                      {'id': 3, 'name': 'source-3'}]}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_many_create(self): +        data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                         {'id': 2, 'name': 'source-2'}, +                                                         {'id': 3, 'name': 'source-3'}, +                                                         {'id': 4, 'name': 'source-4'}]} +        instance = OneToManyTarget.objects.get(pk=1) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-1') + +        # Ensure source 4 is added, and everything else is as +        # expected. +        queryset = OneToManyTarget.objects.all() +        serializer = self.Serializer(queryset, many=True)          expected = [ -            {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}}, -            {'id': 2, 'name': 'target-2', 'nullable_source': None}, +            {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                      {'id': 2, 'name': 'source-2'}, +                                                      {'id': 3, 'name': 'source-3'}, +                                                      {'id': 4, 'name': 'source-4'}]} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_many_create_with_invalid_data(self): +        data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                         {'id': 2, 'name': 'source-2'}, +                                                         {'id': 3, 'name': 'source-3'}, +                                                         {'id': 4}]} +        serializer = self.Serializer(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]}) + +    def test_one_to_many_update(self): +        data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, +                                                                 {'id': 2, 'name': 'source-2'}, +                                                                 {'id': 3, 'name': 'source-3'}]} +        instance = OneToManyTarget.objects.get(pk=1) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-1-updated') + +        # Ensure (target 1, source 1) are updated, +        # and everything else is as expected. +        queryset = OneToManyTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, +                                                              {'id': 2, 'name': 'source-2'}, +                                                              {'id': 3, 'name': 'source-3'}]} + +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_many_delete(self): +        data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                         {'id': 3, 'name': 'source-3'}]} +        instance = OneToManyTarget.objects.get(pk=1) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() + +        # Ensure source 2 is deleted, and everything else is as +        # expected. +        queryset = OneToManyTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                      {'id': 3, 'name': 'source-3'}]} +          ]          self.assertEqual(serializer.data, expected)  | 
