diff options
Diffstat (limited to 'rest_framework/mixins.py')
| -rw-r--r-- | rest_framework/mixins.py | 31 | 
1 files changed, 17 insertions, 14 deletions
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index cd104a7c..1edcfa5c 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -15,20 +15,20 @@ class CreateModelMixin(object):      Should be mixed in with any `BaseView`.      """      def create(self, request, *args, **kwargs): -        serializer = self.get_serializer(data=request.DATA) +        serializer = self.get_serializer(data=request.DATA, files=request.FILES)          if serializer.is_valid():              self.pre_save(serializer.object)              self.object = serializer.save()              headers = self.get_success_headers(serializer.data)              return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)          return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) -     +      def get_success_headers(self, data): -        if 'url' in data: -            return {'Location': data.get('url')} -        else: +        try: +            return {'Location': data['url']} +        except (TypeError, KeyError):              return {} -     +      def pre_save(self, obj):          pass @@ -41,14 +41,16 @@ class ListModelMixin(object):      empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."      def list(self, request, *args, **kwargs): -        self.object_list = self.get_filtered_queryset() +        queryset = self.get_queryset() +        self.object_list = self.filter_queryset(queryset)          # Default is to allow empty querysets.  This can be altered by setting          # `.allow_empty = False`, to raise 404 errors on empty querysets.          allow_empty = self.get_allow_empty() -        if not allow_empty and len(self.object_list) == 0: -            error_args = {'class_name': self.__class__.__name__} -            raise Http404(self.empty_error % error_args) +        if not allow_empty and not self.object_list: +            class_name = self.__class__.__name__ +            error_msg = self.empty_error % {'class_name': class_name} +            raise Http404(error_msg)          # Pagination size is set by the `.paginate_by` attribute,          # which may be `None` to disable pagination. @@ -82,17 +84,18 @@ class UpdateModelMixin(object):      def update(self, request, *args, **kwargs):          try:              self.object = self.get_object() -            success_status = status.HTTP_200_OK +            created = False          except Http404:              self.object = None -            success_status = status.HTTP_201_CREATED +            created = True -        serializer = self.get_serializer(self.object, data=request.DATA) +        serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES)          if serializer.is_valid():              self.pre_save(serializer.object)              self.object = serializer.save() -            return Response(serializer.data, status=success_status) +            status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK +            return Response(serializer.data, status=status_code)          return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)  | 
