diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/compat.py | 6 | ||||
| -rw-r--r-- | rest_framework/generics.py | 14 | ||||
| -rw-r--r-- | rest_framework/mixins.py | 4 | ||||
| -rw-r--r-- | rest_framework/tests/decorators.py | 16 | ||||
| -rw-r--r-- | rest_framework/tests/generics.py | 16 | ||||
| -rw-r--r-- | rest_framework/tests/utils.py | 27 | ||||
| -rw-r--r-- | rest_framework/tests/views.py | 4 | 
7 files changed, 79 insertions, 8 deletions
| diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 86952fb8..5508f6c0 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -96,6 +96,12 @@ else:              update_wrapper(view, cls.dispatch, assigned=())              return view +# Taken from @markotibold's attempt at supporting PATCH. +# https://github.com/markotibold/django-rest-framework/tree/patch +http_method_names = set(View.http_method_names) +http_method_names.add('patch') +View.http_method_names = list(http_method_names)  # PATCH method is not implemented by Django +  # PUT, DELETE do not require CSRF until 1.4.  They should.  Make it better.  if django.VERSION >= (1, 4):      from django.middleware.csrf import CsrfViewMiddleware diff --git a/rest_framework/generics.py b/rest_framework/generics.py index dd8dfcf8..14e4430e 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -47,14 +47,14 @@ class GenericAPIView(views.APIView):          return serializer_class -    def get_serializer(self, instance=None, data=None, files=None): +    def get_serializer(self, instance=None, data=None, files=None, partial=False):          """          Return the serializer instance that should be used for validating and          deserializing input, and for serializing output.          """          serializer_class = self.get_serializer_class()          context = self.get_serializer_context() -        return serializer_class(instance, data=data, files=files, context=context) +        return serializer_class(instance, data=data, files=files, partial=partial, context=context)  class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): @@ -169,7 +169,10 @@ class UpdateAPIView(mixins.UpdateModelMixin,      Concrete view for updating a model instance.      """      def put(self, request, *args, **kwargs): -        return self.update(request, *args, **kwargs) +        return self.update(request, partial=False, *args, **kwargs) + +    def patch(self, request, *args, **kwargs): +        return self.update(request, partial=True, *args, **kwargs)  class ListCreateAPIView(mixins.ListModelMixin, @@ -209,7 +212,10 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,          return self.retrieve(request, *args, **kwargs)      def put(self, request, *args, **kwargs): -        return self.update(request, *args, **kwargs) +        return self.update(request, partial=False, *args, **kwargs)      def delete(self, request, *args, **kwargs):          return self.destroy(request, *args, **kwargs) + +    def patch(self, request, *args, **kwargs): +        return self.update(request, partial=True, *args, **kwargs) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 2700606d..d828078d 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -81,7 +81,7 @@ class UpdateModelMixin(object):      Update a model instance.      Should be mixed in with `SingleObjectBaseView`.      """ -    def update(self, request, *args, **kwargs): +    def update(self, request, partial=False, *args, **kwargs):          try:              self.object = self.get_object()              created = False @@ -89,7 +89,7 @@ class UpdateModelMixin(object):              self.object = None              created = True -        serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES) +        serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES, partial=partial)          if serializer.is_valid():              self.pre_save(serializer.object) diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index 8079c8cb..bc44a45b 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -17,6 +17,8 @@ from rest_framework.decorators import (      permission_classes,  ) +from rest_framework.tests.utils import RequestFactory +  class DecoratorTestCase(TestCase): @@ -63,6 +65,20 @@ class DecoratorTestCase(TestCase):          response = view(request)          self.assertEqual(response.status_code, 405) +    def test_calling_patch_method(self): + +        @api_view(['GET', 'PATCH']) +        def view(request): +            return Response({}) + +        request = self.factory.patch('/') +        response = view(request) +        self.assertEqual(response.status_code, 200) + +        request = self.factory.post('/') +        response = view(request) +        self.assertEqual(response.status_code, 405) +      def test_renderer_classes(self):          @api_view(['GET']) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 7c24d84e..843017eb 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,8 +1,8 @@  from django.db import models  from django.test import TestCase -from django.test.client import RequestFactory  from django.utils import simplejson as json  from rest_framework import generics, serializers, status +from rest_framework.tests.utils import RequestFactory  from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel @@ -181,6 +181,20 @@ class TestInstanceView(TestCase):          updated = self.objects.get(id=1)          self.assertEquals(updated.text, 'foobar') +    def test_patch_instance_view(self): +        """ +        PATCH requests to RetrieveUpdateDestroyAPIView should update an object. +        """ +        content = {'text': 'foobar'} +        request = factory.patch('/1', json.dumps(content), +                              content_type='application/json') + +        response = self.view(request, pk=1).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) +        updated = self.objects.get(id=1) +        self.assertEquals(updated.text, 'foobar') +      def test_delete_instance_view(self):          """          DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py new file mode 100644 index 00000000..3906adb9 --- /dev/null +++ b/rest_framework/tests/utils.py @@ -0,0 +1,27 @@ +from django.test.client import RequestFactory, FakePayload +from django.test.client import MULTIPART_CONTENT +from urlparse import urlparse + + +class RequestFactory(RequestFactory): + +    def __init__(self, **defaults): +        super(RequestFactory, self).__init__(**defaults) + +    def patch(self, path, data={}, content_type=MULTIPART_CONTENT, +            **extra): +        "Construct a PATCH request." + +        patch_data = self._encode_data(data, content_type) + +        parsed = urlparse(path) +        r = { +            'CONTENT_LENGTH': len(patch_data), +            'CONTENT_TYPE':   content_type, +            'PATH_INFO':      self._get_path(parsed), +            'QUERY_STRING':   parsed[4], +            'REQUEST_METHOD': 'PATCH', +            'wsgi.input':     FakePayload(patch_data), +        } +        r.update(extra) +        return self.request(**r) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py index 43365e07..7cd82656 100644 --- a/rest_framework/tests/views.py +++ b/rest_framework/tests/views.py @@ -18,7 +18,7 @@ class BasicView(APIView):          return Response({'method': 'POST', 'data': request.DATA}) -@api_view(['GET', 'POST', 'PUT']) +@api_view(['GET', 'POST', 'PUT', 'PATCH'])  def basic_view(request):      if request.method == 'GET':          return {'method': 'GET'} @@ -26,6 +26,8 @@ def basic_view(request):          return {'method': 'POST', 'data': request.DATA}      elif request.method == 'PUT':          return {'method': 'PUT', 'data': request.DATA} +    elif request.method == 'PATCH': +        return {'method': 'PATCH', 'data': request.DATA}  def sanitise_json_error(error_dict): | 
