diff options
Diffstat (limited to 'rest_framework')
32 files changed, 893 insertions, 81 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 83a6f302..2e38d863 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,3 +1,3 @@ -__version__ = '2.1.9' +__version__ = '2.1.13'  VERSION = __version__  # synonym diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 5996b16e..5508f6c0 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -5,6 +5,12 @@ versions of django/python, and compatibility wrappers around optional packages.  # flake8: noqa  import django +# location of patterns, url, include changes in 1.4 onwards +try: +    from django.conf.urls import patterns, url, include +except: +    from django.conf.urls.defaults import patterns, url, include +  # django-filter is optional  try:      import django_filters diff --git a/rest_framework/fields.py b/rest_framework/fields.py index d3ef8f77..dd90c3f8 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -351,7 +351,12 @@ class RelatedField(WritableField):          if self.read_only:              return -        value = data.get(field_name) +        try: +            value = data[field_name] +        except KeyError: +            if self.required: +                raise ValidationError(self.error_messages['required']) +            return          if value in (None, '') and not self.null:              raise ValidationError('Value may not be null') @@ -384,6 +389,7 @@ class ManyRelatedMixin(object):          else:              if value == ['']:                  value = [] +          into[field_name] = [self.from_native(item) for item in value] @@ -795,7 +801,7 @@ class ChoiceField(WritableField):                      if value == smart_unicode(k2):                          return True              else: -                if value == smart_unicode(k): +                if value == smart_unicode(k) or value == k:                      return True          return False diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index b71ab05c..d828078d 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -113,6 +113,10 @@ class UpdateModelMixin(object):              slug_field = self.get_slug_field()              setattr(obj, slug_field, slug) +        # Ensure we clean the attributes so that we don't eg return integer +        # pk using a string representation, as provided by the url conf kwarg. +        obj.full_clean() +  class DestroyModelMixin(object):      """ @@ -120,6 +124,6 @@ class DestroyModelMixin(object):      Should be mixed in with `SingleObjectBaseView`.      """      def destroy(self, request, *args, **kwargs): -        self.object = self.get_object() -        self.object.delete() +        obj = self.get_object() +        obj.delete()          return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 1220bca1..a4ae717d 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -20,7 +20,7 @@ from rest_framework.utils import dict2xml  from rest_framework.utils import encoders  from rest_framework.utils.breadcrumbs import get_breadcrumbs  from rest_framework import VERSION, status -from rest_framework import serializers, parsers +from rest_framework import parsers  class BaseRenderer(object): diff --git a/rest_framework/request.py b/rest_framework/request.py index 39c64321..b7133608 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -188,6 +188,14 @@ class Request(object):              self._user, self._auth = self._authenticate()          return self._auth +    @auth.setter +    def auth(self, value): +        """ +        Sets any non-user authentication information associated with the +        request, such as an authentication token. +        """ +        self._auth = value +      def _load_data_and_files(self):          """          Parses the request content into self.DATA and self.FILES. diff --git a/rest_framework/runtests/runcoverage.py b/rest_framework/runtests/runcoverage.py index 0ce379eb..bcab1d14 100755 --- a/rest_framework/runtests/runcoverage.py +++ b/rest_framework/runtests/runcoverage.py @@ -8,6 +8,9 @@ Useful tool to run the test suite for rest_framework and generate a coverage rep  # http://code.djangoproject.com/svn/django/trunk/tests/runtests.py  import os  import sys + +# fix sys path so we don't need to setup PYTHONPATH +sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))  os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'  from coverage import coverage @@ -55,6 +58,12 @@ def main():          if 'compat.py' in files:              files.remove('compat.py') +        # Same applies to template tags module. +        # This module has to include branching on Django versions, +        # so it's never possible for it to have full coverage. +        if 'rest_framework.py' in files: +            files.remove('rest_framework.py') +          cov_files.extend([os.path.join(path, file) for file in files if file.endswith('.py')])      cov.report(cov_files) diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py index 729ef26a..505994e2 100755 --- a/rest_framework/runtests/runtests.py +++ b/rest_framework/runtests/runtests.py @@ -5,11 +5,9 @@  # http://code.djangoproject.com/svn/django/trunk/tests/runtests.py  import os  import sys -""" -Need to fix sys path so following works without specifically messing with PYTHONPATH -python ./rest_framework/runtests/runtests.py -""" -sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))   + +# fix sys path so we don't need to setup PYTHONPATH +sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))  os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'  from django.conf import settings diff --git a/rest_framework/runtests/urls.py b/rest_framework/runtests/urls.py index 4b7da787..ed5baeae 100644 --- a/rest_framework/runtests/urls.py +++ b/rest_framework/runtests/urls.py @@ -1,7 +1,7 @@  """  Blank URLConf just to keep runtests.py happy.  """ -from django.conf.urls.defaults import * +from rest_framework.compat import patterns  urlpatterns = patterns('',  ) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 1d93f777..e8e6735a 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -160,6 +160,9 @@ class BaseSerializer(Field):              for key in self.opts.exclude:                  ret.pop(key, None) +        for key, field in ret.items(): +            field.initialize(parent=self, field_name=key) +          return ret      ##### @@ -174,13 +177,6 @@ class BaseSerializer(Field):          if parent.opts.depth:              self.opts.depth = parent.opts.depth - 1 -        # We need to call initialize here to ensure any nested -        # serializers that will have already called initialize on their -        # descendants get updated with *their* parent. -        # We could be a bit more smart about this, but it'll do for now. -        for key, field in self.fields.items(): -            field.initialize(parent=self, field_name=key) -      #####      # Methods to convert or revert from objects <--> primitive representations. @@ -311,6 +307,9 @@ class BaseSerializer(Field):          if is_simple_callable(getattr(obj, 'all', None)):              return [self.to_native(item) for item in obj.all()] +        if obj is None: +            return None +          return self.to_native(obj)      @property @@ -442,7 +441,7 @@ class ModelSerializer(Serializer):          kwargs['blank'] = model_field.blank -        if model_field.null: +        if model_field.null or model_field.blank:              kwargs['required'] = False          if model_field.has_default(): @@ -497,29 +496,38 @@ class ModelSerializer(Serializer):          Restore the model instance.          """          self.m2m_data = {} +        self.related_data = {}          if instance is not None:              for key, val in attrs.items():                  setattr(instance, key, val) -            return instance -        # Reverse relations -        for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): -            field_name = obj.field.related_query_name() -            if field_name in attrs: -                self.m2m_data[field_name] = attrs.pop(field_name) +        else: +            # Reverse fk relations +            for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): +                field_name = obj.field.related_query_name() +                if field_name in attrs: +                    self.related_data[field_name] = attrs.pop(field_name) + +            # Reverse m2m relations +            for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): +                field_name = obj.field.related_query_name() +                if field_name in attrs: +                    self.m2m_data[field_name] = attrs.pop(field_name) + +            # Forward m2m relations +            for field in self.opts.model._meta.many_to_many: +                if field.name in attrs: +                    self.m2m_data[field.name] = attrs.pop(field.name) -        # Forward relations -        for field in self.opts.model._meta.many_to_many: -            if field.name in attrs: -                self.m2m_data[field.name] = attrs.pop(field.name) +            instance = self.opts.model(**attrs) -        instance = self.opts.model(**attrs)          try:              instance.full_clean(exclude=self.get_validation_exclusions())          except ValidationError, err:              self._errors = err.message_dict              return None +          return instance      def save(self, save_m2m=True): @@ -533,6 +541,11 @@ class ModelSerializer(Serializer):                  setattr(self.object, accessor_name, object_list)              self.m2m_data = {} +        if getattr(self, 'related_data', None): +            for accessor_name, object_list in self.related_data.items(): +                setattr(self.object, accessor_name, object_list) +            self.related_data = {} +          return self.object diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index fb0e19f0..42e49cb9 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -1,6 +1,5 @@  {% load url from future %}  {% load rest_framework %} -{% load static %}  <!DOCTYPE html>  <html>      <head> @@ -14,10 +13,10 @@          <title>{% block title %}Django REST framework{% endblock %}</title>          {% block style %} -        <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap.min.css"/> -        <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap-tweaks.css"/> -        <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/prettify.css'/> -        <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/default.css'/> +        <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/> +        <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> +        <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/prettify.css" %}"/> +        <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>          {% endblock %}      {% endblock %} @@ -195,10 +194,10 @@      {% endblock %}      {% block script %} -    <script src="{% get_static_prefix %}rest_framework/js/jquery-1.8.1-min.js"></script> -    <script src="{% get_static_prefix %}rest_framework/js/bootstrap.min.js"></script> -    <script src="{% get_static_prefix %}rest_framework/js/prettify-min.js"></script> -    <script src="{% get_static_prefix %}rest_framework/js/default.js"></script> +    <script src="{% static "rest_framework/js/jquery-1.8.1-min.js" %}"></script> +    <script src="{% static "rest_framework/js/bootstrap.min.js" %}"></script> +    <script src="{% static "rest_framework/js/prettify-min.js" %}"></script> +    <script src="{% static "rest_framework/js/default.js" %}"></script>      {% endblock %}    </body>  </html> diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html index c1271399..6e2bd8d4 100644 --- a/rest_framework/templates/rest_framework/login.html +++ b/rest_framework/templates/rest_framework/login.html @@ -1,11 +1,11 @@  {% load url from future %} -{% load static %} +{% load rest_framework %}  <html>      <head> -        <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap.min.css"/> -        <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap-tweaks.css"/> -        <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/default.css'/> +        <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/> +        <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> +        <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>      </head>      <body class="container"> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 4e0181ee..09c658bc 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -11,6 +11,89 @@ import string  register = template.Library() +# Note we don't use 'load staticfiles', because we need a 1.3 compatible +# version, so instead we include the `static` template tag ourselves. + +# When 1.3 becomes unsupported by REST framework, we can instead start to +# use the {% load staticfiles %} tag, remove the following code, +# and add a dependancy that `django.contrib.staticfiles` must be installed. + +# Note: We can't put this into the `compat` module because the compat import +# from rest_framework.compat import ... +# conflicts with this rest_framework template tag module. + +try:  # Django 1.5+ +    from django.contrib.staticfiles.templatetags import StaticFilesNode + +    @register.tag('static') +    def do_static(parser, token): +        return StaticFilesNode.handle_token(parser, token) + +except: +    try:  # Django 1.4 +        from django.contrib.staticfiles.storage import staticfiles_storage + +        @register.simple_tag +        def static(path): +            """ +            A template tag that returns the URL to a file +            using staticfiles' storage backend +            """ +            return staticfiles_storage.url(path) + +    except:  # Django 1.3 +        from urlparse import urljoin +        from django import template +        from django.templatetags.static import PrefixNode + +        class StaticNode(template.Node): +            def __init__(self, varname=None, path=None): +                if path is None: +                    raise template.TemplateSyntaxError( +                        "Static template nodes must be given a path to return.") +                self.path = path +                self.varname = varname + +            def url(self, context): +                path = self.path.resolve(context) +                return self.handle_simple(path) + +            def render(self, context): +                url = self.url(context) +                if self.varname is None: +                    return url +                context[self.varname] = url +                return '' + +            @classmethod +            def handle_simple(cls, path): +                return urljoin(PrefixNode.handle_simple("STATIC_URL"), path) + +            @classmethod +            def handle_token(cls, parser, token): +                """ +                Class method to parse prefix node and return a Node. +                """ +                bits = token.split_contents() + +                if len(bits) < 2: +                    raise template.TemplateSyntaxError( +                        "'%s' takes at least one argument (path to file)" % bits[0]) + +                path = parser.compile_filter(bits[1]) + +                if len(bits) >= 2 and bits[-2] == 'as': +                    varname = bits[3] +                else: +                    varname = None + +                return cls(varname, path) + +        @register.tag('static') +        def do_static_13(parser, token): +            return StaticNode.handle_token(parser, token) + +  def replace_query_param(url, key, val):      """      Given a URL and a key/val pair, set or replace an item in the query diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index d498ae3e..838e081b 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -1,15 +1,13 @@ -from django.conf.urls.defaults import patterns  from django.contrib.auth.models import User +from django.http import HttpResponse  from django.test import Client, TestCase -  from django.utils import simplejson as json -from django.http import HttpResponse -from rest_framework.views import APIView  from rest_framework import permissions -  from rest_framework.authtoken.models import Token  from rest_framework.authentication import TokenAuthentication +from rest_framework.compat import patterns +from rest_framework.views import APIView  import base64 diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/breadcrumbs.py index 647ab96d..df891683 100644 --- a/rest_framework/tests/breadcrumbs.py +++ b/rest_framework/tests/breadcrumbs.py @@ -1,5 +1,5 @@ -from django.conf.urls.defaults import patterns, url  from django.test import TestCase +from rest_framework.compat import patterns, url  from rest_framework.utils.breadcrumbs import get_breadcrumbs  from rest_framework.views import APIView diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index 5e6bce4e..bc44a45b 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -1,4 +1,5 @@  from django.test import TestCase +from django.test.client import RequestFactory  from rest_framework import status  from rest_framework.response import Response  from rest_framework.renderers import JSONRenderer diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 33ac4b32..843017eb 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,3 +1,4 @@ +from django.db import models  from django.test import TestCase  from django.utils import simplejson as json  from rest_framework import generics, serializers, status @@ -174,7 +175,7 @@ class TestInstanceView(TestCase):          content = {'text': 'foobar'}          request = factory.put('/1', json.dumps(content),                                content_type='application/json') -        response = self.view(request, pk=1).render() +        response = self.view(request, pk='1').render()          self.assertEquals(response.status_code, status.HTTP_200_OK)          self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})          updated = self.objects.get(id=1) @@ -315,3 +316,36 @@ class TestCreateModelWithAutoNowAddField(TestCase):          self.assertEquals(response.status_code, status.HTTP_201_CREATED)          created = self.objects.get(id=1)          self.assertEquals(created.content, 'foobar') + + +# Test for particularly ugly reression with m2m in browseable API +class ClassB(models.Model): +    name = models.CharField(max_length=255) + + +class ClassA(models.Model): +    name = models.CharField(max_length=255) +    childs = models.ManyToManyField(ClassB, blank=True, null=True) + + +class ClassASerializer(serializers.ModelSerializer): +    childs = serializers.ManyPrimaryKeyRelatedField(source='childs') + +    class Meta: +        model = ClassA + + +class ExampleView(generics.ListCreateAPIView): +    serializer_class = ClassASerializer +    model = ClassA + + +class TestM2MBrowseableAPI(TestCase): +    def test_m2m_in_browseable_api(self): +        """ +        Test for particularly ugly reression with m2m in browseable API +        """ +        request = factory.get('/', HTTP_ACCEPT='text/html') +        view = ExampleView().as_view() +        response = view(request).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py index 4caed59e..54096206 100644 --- a/rest_framework/tests/htmlrenderer.py +++ b/rest_framework/tests/htmlrenderer.py @@ -1,9 +1,9 @@  from django.core.exceptions import PermissionDenied -from django.conf.urls.defaults import patterns, url  from django.http import Http404  from django.test import TestCase  from django.template import TemplateDoesNotExist, Template  import django.template.loader +from rest_framework.compat import patterns, url  from rest_framework.decorators import api_view, renderer_classes  from rest_framework.renderers import TemplateHTMLRenderer  from rest_framework.response import Response diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index 24bf61bf..ee4d8e57 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -1,8 +1,8 @@ -from django.conf.urls.defaults import patterns, url  from django.test import TestCase  from django.test.client import RequestFactory  from django.utils import simplejson as json  from rest_framework import generics, status, serializers +from rest_framework.compat import patterns, url  from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel  factory = RequestFactory() diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index e2b287d0..0759650a 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -51,6 +51,11 @@ class RESTFrameworkModel(models.Model):          abstract = True +class HasPositiveIntegerAsChoice(RESTFrameworkModel): +    some_choices = ((1, 'A'), (2, 'B'), (3, 'C')) +    some_integer = models.PositiveIntegerField(choices=some_choices) + +  class Anchor(RESTFrameworkModel):      text = models.CharField(max_length=100, default='anchor') @@ -160,7 +165,7 @@ class Photo(RESTFrameworkModel):  # Model for issue #324  class BlankFieldModel(RESTFrameworkModel): -    title = models.CharField(max_length=100, blank=True, null=True) +    title = models.CharField(max_length=100, blank=True, null=False)  # Model for issue #380 diff --git a/rest_framework/tests/modelviews.py b/rest_framework/tests/modelviews.py index 1f8468e8..f12e3b97 100644 --- a/rest_framework/tests/modelviews.py +++ b/rest_framework/tests/modelviews.py @@ -1,4 +1,4 @@ -# from django.conf.urls.defaults import patterns, url +# from rest_framework.compat import patterns, url  # from django.forms import ModelForm  # from django.contrib.auth.models import Group, User  # from rest_framework.resources import ModelResource diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py new file mode 100644 index 00000000..24039410 --- /dev/null +++ b/rest_framework/tests/relations_hyperlink.py @@ -0,0 +1,424 @@ +from django.db import models +from django.test import TestCase +from rest_framework import serializers +from rest_framework.compat import patterns, url + + +def dummy_view(request, pk): +    pass + +urlpatterns = patterns('', +    url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), +    url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), +    url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), +    url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'), +    url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), +) + + +# ManyToMany + +class ManyToManyTarget(models.Model): +    name = models.CharField(max_length=100) + + +class ManyToManySource(models.Model): +    name = models.CharField(max_length=100) +    targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') + + +class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): +    sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail') + +    class Meta: +        model = ManyToManyTarget + + +class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = ManyToManySource + + +# ForeignKey + +class ForeignKeyTarget(models.Model): +    name = models.CharField(max_length=100) + + +class ForeignKeySource(models.Model): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(ForeignKeyTarget, related_name='sources') + + +class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): +    sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail') + +    class Meta: +        model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = ForeignKeySource + + +# Nullable ForeignKey + +class NullableForeignKeySource(models.Model): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, +                               related_name='nullable_sources') + + +class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = NullableForeignKeySource + + +# TODO: Add test that .data cannot be accessed prior to .is_valid + +class HyperlinkedManyToManyTests(TestCase): +    urls = 'rest_framework.tests.relations_hyperlink' + +    def setUp(self): +        for idx in range(1, 4): +            target = ManyToManyTarget(name='target-%d' % idx) +            target.save() +            source = ManyToManySource(name='source-%d' % idx) +            source.save() +            for target in ManyToManyTarget.objects.all(): +                source.targets.add(target) + +    def test_many_to_many_retrieve(self): +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset) +        expected = [ +                {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, +                {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, +                {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_many_to_many_retrieve(self): +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset) +        expected = [ +            {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, +            {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, +            {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_many_to_many_update(self): +        data = {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} +        instance = ManyToManySource.objects.get(pk=1) +        serializer = ManyToManySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset) +        expected = [ +                {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}, +                {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, +                {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_many_to_many_update(self): +        data = {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']} +        instance = ManyToManyTarget.objects.get(pk=1) +        serializer = ManyToManyTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, data) +        serializer.save() + +        # Ensure target 1 is updated, and everything else is as expected +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset) +        expected = [ +            {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']}, +            {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, +            {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} + +        ] +        self.assertEquals(serializer.data, expected) + +    def test_many_to_many_create(self): +        data = {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']} +        serializer = ManyToManySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'source-4') + +        # Ensure source 4 is added, and everything else is as expected +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset) +        expected = [ +            {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, +            {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, +            {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}, +            {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_many_to_many_create(self): +        data = {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']} +        serializer = ManyToManyTargetSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'target-4') + +        # Ensure target 4 is added, and everything else is as expected +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset) +        expected = [ +            {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, +            {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, +            {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}, +            {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']} +        ] +        self.assertEquals(serializer.data, expected) + + +class HyperlinkedForeignKeyTests(TestCase): +    urls = 'rest_framework.tests.relations_hyperlink' + +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        new_target = ForeignKeyTarget(name='target-2') +        new_target.save() +        for idx in range(1, 4): +            source = ForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve(self): +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset) +        expected = [ +            {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, +            {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, +            {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_foreign_key_retrieve(self): +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, +            {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update(self): +        data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset) +        expected = [ +            {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'}, +            {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, +            {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_foreign_key_update(self): +        data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']} +        instance = ForeignKeyTarget.objects.get(pk=2) +        serializer = ForeignKeyTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, data) +        serializer.save() + +        # Ensure target 2 is update, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']}, +            {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_create(self): +        data = {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'} +        serializer = ForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'source-4') + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset) +        expected = [ +            {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, +            {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, +            {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}, +            {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_foreign_key_create(self): +        data = {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']} +        serializer = ForeignKeyTargetSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'target-3') + +        # Ensure target 4 is added, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']}, +            {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, +            {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update_with_invalid_null(self): +        data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': None} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) + + +class HyperlinkedNullableForeignKeyTests(TestCase): +    urls = 'rest_framework.tests.relations_hyperlink' + +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        for idx in range(1, 4): +            if idx == 3: +                target = None +            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve_with_null(self): +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, +            {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, +            {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_create_with_valid_null(self): +        data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, +            {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, +            {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, +            {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_create_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': ''} +        expected_data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, expected_data) +        self.assertEqual(obj.name, u'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, +            {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, +            {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, +            {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update_with_valid_null(self): +        data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, +            {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, +            {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': ''} +        expected_data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, expected_data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, +            {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, +            {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, +        ] +        self.assertEquals(serializer.data, expected) + +    # reverse foreign keys MUST be read_only +    # In the general case they do not provide .remove() or .clear() +    # and cannot be arbitrarily set. + +    # def test_reverse_foreign_key_update(self): +    #     data = {'id': 1, 'name': u'target-1', 'sources': [1]} +    #     instance = ForeignKeyTarget.objects.get(pk=1) +    #     serializer = ForeignKeyTargetSerializer(instance, data=data) +    #     self.assertTrue(serializer.is_valid()) +    #     self.assertEquals(serializer.data, data) +    #     serializer.save() + +    #     # Ensure target 1 is updated, and everything else is as expected +    #     queryset = ForeignKeyTarget.objects.all() +    #     serializer = ForeignKeyTargetSerializer(queryset) +    #     expected = [ +    #         {'id': 1, 'name': u'target-1', 'sources': [1]}, +    #         {'id': 2, 'name': u'target-2', 'sources': []}, +    #     ] +    #     self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py new file mode 100644 index 00000000..b1147378 --- /dev/null +++ b/rest_framework/tests/relations_nested.py @@ -0,0 +1,102 @@ +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +# ForeignKey + +class ForeignKeyTarget(models.Model): +    name = models.CharField(max_length=100) + + +class ForeignKeySource(models.Model): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(ForeignKeyTarget, related_name='sources') + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): +    class Meta: +        depth = 1 +        model = ForeignKeySource + + +class FlatForeignKeySourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = ForeignKeySource + + +class ForeignKeyTargetSerializer(serializers.ModelSerializer): +    sources = FlatForeignKeySourceSerializer() + +    class Meta: +        model = ForeignKeyTarget + + +# Nullable ForeignKey + +class NullableForeignKeySource(models.Model): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, +                               related_name='nullable_sources') + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): +    class Meta: +        depth = 1 +        model = NullableForeignKeySource + + +class ReverseForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        new_target = ForeignKeyTarget(name='target-2') +        new_target.save() +        for idx in range(1, 4): +            source = ForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve(self): +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}}, +            {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}}, +            {'id': 3, 'name': u'source-3', 'target': {'id': 1, 'name': u'target-1'}}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_foreign_key_retrieve(self): +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'sources': [ +                {'id': 1, 'name': u'source-1', 'target': 1}, +                {'id': 2, 'name': u'source-2', 'target': 1}, +                {'id': 3, 'name': u'source-3', 'target': 1}, +            ]}, +            {'id': 2, 'name': u'target-2', 'sources': [ +            ]} +        ] +        self.assertEquals(serializer.data, expected) + + +class NestedNullableForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        for idx in range(1, 4): +            if idx == 3: +                target = None +            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve_with_null(self): +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}}, +            {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}}, +            {'id': 3, 'name': u'source-3', 'target': None}, +        ] +        self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/pk_relations.py b/rest_framework/tests/relations_pk.py index e3360939..01109ef9 100644 --- a/rest_framework/tests/pk_relations.py +++ b/rest_framework/tests/relations_pk.py @@ -38,7 +38,7 @@ class ForeignKeySource(models.Model):  class ForeignKeyTargetSerializer(serializers.ModelSerializer): -    sources = serializers.ManyPrimaryKeyRelatedField(read_only=True) +    sources = serializers.ManyPrimaryKeyRelatedField()      class Meta:          model = ForeignKeyTarget @@ -216,6 +216,60 @@ class PKForeignKeyTests(TestCase):          ]          self.assertEquals(serializer.data, expected) +    def test_reverse_foreign_key_update(self): +        data = {'id': 2, 'name': u'target-2', 'sources': [1, 3]} +        instance = ForeignKeyTarget.objects.get(pk=2) +        serializer = ForeignKeyTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, data) +        serializer.save() + +        # Ensure target 2 is update, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'sources': [2]}, +            {'id': 2, 'name': u'target-2', 'sources': [1, 3]}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_create(self): +        data = {'id': 4, 'name': u'source-4', 'target': 2} +        serializer = ForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'source-4') + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 1}, +            {'id': 2, 'name': u'source-2', 'target': 1}, +            {'id': 3, 'name': u'source-3', 'target': 1}, +            {'id': 4, 'name': u'source-4', 'target': 2}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_foreign_key_create(self): +        data = {'id': 3, 'name': u'target-3', 'sources': [1, 3]} +        serializer = ForeignKeyTargetSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'target-3') + +        # Ensure target 4 is added, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'sources': [2]}, +            {'id': 2, 'name': u'target-2', 'sources': []}, +            {'id': 3, 'name': u'target-3', 'sources': [1, 3]}, +        ] +        self.assertEquals(serializer.data, expected) +      def test_foreign_key_update_with_invalid_null(self):          data = {'id': 1, 'name': u'source-1', 'target': None}          instance = ForeignKeySource.objects.get(pk=1) @@ -229,9 +283,21 @@ class PKNullableForeignKeyTests(TestCase):          target = ForeignKeyTarget(name='target-1')          target.save()          for idx in range(1, 4): +            if idx == 3: +                target = None              source = NullableForeignKeySource(name='source-%d' % idx, target=target)              source.save() +    def test_foreign_key_retrieve_with_null(self): +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 1}, +            {'id': 2, 'name': u'source-2', 'target': 1}, +            {'id': 3, 'name': u'source-3', 'target': None}, +        ] +        self.assertEquals(serializer.data, expected) +      def test_foreign_key_create_with_valid_null(self):          data = {'id': 4, 'name': u'source-4', 'target': None}          serializer = NullableForeignKeySourceSerializer(data=data) @@ -246,7 +312,7 @@ class PKNullableForeignKeyTests(TestCase):          expected = [              {'id': 1, 'name': u'source-1', 'target': 1},              {'id': 2, 'name': u'source-2', 'target': 1}, -            {'id': 3, 'name': u'source-3', 'target': 1}, +            {'id': 3, 'name': u'source-3', 'target': None},              {'id': 4, 'name': u'source-4', 'target': None}          ]          self.assertEquals(serializer.data, expected) @@ -270,7 +336,7 @@ class PKNullableForeignKeyTests(TestCase):          expected = [              {'id': 1, 'name': u'source-1', 'target': 1},              {'id': 2, 'name': u'source-2', 'target': 1}, -            {'id': 3, 'name': u'source-3', 'target': 1}, +            {'id': 3, 'name': u'source-3', 'target': None},              {'id': 4, 'name': u'source-4', 'target': None}          ]          self.assertEquals(serializer.data, expected) @@ -289,7 +355,7 @@ class PKNullableForeignKeyTests(TestCase):          expected = [              {'id': 1, 'name': u'source-1', 'target': None},              {'id': 2, 'name': u'source-2', 'target': 1}, -            {'id': 3, 'name': u'source-3', 'target': 1} +            {'id': 3, 'name': u'source-3', 'target': None}          ]          self.assertEquals(serializer.data, expected) @@ -312,7 +378,7 @@ class PKNullableForeignKeyTests(TestCase):          expected = [              {'id': 1, 'name': u'source-1', 'target': None},              {'id': 2, 'name': u'source-2', 'target': 1}, -            {'id': 3, 'name': u'source-3', 'target': 1} +            {'id': 3, 'name': u'source-3', 'target': None}          ]          self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py index 9be4b114..c1b4e624 100644 --- a/rest_framework/tests/renderers.py +++ b/rest_framework/tests/renderers.py @@ -1,13 +1,12 @@  import pickle  import re -from django.conf.urls.defaults import patterns, url, include  from django.core.cache import cache  from django.test import TestCase  from django.test.client import RequestFactory  from rest_framework import status, permissions -from rest_framework.compat import yaml +from rest_framework.compat import yaml, patterns, url, include  from rest_framework.response import Response  from rest_framework.views import APIView  from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ @@ -444,19 +443,19 @@ class CacheRenderTest(TestCase):              return          if state == None:              return -        if isinstance(state,tuple): -            if not isinstance(state[0],dict): -                state=state[1] +        if isinstance(state, tuple): +            if not isinstance(state[0], dict): +                state = state[1]              else: -                state=state[0].update(state[1]) +                state = state[0].update(state[1])          result = {}          for i in state:              try: -                pickle.dumps(state[i],protocol=2) +                pickle.dumps(state[i], protocol=2)              except pickle.PicklingError:                  if not state[i] in seen:                      seen.append(state[i]) -                    result[i] = cls._get_pickling_errors(state[i],seen) +                    result[i] = cls._get_pickling_errors(state[i], seen)          return result      def http_resp(self, http_method, url): diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index 2850992d..1f05ff8f 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -1,16 +1,15 @@  """  Tests for content parsing, and form-overloaded content parsing.  """ -from django.conf.urls.defaults import patterns  from django.contrib.auth.models import User  from django.contrib.auth import authenticate, login, logout  from django.contrib.sessions.middleware import SessionMiddleware  from django.test import TestCase, Client +from django.test.client import RequestFactory  from django.utils import simplejson as json -  from rest_framework import status  from rest_framework.authentication import SessionAuthentication -from django.test.client import RequestFactory +from rest_framework.compat import patterns  from rest_framework.parsers import (      BaseParser,      FormParser, @@ -304,3 +303,11 @@ class TestUserSetter(TestCase):          self.assertFalse(self.request.user.is_anonymous())          logout(self.request)          self.assertTrue(self.request.user.is_anonymous()) + + +class TestAuthSetter(TestCase): + +    def test_auth_can_be_set(self): +        request = Request(factory.get('/')) +        request.auth = 'DUMMY' +        self.assertEqual(request.auth, 'DUMMY') diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py index d7b75450..875f4d42 100644 --- a/rest_framework/tests/response.py +++ b/rest_framework/tests/response.py @@ -1,8 +1,5 @@ -import unittest - -from django.conf.urls.defaults import patterns, url, include  from django.test import TestCase - +from rest_framework.compat import patterns, url, include  from rest_framework.response import Response  from rest_framework.views import APIView  from rest_framework import status diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/reverse.py index fd9a7d64..8c86e1fb 100644 --- a/rest_framework/tests/reverse.py +++ b/rest_framework/tests/reverse.py @@ -1,6 +1,6 @@ -from django.conf.urls.defaults import patterns, url  from django.test import TestCase  from django.test.client import RequestFactory +from rest_framework.compat import patterns, url  from rest_framework.reverse import reverse  factory = RequestFactory() diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 780177aa..701b2f47 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -2,7 +2,7 @@ import datetime  import pickle  from django.test import TestCase  from rest_framework import serializers -from rest_framework.tests.models import (Album, ActionItem, Anchor, BasicModel, +from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,      BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,      ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) @@ -69,6 +69,11 @@ class AlbumsSerializer(serializers.ModelSerializer):          model = Album          fields = ['title']  # lists are also valid options +class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): +    class Meta: +        model = HasPositiveIntegerAsChoice +        fields = ['some_integer'] +  class BasicTests(TestCase):      def setUp(self): @@ -285,6 +290,12 @@ class ValidationTests(TestCase):          self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']}) +class PositiveIntegerAsChoiceTests(TestCase): +    def test_positive_integer_in_json_is_correctly_parsed(self): +        data = {'some_integer':1} +        serializer = PositiveIntegerAsChoiceSerializer(data=data) +        self.assertEquals(serializer.is_valid(), True) +  class ModelValidationTests(TestCase):      def test_validate_unique(self):          """ @@ -297,6 +308,38 @@ class ModelValidationTests(TestCase):          self.assertFalse(second_serializer.is_valid())          self.assertEqual(second_serializer.errors,  {'title': [u'Album with this Title already exists.']}) +    def test_foreign_key_with_partial(self): +        """ +        Test ModelSerializer validation with partial=True + +        Specifically test foreign key validation. +        """ + +        album = Album(title='test') +        album.save() + +        class PhotoSerializer(serializers.ModelSerializer): +            class Meta: +                model = Photo + +        photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk}) +        self.assertTrue(photo_serializer.is_valid()) +        photo = photo_serializer.save() + +        # Updating only the album (foreign key) +        photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True) +        self.assertTrue(photo_serializer.is_valid()) +        self.assertTrue(photo_serializer.save()) + +        # Updating only the description +        photo_serializer = PhotoSerializer(instance=photo, +                                           data={'description': 'new'}, +                                           partial=True) + +        self.assertTrue(photo_serializer.is_valid()) +        self.assertTrue(photo_serializer.save()) + +  class RegexValidationTest(TestCase):      def test_create_failed(self): @@ -688,6 +731,10 @@ class BlankFieldTests(TestCase):          serializer = self.model_serializer_class(data=self.data)          self.assertEquals(serializer.is_valid(), True) +    def test_create_model_null_field(self): +        serializer = self.model_serializer_class(data={'title': None}) +        self.assertEquals(serializer.is_valid(), True) +      def test_create_not_blank_field(self):          """          Test to ensure blank data in a field not marked as blank=True @@ -704,6 +751,10 @@ class BlankFieldTests(TestCase):          serializer = self.not_blank_model_serializer_class(data=self.data)          self.assertEquals(serializer.is_valid(), False) +    def test_create_model_null_field(self): +        serializer = self.model_serializer_class(data={}) +        self.assertEquals(serializer.is_valid(), True) +  #test for issue #460  class SerializerPickleTests(TestCase): diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py index c90224aa..97f492ff 100644 --- a/rest_framework/tests/testcases.py +++ b/rest_framework/tests/testcases.py @@ -6,6 +6,7 @@ from django.test import TestCase  NO_SETTING = ('!', None) +  class TestSettingsManager(object):      """      A class which can modify some Django settings temporarily for a @@ -19,7 +20,7 @@ class TestSettingsManager(object):          self._original_settings = {}      def set(self, **kwargs): -        for k,v in kwargs.iteritems(): +        for k, v in kwargs.iteritems():              self._original_settings.setdefault(k, getattr(settings, k,                                                            NO_SETTING))              setattr(settings, k, v) @@ -31,7 +32,7 @@ class TestSettingsManager(object):          call_command('syncdb', verbosity=0)      def revert(self): -        for k,v in self._original_settings.iteritems(): +        for k, v in self._original_settings.iteritems():              if v == NO_SETTING:                  delattr(settings, k)              else: @@ -57,6 +58,7 @@ class SettingsTestCase(TestCase):      def tearDown(self):          self.settings_manager.revert() +  class TestModelsTestCase(SettingsTestCase):      def setUp(self, *args, **kwargs):          installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',) diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 0ad926fa..143928c9 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,4 +1,4 @@ -from django.conf.urls.defaults import url +from rest_framework.compat import url  from rest_framework.settings import api_settings diff --git a/rest_framework/urls.py b/rest_framework/urls.py index bcdc23e7..fbe4bc07 100644 --- a/rest_framework/urls.py +++ b/rest_framework/urls.py @@ -12,7 +12,7 @@ your authentication settings include `SessionAuthentication`.          url(r'^auth', include('rest_framework.urls', namespace='rest_framework'))      )  """ -from django.conf.urls.defaults import patterns, url +from rest_framework.compat import patterns, url  template_name = {'template_name': 'rest_framework/login.html'}  | 
