diff options
Diffstat (limited to 'rest_framework')
28 files changed, 476 insertions, 175 deletions
diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 103abb27..b75c2e25 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -1,5 +1,4 @@ from rest_framework.views import APIView -from rest_framework import status from rest_framework import parsers from rest_framework import renderers from rest_framework.response import Response @@ -12,16 +11,13 @@ class ObtainAuthToken(APIView): permission_classes = () parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) renderer_classes = (renderers.JSONRenderer,) - serializer_class = AuthTokenSerializer - model = Token def post(self, request): - serializer = self.serializer_class(data=request.data) - if serializer.is_valid(): - user = serializer.validated_data['user'] - token, created = Token.objects.get_or_create(user=user) - return Response({'token': token.key}) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + serializer = AuthTokenSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + user = serializer.validated_data['user'] + token, created = Token.objects.get_or_create(user=user) + return Response({'token': token.key}) obtain_auth_token = ObtainAuthToken.as_view() diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index d28d6e22..325435b3 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -12,12 +12,14 @@ from rest_framework.views import APIView import types -def api_view(http_method_names): +def api_view(http_method_names=None): """ Decorator that converts a function-based view into an APIView subclass. Takes a list of allowed methods for the view as an argument. """ + if http_method_names is None: + http_method_names = ['GET'] def decorator(func): diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 0b06d6e6..906de3b0 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -5,20 +5,44 @@ In addition Django's built in 403 and 404 exceptions are handled. (`django.http.Http404` and `django.core.exceptions.PermissionDenied`) """ from __future__ import unicode_literals + +from django.utils.translation import ugettext_lazy as _ +from django.utils.translation import ungettext_lazy from rest_framework import status +from rest_framework.compat import force_text import math +def _force_text_recursive(data): + """ + Descend into a nested data structure, forcing any + lazy translation strings into plain text. + """ + if isinstance(data, list): + return [ + _force_text_recursive(item) for item in data + ] + elif isinstance(data, dict): + return dict([ + (key, _force_text_recursive(value)) + for key, value in data.items() + ]) + return force_text(data) + + class APIException(Exception): """ Base class for REST framework exceptions. Subclasses should provide `.status_code` and `.default_detail` properties. """ status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - default_detail = 'A server error occured' + default_detail = _('A server error occured') def __init__(self, detail=None): - self.detail = detail or self.default_detail + if detail is not None: + self.detail = force_text(detail) + else: + self.detail = force_text(self.default_detail) def __str__(self): return self.detail @@ -39,7 +63,7 @@ class ValidationError(APIException): # The details should always be coerced to a list if not already. if not isinstance(detail, dict) and not isinstance(detail, list): detail = [detail] - self.detail = detail + self.detail = _force_text_recursive(detail) def __str__(self): return str(self.detail) @@ -47,59 +71,77 @@ class ValidationError(APIException): class ParseError(APIException): status_code = status.HTTP_400_BAD_REQUEST - default_detail = 'Malformed request.' + default_detail = _('Malformed request.') class AuthenticationFailed(APIException): status_code = status.HTTP_401_UNAUTHORIZED - default_detail = 'Incorrect authentication credentials.' + default_detail = _('Incorrect authentication credentials.') class NotAuthenticated(APIException): status_code = status.HTTP_401_UNAUTHORIZED - default_detail = 'Authentication credentials were not provided.' + default_detail = _('Authentication credentials were not provided.') class PermissionDenied(APIException): status_code = status.HTTP_403_FORBIDDEN - default_detail = 'You do not have permission to perform this action.' + default_detail = _('You do not have permission to perform this action.') class MethodNotAllowed(APIException): status_code = status.HTTP_405_METHOD_NOT_ALLOWED - default_detail = "Method '%s' not allowed." + default_detail = _("Method '%s' not allowed.") def __init__(self, method, detail=None): - self.detail = detail or (self.default_detail % method) + if detail is not None: + self.detail = force_text(detail) + else: + self.detail = force_text(self.default_detail) % method class NotAcceptable(APIException): status_code = status.HTTP_406_NOT_ACCEPTABLE - default_detail = "Could not satisfy the request Accept header" + default_detail = _('Could not satisfy the request Accept header') def __init__(self, detail=None, available_renderers=None): - self.detail = detail or self.default_detail + if detail is not None: + self.detail = force_text(detail) + else: + self.detail = force_text(self.default_detail) self.available_renderers = available_renderers class UnsupportedMediaType(APIException): status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE - default_detail = "Unsupported media type '%s' in request." + default_detail = _("Unsupported media type '%s' in request.") def __init__(self, media_type, detail=None): - self.detail = detail or (self.default_detail % media_type) + if detail is not None: + self.detail = force_text(detail) + else: + self.detail = force_text(self.default_detail) % media_type class Throttled(APIException): status_code = status.HTTP_429_TOO_MANY_REQUESTS - default_detail = 'Request was throttled.' - extra_detail = " Expected available in %d second%s." + default_detail = _('Request was throttled.') + extra_detail = ungettext_lazy( + 'Expected available in %(wait)d second.', + 'Expected available in %(wait)d seconds.', + 'wait' + ) def __init__(self, wait=None, detail=None): + if detail is not None: + self.detail = force_text(detail) + else: + self.detail = force_text(self.default_detail) + if wait is None: - self.detail = detail or self.default_detail self.wait = None else: - format = (detail or self.default_detail) + self.extra_detail - self.detail = format % (wait, wait != 1 and 's' or '') self.wait = math.ceil(wait) + self.detail += ' ' + force_text( + self.extra_detail % {'wait': self.wait} + ) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 58482db5..ca9c479f 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -68,8 +68,8 @@ def get_attribute(instance, attrs): return instance[attr] except (KeyError, TypeError, AttributeError): raise exc - if is_simple_callable(instance): - return instance() + if is_simple_callable(instance): + instance = instance() return instance @@ -117,6 +117,17 @@ class CreateOnlyDefault: return '%s(%s)' % (self.__class__.__name__, repr(self.default)) +class CurrentUserDefault: + def set_context(self, serializer_field): + self.user = serializer_field.context['request'].user + + def __call__(self): + return self.user + + def __repr__(self): + return '%s()' % self.__class__.__name__ + + class SkipField(Exception): pass @@ -170,6 +181,9 @@ class Field(object): self.style = {} if style is None else style self.allow_null = allow_null + if allow_null and self.default_empty_html is empty: + self.default_empty_html = None + if validators is not None: self.validators = validators[:] @@ -248,7 +262,11 @@ class Field(object): if html.is_html_input(dictionary): # HTML forms will represent empty fields as '', and cannot # represent None or False values directly. - ret = dictionary.get(self.field_name, '') + if self.field_name not in dictionary: + if getattr(self.root, 'partial', False): + return empty + return self.default_empty_html + ret = dictionary[self.field_name] return self.default_empty_html if (ret == '') else ret return dictionary.get(self.field_name, empty) @@ -303,7 +321,6 @@ class Field(object): value = self.to_internal_value(data) self.run_validators(value) - self.validate(value) return value def run_validators(self, value): @@ -330,9 +347,6 @@ class Field(object): if errors: raise ValidationError(errors) - def validate(self, value): - pass - def to_internal_value(self, data): """ Transform the *incoming* primitive data into a native value. @@ -484,6 +498,7 @@ class CharField(Field): } initial = '' coerce_blank_to_null = False + default_empty_html = '' def __init__(self, **kwargs): self.allow_blank = kwargs.pop('allow_blank', False) @@ -936,6 +951,8 @@ class ChoiceField(Field): self.fail('invalid_choice', input=data) def to_representation(self, value): + if value in ('', None): + return value return self.choice_strings_to_values[six.text_type(value)] diff --git a/rest_framework/metadata.py b/rest_framework/metadata.py index 90d3f2e0..de829d00 100644 --- a/rest_framework/metadata.py +++ b/rest_framework/metadata.py @@ -121,7 +121,10 @@ class SimpleMetadata(BaseMetadata): if hasattr(field, 'choices'): field_info['choices'] = [ - {'value': choice_value, 'display_name': choice_name} + { + 'value': choice_value, + 'display_name': force_text(choice_name, strings_only=True) + } for choice_value, choice_name in field.choices.items() ] diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index ccb82f03..d229abec 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -256,23 +256,24 @@ class FileUploadParser(BaseParser): chunks = ChunkIter(stream, chunk_size) counters = [0] * len(upload_handlers) - for handler in upload_handlers: + for index, handler in enumerate(upload_handlers): try: handler.new_file(None, filename, content_type, content_length, encoding) except StopFutureHandlers: + upload_handlers = upload_handlers[:index + 1] break for chunk in chunks: - for i, handler in enumerate(upload_handlers): + for index, handler in enumerate(upload_handlers): chunk_length = len(chunk) - chunk = handler.receive_data_chunk(chunk, counters[i]) - counters[i] += chunk_length + chunk = handler.receive_data_chunk(chunk, counters[index]) + counters[index] += chunk_length if chunk is None: break - for i, handler in enumerate(upload_handlers): - file_obj = handler.file_complete(counters[i]) + for index, handler in enumerate(upload_handlers): + file_obj = handler.file_complete(counters[index]) if file_obj: return DataAndFiles(None, {'file': file_obj}) raise ParseError("FileUpload parse error - " diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 48ddf41e..d1ea497a 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -10,9 +10,17 @@ from django.utils.translation import ugettext_lazy as _ class PKOnlyObject(object): + """ + This is a mock object, used for when we only need the pk of the object + instance, but still want to return an object with a .pk attribute, + in order to keep the same interface as a regular model instance. + """ def __init__(self, pk): self.pk = pk + +# We assume that 'validators' are intended for the child serializer, +# rather than the parent serializer. MANY_RELATION_KWARGS = ( 'read_only', 'write_only', 'required', 'default', 'initial', 'source', 'label', 'help_text', 'style', 'error_messages' @@ -34,15 +42,34 @@ class RelatedField(Field): def __new__(cls, *args, **kwargs): # We override this method in order to automagically create - # `ManyRelation` classes instead when `many=True` is set. + # `ManyRelatedField` classes instead when `many=True` is set. if kwargs.pop('many', False): - list_kwargs = {'child_relation': cls(*args, **kwargs)} - for key in kwargs.keys(): - if key in MANY_RELATION_KWARGS: - list_kwargs[key] = kwargs[key] - return ManyRelation(**list_kwargs) + return cls.many_init(*args, **kwargs) return super(RelatedField, cls).__new__(cls, *args, **kwargs) + @classmethod + def many_init(cls, *args, **kwargs): + """ + This method handles creating a parent `ManyRelatedField` instance + when the `many=True` keyword argument is passed. + + Typically you won't need to override this method. + + Note that we're over-cautious in passing most arguments to both parent + and child classes in order to try to cover the general case. If you're + overriding this method you'll probably want something much simpler, eg: + + @classmethod + def many_init(cls, *args, **kwargs): + kwargs['child'] = cls() + return CustomManyRelatedField(*args, **kwargs) + """ + list_kwargs = {'child_relation': cls(*args, **kwargs)} + for key in kwargs.keys(): + if key in MANY_RELATION_KWARGS: + list_kwargs[key] = kwargs[key] + return ManyRelatedField(**list_kwargs) + def run_validation(self, data=empty): # We force empty strings to None values for relational fields. if data == '': @@ -286,15 +313,17 @@ class SlugRelatedField(RelatedField): return getattr(obj, self.slug_field) -class ManyRelation(Field): +class ManyRelatedField(Field): """ Relationships with `many=True` transparently get coerced into instead being - a ManyRelation with a child relationship. + a ManyRelatedField with a child relationship. - The `ManyRelation` class is responsible for handling iterating through + The `ManyRelatedField` class is responsible for handling iterating through the values and passing each one to the child relationship. - You shouldn't need to be using this class directly yourself. + This class is treated as private API. + You shouldn't generally need to be using this class directly yourself, + and should instead simply set 'many=True' on the relationship. """ initial = [] default_empty_html = [] @@ -302,7 +331,7 @@ class ManyRelation(Field): def __init__(self, child_relation=None, *args, **kwargs): self.child_relation = child_relation assert child_relation is not None, '`child_relation` is a required argument.' - super(ManyRelation, self).__init__(*args, **kwargs) + super(ManyRelatedField, self).__init__(*args, **kwargs) self.child_relation.bind(field_name='', parent=self) def get_value(self, dictionary): diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 37d3c47c..31d3ef5f 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -387,7 +387,10 @@ class HTMLFormRenderer(BaseRenderer): serializers.MultipleChoiceField: { 'base_template': 'select_multiple.html', # Also valid: 'checkbox_multiple.html' }, - serializers.ManyRelation: { + serializers.RelatedField: { + 'base_template': 'select.html', # Also valid: 'radio.html' + }, + serializers.ManyRelatedField: { 'base_template': 'select_multiple.html', # Also valid: 'checkbox_multiple.html' }, serializers.Serializer: { @@ -430,7 +433,10 @@ class HTMLFormRenderer(BaseRenderer): style['base_template'] = self.base_template style['renderer'] = self - if 'template' in style: + # This API needs to be finessed and finalized for 3.1 + if 'template' in renderer_context: + template_name = renderer_context['template'] + elif 'template' in style: template_name = style['template'] else: template_name = style['template_pack'].strip('/') + '/' + style['base_template'] @@ -516,13 +522,17 @@ class BrowsableAPIRenderer(BaseRenderer): In the absence of the View having an associated form then return None. """ + # See issue #2089 for refactoring this. serializer = getattr(data, 'serializer', None) if serializer and not getattr(serializer, 'many', False): instance = getattr(serializer, 'instance', None) else: instance = None - if request.method == method: + # If this is valid serializer data, and the form is for the same + # HTTP method as was used in the request then use the existing + # serializer instance, rather than dynamically creating a new one. + if request.method == method and serializer is not None: try: data = request.data except ParseError: @@ -548,11 +558,21 @@ class BrowsableAPIRenderer(BaseRenderer): if existing_serializer is not None: serializer = existing_serializer else: - serializer = view.get_serializer(instance=instance, data=data) + if method in ('PUT', 'PATCH'): + serializer = view.get_serializer(instance=instance, data=data) + else: + serializer = view.get_serializer(data=data) if data is not None: serializer.is_valid() form_renderer = self.form_renderer_class() - return form_renderer.render(serializer.data, self.accepted_media_type, self.renderer_context) + return form_renderer.render( + serializer.data, + self.accepted_media_type, + dict( + list(self.renderer_context.items()) + + [('template', 'rest_framework/api_form.html')] + ) + ) def get_raw_data_form(self, data, view, method, request): """ @@ -560,6 +580,7 @@ class BrowsableAPIRenderer(BaseRenderer): via standard HTML forms. (Which are typically application/x-www-form-urlencoded) """ + # See issue #2089 for refactoring this. serializer = getattr(data, 'serializer', None) if serializer and not getattr(serializer, 'many', False): instance = getattr(serializer, 'instance', None) @@ -585,7 +606,10 @@ class BrowsableAPIRenderer(BaseRenderer): # View has a serializer defined and parser class has a # corresponding renderer that can be used to render the data. - serializer = view.get_serializer(instance=instance) + if method in ('PUT', 'PATCH'): + serializer = view.get_serializer(instance=instance) + else: + serializer = view.get_serializer() # Render the raw data content renderer = renderer_class() diff --git a/rest_framework/request.py b/rest_framework/request.py index 096b3042..d7e74674 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -310,7 +310,7 @@ class Request(object): def _load_data_and_files(self): """ - Parses the request content into self.DATA and self.FILES. + Parses the request content into `self.data`. """ if not _hasattr(self, '_content_type'): self._load_method_and_content_type() diff --git a/rest_framework/response.py b/rest_framework/response.py index 0a7d313f..d6ca1aad 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -5,7 +5,6 @@ it is initialized with unrendered data, instead of a pre-rendered string. The appropriate renderer is called during Django's template response rendering. """ from __future__ import unicode_literals -import django from django.core.handlers.wsgi import STATUS_CODE_TEXT from django.template.response import SimpleTemplateResponse from django.utils import six @@ -16,9 +15,6 @@ class Response(SimpleTemplateResponse): An HttpResponse that allows its data to be rendered into arbitrary media types. """ - # TODO: remove that once Django 1.3 isn't supported - if django.VERSION >= (1, 4): - rendering_attrs = SimpleTemplateResponse.rendering_attrs + ['_closable_objects'] def __init__(self, data=None, status=None, template_name=None, headers=None, diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d83367f4..af8aeb48 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -11,6 +11,7 @@ python primitives. response content is handled by parsers and renderers. """ from django.core.exceptions import ImproperlyConfigured +from django.core.exceptions import ValidationError as DjangoValidationError from django.db import models from django.db.models.fields import FieldDoesNotExist from django.utils import six @@ -46,6 +47,9 @@ import warnings from rest_framework.relations import * # NOQA from rest_framework.fields import * # NOQA + +# We assume that 'validators' are intended for the child serializer, +# rather than the parent serializer. LIST_SERIALIZER_KWARGS = ( 'read_only', 'write_only', 'required', 'default', 'initial', 'source', 'label', 'help_text', 'style', 'error_messages', @@ -73,13 +77,36 @@ class BaseSerializer(Field): # We override this method in order to automagically create # `ListSerializer` classes instead when `many=True` is set. if kwargs.pop('many', False): - list_kwargs = {'child': cls(*args, **kwargs)} - for key in kwargs.keys(): - if key in LIST_SERIALIZER_KWARGS: - list_kwargs[key] = kwargs[key] - return ListSerializer(*args, **list_kwargs) + return cls.many_init(*args, **kwargs) return super(BaseSerializer, cls).__new__(cls, *args, **kwargs) + @classmethod + def many_init(cls, *args, **kwargs): + """ + This method implements the creation of a `ListSerializer` parent + class when `many=True` is used. You can customize it if you need to + control which keyword arguments are passed to the parent, and + which are passed to the child. + + Note that we're over-cautious in passing most arguments to both parent + and child classes in order to try to cover the general case. If you're + overriding this method you'll probably want something much simpler, eg: + + @classmethod + def many_init(cls, *args, **kwargs): + kwargs['child'] = cls() + return CustomListSerializer(*args, **kwargs) + """ + child_serializer = cls(*args, **kwargs) + list_kwargs = {'child': child_serializer} + list_kwargs.update(dict([ + (key, value) for key, value in kwargs.items() + if key in LIST_SERIALIZER_KWARGS + ])) + meta = getattr(cls, 'Meta', None) + list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer) + return list_serializer_class(*args, **list_kwargs) + def to_internal_value(self, data): raise NotImplementedError('`to_internal_value()` must be implemented.') @@ -93,6 +120,21 @@ class BaseSerializer(Field): raise NotImplementedError('`create()` must be implemented.') def save(self, **kwargs): + assert not hasattr(self, 'save_object'), ( + 'Serializer `%s.%s` has old-style version 2 `.save_object()` ' + 'that is no longer compatible with REST framework 3. ' + 'Use the new-style `.create()` and `.update()` methods instead.' % + (self.__class__.__module__, self.__class__.__name__) + ) + + assert hasattr(self, '_errors'), ( + 'You must call `.is_valid()` before calling `.save()`.' + ) + + assert not self.errors, ( + 'You cannot call `.save()` on a serializer with invalid data.' + ) + validated_data = dict( list(self.validated_data.items()) + list(kwargs.items()) @@ -230,18 +272,18 @@ class Serializer(BaseSerializer): def get_initial(self): if self._initial_data is not None: - return ReturnDict([ + return OrderedDict([ (field_name, field.get_value(self._initial_data)) for field_name, field in self.fields.items() if field.get_value(self._initial_data) is not empty and not field.read_only - ], serializer=self) + ]) - return ReturnDict([ + return OrderedDict([ (field.field_name, field.get_initial()) for field in self.fields.values() if not field.read_only - ], serializer=self) + ]) def get_value(self, dictionary): # We override the default field access in order to support @@ -297,6 +339,14 @@ class Serializer(BaseSerializer): raise ValidationError({ api_settings.NON_FIELD_ERRORS_KEY: [exc.detail] }) + except DjangoValidationError as exc: + # Normally you should raise `serializers.ValidationError` + # inside your codebase, but we handle Django's validation + # exception class as well for simpler compat. + # Eg. Calling Model.clean() explictily inside Serializer.validate() + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) + }) return value @@ -304,8 +354,8 @@ class Serializer(BaseSerializer): """ Dict of native values <- Dict of primitive datatypes. """ - ret = {} - errors = ReturnDict(serializer=self) + ret = OrderedDict() + errors = OrderedDict() fields = [ field for field in self.fields.values() if (not field.read_only) or (field.default is not empty) @@ -320,6 +370,8 @@ class Serializer(BaseSerializer): validated_value = validate_method(validated_value) except ValidationError as exc: errors[field.field_name] = exc.detail + except DjangoValidationError as exc: + errors[field.field_name] = list(exc.messages) except SkipField: pass else: @@ -334,20 +386,15 @@ class Serializer(BaseSerializer): """ Object instance -> Dict of primitive datatypes. """ - ret = ReturnDict(serializer=self) + ret = OrderedDict() fields = [field for field in self.fields.values() if not field.write_only] for field in fields: attribute = field.get_attribute(instance) if attribute is None: - value = None + ret[field.field_name] = None else: - value = field.to_representation(attribute) - transform_method = getattr(self, 'transform_' + field.field_name, None) - if transform_method is not None: - value = transform_method(value) - - ret[field.field_name] = value + ret[field.field_name] = field.to_representation(attribute) return ret @@ -373,6 +420,19 @@ class Serializer(BaseSerializer): return NestedBoundField(field, value, error) return BoundField(field, value, error) + # Include a backlink to the serializer class on return objects. + # Allows renderers such as HTMLFormRenderer to get the full field info. + + @property + def data(self): + ret = super(Serializer, self).data + return ReturnDict(ret, serializer=self) + + @property + def errors(self): + ret = super(Serializer, self).errors + return ReturnDict(ret, serializer=self) + # There's some replication of `ListField` here, # but that's probably better than obfuscating the call hierarchy. @@ -395,7 +455,7 @@ class ListSerializer(BaseSerializer): def get_initial(self): if self._initial_data is not None: return self.to_representation(self._initial_data) - return ReturnList(serializer=self) + return [] def get_value(self, dictionary): """ @@ -423,7 +483,7 @@ class ListSerializer(BaseSerializer): }) ret = [] - errors = ReturnList(serializer=self) + errors = [] for item in data: try: @@ -444,37 +504,64 @@ class ListSerializer(BaseSerializer): List of object instances -> List of dicts of primitive datatypes. """ iterable = data.all() if (hasattr(data, 'all')) else data - return ReturnList( - [self.child.to_representation(item) for item in iterable], - serializer=self + return [ + self.child.to_representation(item) for item in iterable + ] + + def update(self, instance, validated_data): + raise NotImplementedError( + "Serializers with many=True do not support multiple update by " + "default, only multiple create. For updates it is unclear how to " + "deal with insertions and deletions. If you need to support " + "multiple update, use a `ListSerializer` class and override " + "`.update()` so you can specify the behavior exactly." ) + def create(self, validated_data): + return [ + self.child.create(attrs) for attrs in validated_data + ] + def save(self, **kwargs): """ Save and return a list of object instances. """ - assert self.instance is None, ( - "Serializers do not support multiple update by default, only " - "multiple create. For updates it is unclear how to deal with " - "insertions and deletions. If you need to support multiple update, " - "use a `ListSerializer` class and override `.save()` so you can " - "specify the behavior exactly." - ) - validated_data = [ dict(list(attrs.items()) + list(kwargs.items())) for attrs in self.validated_data ] - self.instance = [ - self.child.create(attrs) for attrs in validated_data - ] + if self.instance is not None: + self.instance = self.update(self.instance, validated_data) + assert self.instance is not None, ( + '`update()` did not return an object instance.' + ) + else: + self.instance = self.create(validated_data) + assert self.instance is not None, ( + '`create()` did not return an object instance.' + ) return self.instance def __repr__(self): return representation.list_repr(self, indent=1) + # Include a backlink to the serializer class on return objects. + # Allows renderers such as HTMLFormRenderer to get the full field info. + + @property + def data(self): + ret = super(ListSerializer, self).data + return ReturnList(ret, serializer=self) + + @property + def errors(self): + ret = super(ListSerializer, self).errors + if isinstance(ret, dict): + return ReturnDict(ret, serializer=self) + return ReturnList(ret, serializer=self) + # ModelSerializer & HyperlinkedModelSerializer # -------------------------------------------- @@ -486,6 +573,14 @@ class ModelSerializer(Serializer): * A set of default fields are automatically populated. * A set of default validators are automatically populated. * Default `.create()` and `.update()` implementations are provided. + + The process of automatically determining a set of serializer fields + based on the model fields is reasonably complex, but you almost certainly + don't need to dig into the implemention. + + If the `ModelSerializer` class *doesn't* generate the set of fields that + you need you should either declare the extra/differing fields explicitly on + the serializer class, or simply use a `Serializer` class. """ _field_mapping = ClassLookupDict({ models.AutoField: IntegerField, @@ -513,13 +608,33 @@ class ModelSerializer(Serializer): }) _related_class = PrimaryKeyRelatedField - def create(self, validated_attrs): + def create(self, validated_data): + """ + We have a bit of extra checking around this in order to provide + descriptive messages when something goes wrong, but this method is + essentially just: + + return ExampleModel.objects.create(**validated_data) + + If there are many to many fields present on the instance then they + cannot be set until the model is instantiated, in which case the + implementation is like so: + + example_relationship = validated_data.pop('example_relationship') + instance = ExampleModel.objects.create(**validated_data) + instance.example_relationship = example_relationship + return instance + + The default implementation also does not handle nested relationships. + If you want to support writable nested relationships you'll need + to write an explicit `.create()` method. + """ # Check that the user isn't trying to handle a writable nested field. # If we don't do this explicitly they'd likely get a confusing # error at the point of calling `Model.objects.create()`. assert not any( - isinstance(field, BaseSerializer) and not field.read_only - for field in self.fields.values() + isinstance(field, BaseSerializer) and (key in validated_attrs) + for key, field in self.fields.items() ), ( 'The `.create()` method does not suport nested writable fields ' 'by default. Write an explicit `.create()` method for serializer ' @@ -529,16 +644,33 @@ class ModelSerializer(Serializer): ModelClass = self.Meta.model - # Remove many-to-many relationships from validated_attrs. + # Remove many-to-many relationships from validated_data. # They are not valid arguments to the default `.create()` method, # as they require that the instance has already been saved. info = model_meta.get_field_info(ModelClass) many_to_many = {} for field_name, relation_info in info.relations.items(): - if relation_info.to_many and (field_name in validated_attrs): - many_to_many[field_name] = validated_attrs.pop(field_name) + if relation_info.to_many and (field_name in validated_data): + many_to_many[field_name] = validated_data.pop(field_name) - instance = ModelClass.objects.create(**validated_attrs) + try: + instance = ModelClass.objects.create(**validated_data) + except TypeError as exc: + msg = ( + 'Got a `TypeError` when calling `%s.objects.create()`. ' + 'This may be because you have a writable field on the ' + 'serializer class that is not a valid argument to ' + '`%s.objects.create()`. You may need to make the field ' + 'read-only, or override the %s.create() method to handle ' + 'this correctly.\nOriginal exception text was: %s.' % + ( + ModelClass.__name__, + ModelClass.__name__, + self.__class__.__name__, + exc + ) + ) + raise TypeError(msg) # Save many-to-many relationships after the instance is created. if many_to_many: @@ -547,10 +679,10 @@ class ModelSerializer(Serializer): return instance - def update(self, instance, validated_attrs): + def update(self, instance, validated_data): assert not any( - isinstance(field, BaseSerializer) and not field.read_only - for field in self.fields.values() + isinstance(field, BaseSerializer) and (key in validated_attrs) + for key, field in self.fields.items() ), ( 'The `.update()` method does not suport nested writable fields ' 'by default. Write an explicit `.update()` method for serializer ' @@ -558,20 +690,25 @@ class ModelSerializer(Serializer): (self.__class__.__module__, self.__class__.__name__) ) - for attr, value in validated_attrs.items(): + for attr, value in validated_data.items(): setattr(instance, attr, value) instance.save() return instance def get_validators(self): + # If the validators have been declared explicitly then use that. + validators = getattr(getattr(self, 'Meta', None), 'validators', None) + if validators is not None: + return validators + + # Determine the default set of validators. + validators = [] + model_class = self.Meta.model field_names = set([ field.source for field in self.fields.values() if (field.source != '*') and ('.' not in field.source) ]) - validators = getattr(getattr(self, 'Meta', None), 'validators', []) - model_class = self.Meta.model - # Note that we make sure to check `unique_together` both on the # base model class, but also on any parent classes. for parent_class in [model_class] + list(model_class._meta.parents.keys()): @@ -658,49 +795,62 @@ class ModelSerializer(Serializer): # Determine if we need any additional `HiddenField` or extra keyword # arguments to deal with `unique_for` dates that are required to # be in the input data in order to validate it. - unique_fields = {} + hidden_fields = {} + unique_constraint_names = set() + for model_field_name, field_name in model_field_mapping.items(): try: model_field = model._meta.get_field(model_field_name) except FieldDoesNotExist: continue - # Deal with each of the `unique_for_*` cases. - for date_field_name in ( + # Include each of the `unique_for_*` field names. + unique_constraint_names |= set([ model_field.unique_for_date, model_field.unique_for_month, model_field.unique_for_year - ): - if date_field_name is None: - continue - - # Get the model field that is refered too. - date_field = model._meta.get_field(date_field_name) - - if date_field.auto_now_add: - default = CreateOnlyDefault(timezone.now) - elif date_field.auto_now: - default = timezone.now - elif date_field.has_default(): - default = model_field.default - else: - default = empty - - if date_field_name in model_field_mapping: - # The corresponding date field is present in the serializer - if date_field_name not in extra_kwargs: - extra_kwargs[date_field_name] = {} - if default is empty: - if 'required' not in extra_kwargs[date_field_name]: - extra_kwargs[date_field_name]['required'] = True - else: - if 'default' not in extra_kwargs[date_field_name]: - extra_kwargs[date_field_name]['default'] = default + ]) + + unique_constraint_names -= set([None]) + + # Include each of the `unique_together` field names, + # so long as all the field names are included on the serializer. + for parent_class in [model] + list(model._meta.parents.keys()): + for unique_together_list in parent_class._meta.unique_together: + if set(fields).issuperset(set(unique_together_list)): + unique_constraint_names |= set(unique_together_list) + + # Now we have all the field names that have uniqueness constraints + # applied, we can add the extra 'required=...' or 'default=...' + # arguments that are appropriate to these fields, or add a `HiddenField` for it. + for unique_constraint_name in unique_constraint_names: + # Get the model field that is refered too. + unique_constraint_field = model._meta.get_field(unique_constraint_name) + + if getattr(unique_constraint_field, 'auto_now_add', None): + default = CreateOnlyDefault(timezone.now) + elif getattr(unique_constraint_field, 'auto_now', None): + default = timezone.now + elif unique_constraint_field.has_default(): + default = unique_constraint_field.default + else: + default = empty + + if unique_constraint_name in model_field_mapping: + # The corresponding field is present in the serializer + if unique_constraint_name not in extra_kwargs: + extra_kwargs[unique_constraint_name] = {} + if default is empty: + if 'required' not in extra_kwargs[unique_constraint_name]: + extra_kwargs[unique_constraint_name]['required'] = True else: - # The corresponding date field is not present in the, - # serializer. We have a default to use for the date, so - # add in a hidden field that populates it. - unique_fields[date_field_name] = HiddenField(default=default) + if 'default' not in extra_kwargs[unique_constraint_name]: + extra_kwargs[unique_constraint_name]['default'] = default + elif default is not empty: + # The corresponding field is not present in the, + # serializer. We have a default to use for it, so + # add in a hidden field that populates it. + hidden_fields[unique_constraint_name] = HiddenField(default=default) # Now determine the fields that should be included on the serializer. for field_name in fields: @@ -776,12 +926,16 @@ class ModelSerializer(Serializer): 'validators', 'queryset' ]: kwargs.pop(attr, None) + + if extras.get('default') and kwargs.get('required') is False: + kwargs.pop('required') + kwargs.update(extras) # Create the serializer field. ret[field_name] = field_cls(**kwargs) - for field_name, field in unique_fields.items(): + for field_name, field in hidden_fields.items(): ret[field_name] = field return ret diff --git a/rest_framework/templates/rest_framework/api_form.html b/rest_framework/templates/rest_framework/api_form.html new file mode 100644 index 00000000..96f924ed --- /dev/null +++ b/rest_framework/templates/rest_framework/api_form.html @@ -0,0 +1,8 @@ +{% load rest_framework %} +{% csrf_token %} +{% for field in form %} + {% if not field.read_only %} + {% render_field field style=style %} + {% endif %} +{% endfor %} +<!-- form.non_field_errors --> diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index e9d99a65..e9668193 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -237,13 +237,6 @@ </div> <!-- END Content --> </div><!-- /.container --> - - <footer> - {% block footer %} - <p>Sponsored by <a href="http://dabapps.com/">DabApps</a>.</p> - {% endblock %} - </footer> - </div><!-- ./wrapper --> {% block script %} diff --git a/rest_framework/templates/rest_framework/horizontal/input.html b/rest_framework/templates/rest_framework/horizontal/input.html index df4aa40f..c41cd523 100644 --- a/rest_framework/templates/rest_framework/horizontal/input.html +++ b/rest_framework/templates/rest_framework/horizontal/input.html @@ -3,7 +3,7 @@ <label class="col-sm-2 control-label {% if style.hide_label %}sr-only{% endif %}">{{ field.label }}</label> {% endif %} <div class="col-sm-10"> - <input name="{{ field.name }}" class="form-control" type="{{ style.input_type }}" {% if style.placeholder %}placeholder="{{ style.placeholder }}"{% endif %} {% if field.value %}value="{{ field.value }}"{% endif %}> + <input name="{{ field.name }}" {% if style.input_type != "file" %}class="form-control"{% endif %} type="{{ style.input_type }}" {% if style.placeholder %}placeholder="{{ style.placeholder }}"{% endif %} {% if field.value %}value="{{ field.value }}"{% endif %}> {% if field.errors %} {% for error in field.errors %}<span class="help-block">{{ error }}</span>{% endfor %} {% endif %} diff --git a/rest_framework/templates/rest_framework/horizontal/list_fieldset.html b/rest_framework/templates/rest_framework/horizontal/list_fieldset.html index a30514c6..a9ff04a6 100644 --- a/rest_framework/templates/rest_framework/horizontal/list_fieldset.html +++ b/rest_framework/templates/rest_framework/horizontal/list_fieldset.html @@ -5,9 +5,12 @@ <legend class="control-label col-sm-2 {% if style.hide_label %}sr-only{% endif %}" style="border-bottom: 0">{{ field.label }}</legend> </div> {% endif %} + <!-- <ul> {% for child in field.value %} <li>TODO</li> {% endfor %} </ul> + --> + <p>Lists are not currently supported in HTML input.</p> </fieldset> diff --git a/rest_framework/templates/rest_framework/horizontal/select.html b/rest_framework/templates/rest_framework/horizontal/select.html index 1d00f424..380b38e9 100644 --- a/rest_framework/templates/rest_framework/horizontal/select.html +++ b/rest_framework/templates/rest_framework/horizontal/select.html @@ -4,6 +4,9 @@ {% endif %} <div class="col-sm-10"> <select class="form-control" name="{{ field.name }}"> + {% if field.allow_null %} + <option value="" {% if not field.value %}selected{% endif %}>--------</option> + {% endif %} {% for key, text in field.choices.items %} <option value="{{ key }}" {% if key == field.value %}selected{% endif %}>{{ text }}</option> {% endfor %} diff --git a/rest_framework/templates/rest_framework/inline/input.html b/rest_framework/templates/rest_framework/inline/input.html index f8ec4faf..de85ba48 100644 --- a/rest_framework/templates/rest_framework/inline/input.html +++ b/rest_framework/templates/rest_framework/inline/input.html @@ -2,5 +2,5 @@ {% if field.label %} <label class="sr-only">{{ field.label }}</label> {% endif %} - <input name="{{ field.name }}" class="form-control" type="{{ style.input_type }}" {% if style.placeholder %}placeholder="{{ style.placeholder }}"{% endif %} {% if field.value %}value="{{ field.value }}"{% endif %}> + <input name="{{ field.name }}" {% if style.input_type != "file" %}class="form-control"{% endif %} type="{{ style.input_type }}" {% if style.placeholder %}placeholder="{{ style.placeholder }}"{% endif %} {% if field.value %}value="{{ field.value }}"{% endif %}> </div> diff --git a/rest_framework/templates/rest_framework/inline/list_fieldset.html b/rest_framework/templates/rest_framework/inline/list_fieldset.html new file mode 100644 index 00000000..2ae56d7c --- /dev/null +++ b/rest_framework/templates/rest_framework/inline/list_fieldset.html @@ -0,0 +1 @@ +<span>Lists are not currently supported in HTML input.</span> diff --git a/rest_framework/templates/rest_framework/inline/select.html b/rest_framework/templates/rest_framework/inline/select.html index e9fcebb4..53af2772 100644 --- a/rest_framework/templates/rest_framework/inline/select.html +++ b/rest_framework/templates/rest_framework/inline/select.html @@ -3,8 +3,11 @@ <label class="sr-only">{{ field.label }}</label> {% endif %} <select class="form-control" name="{{ field.name }}"> + {% if field.allow_null %} + <option value="" {% if not field.value %}selected{% endif %}>--------</option> + {% endif %} {% for key, text in field.choices.items %} - <option value="{{ key }}" {% if key == field.value %}selected{% endif %}>{{ text }}</option> + <option value="{{ key }}" {% if key == field.value %}selected{% endif %}>{{ text }}</option> {% endfor %} </select> </div> diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html index 8ab682ac..e050cbdc 100644 --- a/rest_framework/templates/rest_framework/login_base.html +++ b/rest_framework/templates/rest_framework/login_base.html @@ -22,7 +22,7 @@ <div id="div_id_username" class="clearfix control-group {% if form.username.errors %}error{% endif %}"> <div class="controls"> - <Label class="span4">Username:</label> + <label class="span4">Username:</label> <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="span12 textinput textInput" @@ -36,9 +36,10 @@ </div> </div> <div id="div_id_password" - class="clearfix control-group {% if form.password.errors %}error{% endif %}"> + class="clearfix control-group {% if form.password.errors %}error{% endif %}" + style="margin-top: 10px"> <div class="controls"> - <Label class="span4">Password:</label> + <label class="span4">Password:</label> <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="span12 textinput textInput" id="id_password" required> @@ -55,7 +56,7 @@ <div class="well well-small text-error" style="border: none">{{ error }}</div> {% endfor %} {% endif %} - <div class="form-actions-no-box"> + <div class="form-actions-no-box" style="margin-top: 20px"> <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit"> </div> </form> diff --git a/rest_framework/templates/rest_framework/vertical/input.html b/rest_framework/templates/rest_framework/vertical/input.html index e1e21ca1..43cccd3e 100644 --- a/rest_framework/templates/rest_framework/vertical/input.html +++ b/rest_framework/templates/rest_framework/vertical/input.html @@ -2,7 +2,7 @@ {% if field.label %} <label {% if style.hide_label %}class="sr-only"{% endif %}>{{ field.label }}</label> {% endif %} - <input name="{{ field.name }}" class="form-control" type="{{ style.input_type }}" {% if style.placeholder %}placeholder="{{ style.placeholder }}"{% endif %} {% if field.value %}value="{{ field.value }}"{% endif %}> + <input name="{{ field.name }}" {% if style.input_type != "file" %}class="form-control"{% endif %} type="{{ style.input_type }}" {% if style.placeholder %}placeholder="{{ style.placeholder }}"{% endif %} {% if field.value %}value="{{ field.value }}"{% endif %}> {% if field.errors %} {% for error in field.errors %}<span class="help-block">{{ error }}</span>{% endfor %} {% endif %} diff --git a/rest_framework/templates/rest_framework/vertical/list_fieldset.html b/rest_framework/templates/rest_framework/vertical/list_fieldset.html index 74bbf448..1d86c7f2 100644 --- a/rest_framework/templates/rest_framework/vertical/list_fieldset.html +++ b/rest_framework/templates/rest_framework/vertical/list_fieldset.html @@ -4,4 +4,5 @@ {% for field_item in field.value.field_items.values() %} {{ renderer.render_field(field_item, layout=layout) }} {% endfor %} --> + <p>Lists are not currently supported in HTML input.</p> </fieldset> diff --git a/rest_framework/templates/rest_framework/vertical/select.html b/rest_framework/templates/rest_framework/vertical/select.html index 7c673ebb..de72e1dd 100644 --- a/rest_framework/templates/rest_framework/vertical/select.html +++ b/rest_framework/templates/rest_framework/vertical/select.html @@ -3,8 +3,11 @@ <label {% if style.hide_label %}class="sr-only"{% endif %}>{{ field.label }}</label> {% endif %} <select class="form-control" name="{{ field.name }}"> + {% if field.allow_null %} + <option value="" {% if not field.value %}selected{% endif %}>--------</option> + {% endif %} {% for key, text in field.choices.items %} - <option value="{{ key }}" {% if key == field.value %}selected{% endif %}>{{ text }}</option> + <option value="{{ key }}" {% if key == field.value %}selected{% endif %}>{{ text }}</option> {% endfor %} </select> {% if field.errors %} diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index 24639085..9c187176 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -88,7 +88,7 @@ def get_field_kwargs(field_name, model_field): kwargs['read_only'] = True return kwargs - if model_field.has_default(): + if model_field.has_default() or model_field.blank or model_field.null: kwargs['required'] = False if model_field.flatchoices: @@ -215,7 +215,7 @@ def get_relation_kwargs(field_name, relation_info): # If this field is read-only, then return early. # No further keyword arguments are valid. return kwargs - if model_field.has_default(): + if model_field.has_default() or model_field.null: kwargs['required'] = False if model_field.null: kwargs['allow_null'] = True diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index 82361edf..c98725c6 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -6,6 +6,7 @@ relationships and their associated metadata. Usage: `get_field_info(model)` returns a `FieldInfo` instance. """ from collections import namedtuple +from django.core.exceptions import ImproperlyConfigured from django.db import models from django.utils import six from rest_framework.compat import OrderedDict @@ -43,7 +44,11 @@ def _resolve_model(obj): """ if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: app_name, model_name = obj.split('.') - return models.get_model(app_name, model_name) + resolved_model = models.get_model(app_name, model_name) + if resolved_model is None: + msg = "Django did not return a model for {0}.{1}" + raise ImproperlyConfigured(msg.format(app_name, model_name)) + return resolved_model elif inspect.isclass(obj) and issubclass(obj, models.Model): return obj raise ValueError("{0} is not a Django model".format(obj)) diff --git a/rest_framework/utils/serializer_helpers.py b/rest_framework/utils/serializer_helpers.py index 92d19857..277cf649 100644 --- a/rest_framework/utils/serializer_helpers.py +++ b/rest_framework/utils/serializer_helpers.py @@ -1,3 +1,4 @@ +import collections from rest_framework.compat import OrderedDict @@ -70,7 +71,7 @@ class NestedBoundField(BoundField): return BoundField(field, value, error, prefix=self.name + '.') -class BindingDict(object): +class BindingDict(collections.MutableMapping): """ This dict-like object is used to store fields on a serializer. @@ -92,11 +93,8 @@ class BindingDict(object): def __delitem__(self, key): del self.fields[key] - def items(self): - return self.fields.items() - - def keys(self): - return self.fields.keys() + def __iter__(self): + return iter(self.fields) - def values(self): - return self.fields.values() + def __len__(self): + return len(self.fields) diff --git a/rest_framework/validators.py b/rest_framework/validators.py index fa4f1847..7ca4e6a9 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -93,6 +93,9 @@ class UniqueTogetherValidator: The `UniqueTogetherValidator` always forces an implied 'required' state on the fields it applies to. """ + if self.instance is not None: + return + missing = dict([ (field_name, self.missing_message) for field_name in self.fields @@ -105,8 +108,17 @@ class UniqueTogetherValidator: """ Filter the queryset to all instances matching the given attributes. """ + # If this is an update, then any unprovided field should + # have it's value set based on the existing instance attribute. + if self.instance is not None: + for field_name in self.fields: + if field_name not in attrs: + attrs[field_name] = getattr(self.instance, field_name) + + # Determine the filter keyword arguments and filter the queryset. filter_kwargs = dict([ - (field_name, attrs[field_name]) for field_name in self.fields + (field_name, attrs[field_name]) + for field_name in self.fields ]) return queryset.filter(**filter_kwargs) diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 84b4bd8d..70d14695 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -48,6 +48,12 @@ class ViewSetMixin(object): # eg. 'List' or 'Instance'. cls.suffix = None + # actions must not be empty + if not actions: + raise TypeError("The `actions` argument must be provided when " + "calling `.as_view()` on a ViewSet. For example " + "`.as_view({'get': 'list'})`") + # sanitize keyword arguments for key in initkwargs: if key in cls.http_method_names: |
