From f27a28682bdb1b4eea0ec9afca2eb2835c735f55 Mon Sep 17 00:00:00 2001 From: Greg Doermann Date: Wed, 20 Aug 2014 11:00:37 -0600 Subject: Frameworks throws AssertionError saying you cannot set required=True and read_only=True on editable=False model fields. We should not make the field required if editable=False. --- rest_framework/serializers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index be8ad3f2..27af7ef3 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -831,7 +831,7 @@ class ModelSerializer(Serializer): } if model_field: - kwargs['required'] = not(model_field.null or model_field.blank) + kwargs['required'] = not(model_field.null or model_field.blank) and model_field.editable if model_field.help_text is not None: kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: @@ -854,7 +854,7 @@ class ModelSerializer(Serializer): """ kwargs = {} - if model_field.null or model_field.blank: + if model_field.null or model_field.blank and model_field.editable: kwargs['required'] = False if isinstance(model_field, models.AutoField) or not model_field.editable: @@ -1110,7 +1110,7 @@ class HyperlinkedModelSerializer(ModelSerializer): } if model_field: - kwargs['required'] = not(model_field.null or model_field.blank) + kwargs['required'] = not(model_field.null or model_field.blank) and model_field.editable if model_field.help_text is not None: kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: -- cgit v1.2.3 From f62c874ea9621ae67fb56e7e453dca8fd5039051 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 10:48:40 +0100 Subject: Remove `filter_backend`. Closes #1775. --- rest_framework/generics.py | 20 +------------------- rest_framework/settings.py | 10 ---------- 2 files changed, 1 insertion(+), 29 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index a6f68657..8bacf470 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -83,7 +83,6 @@ class GenericAPIView(views.APIView): slug_url_kwarg = 'slug' slug_field = 'slug' allow_empty = True - filter_backend = api_settings.FILTER_BACKEND def get_serializer_context(self): """ @@ -191,24 +190,7 @@ class GenericAPIView(views.APIView): """ Returns the list of filter backends that this view requires. """ - if self.filter_backends is None: - filter_backends = [] - else: - # Note that we are returning a *copy* of the class attribute, - # so that it is safe for the view to mutate it if needed. - filter_backends = list(self.filter_backends) - - if not filter_backends and self.filter_backend: - warnings.warn( - 'The `filter_backend` attribute and `FILTER_BACKEND` setting ' - 'are deprecated in favor of a `filter_backends` ' - 'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take ' - 'a *list* of filter backend classes.', - DeprecationWarning, stacklevel=2 - ) - filter_backends = [self.filter_backend] - - return filter_backends + return list(self.filter_backends) # The following methods provide default implementations # that you may want to override for more complex cases. diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 644751f8..bbe7a56a 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -111,9 +111,6 @@ DEFAULTS = { ), 'TIME_FORMAT': None, - # Pending deprecation - 'FILTER_BACKEND': None, - } @@ -129,7 +126,6 @@ IMPORT_STRINGS = ( 'DEFAULT_PAGINATION_SERIALIZER_CLASS', 'DEFAULT_FILTER_BACKENDS', 'EXCEPTION_HANDLER', - 'FILTER_BACKEND', 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', @@ -196,15 +192,9 @@ class APISettings(object): if val and attr in self.import_strings: val = perform_import(val, attr) - self.validate_setting(attr, val) - # Cache the result setattr(self, attr, val) return val - def validate_setting(self, attr, val): - if attr == 'FILTER_BACKEND' and val is not None: - # Make sure we can initialize the class - val() api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) -- cgit v1.2.3 From 0f8fdf4e72b67ff46474c13c8b532bf319a58099 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 10:57:24 +0100 Subject: Remove `allow_empty`. Closes #1774. --- rest_framework/generics.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 8bacf470..cb8061b7 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -82,7 +82,6 @@ class GenericAPIView(views.APIView): pk_url_kwarg = 'pk' slug_url_kwarg = 'slug' slug_field = 'slug' - allow_empty = True def get_serializer_context(self): """ @@ -140,16 +139,7 @@ class GenericAPIView(views.APIView): if not page_size: return None - if not self.allow_empty: - warnings.warn( - 'The `allow_empty` parameter is deprecated. ' - 'To use `allow_empty=False` style behavior, You should override ' - '`get_queryset()` and explicitly raise a 404 on empty querysets.', - DeprecationWarning, stacklevel=2 - ) - - paginator = self.paginator_class(queryset, page_size, - allow_empty_first_page=self.allow_empty) + paginator = self.paginator_class(queryset, page_size) page_kwarg = self.kwargs.get(self.page_kwarg) page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) page = page_kwarg or page_query_param or 1 -- cgit v1.2.3 From b3bbf416707cf8c71861b0fd6e966a557acef412 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 11:09:35 +0100 Subject: Remove `allow_empty` --- rest_framework/mixins.py | 15 --------------- 1 file changed, 15 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 2cc87eef..dc4c9f35 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -70,24 +70,9 @@ class ListModelMixin(object): """ List a queryset. """ - empty_error = "Empty list and '%(class_name)s.allow_empty' is False." - def list(self, request, *args, **kwargs): self.object_list = self.filter_queryset(self.get_queryset()) - # Default is to allow empty querysets. This can be altered by setting - # `.allow_empty = False`, to raise 404 errors on empty querysets. - if not self.allow_empty and not self.object_list: - warnings.warn( - 'The `allow_empty` parameter is deprecated. ' - 'To use `allow_empty=False` style behavior, You should override ' - '`get_queryset()` and explicitly raise a 404 on empty querysets.', - DeprecationWarning - ) - class_name = self.__class__.__name__ - error_msg = self.empty_error % {'class_name': class_name} - raise Http404(error_msg) - # Switch between paginated or standard style responses page = self.paginate_queryset(self.object_list) if page is not None: -- cgit v1.2.3 From e5e6329a222def3b0745f90fc55ee36de95ada83 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 11:29:26 +0100 Subject: Remove `pk_url_field`, `slug_url_field`, `slug_field`. Closes #1773. --- rest_framework/generics.py | 36 ++------------ rest_framework/mixins.py | 36 +++----------- rest_framework/relations.py | 117 ++------------------------------------------ 3 files changed, 15 insertions(+), 174 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index cb8061b7..e21dc5c7 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -76,13 +76,6 @@ class GenericAPIView(views.APIView): model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS paginator_class = Paginator - ###################################### - # These are pending deprecation... - - pk_url_kwarg = 'pk' - slug_url_kwarg = 'slug' - slug_field = 'slug' - def get_serializer_context(self): """ Extra context provided to the serializer class. @@ -270,7 +263,7 @@ class GenericAPIView(views.APIView): error_format = "'%s' must define 'queryset' or 'model'" raise ImproperlyConfigured(error_format % self.__class__.__name__) - def get_object(self, queryset=None): + def get_object(self): """ Returns the object the view is displaying. @@ -278,36 +271,14 @@ class GenericAPIView(views.APIView): queryset lookups. Eg if objects are referenced using multiple keyword arguments in the url conf. """ - # Determine the base queryset to use. - if queryset is None: - queryset = self.filter_queryset(self.get_queryset()) - else: - pass # Deprecation warning + queryset = self.filter_queryset(self.get_queryset()) # Perform the lookup filtering. # Note that `pk` and `slug` are deprecated styles of lookup filtering. lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field lookup = self.kwargs.get(lookup_url_kwarg, None) - pk = self.kwargs.get(self.pk_url_kwarg, None) - slug = self.kwargs.get(self.slug_url_kwarg, None) - if lookup is not None: - filter_kwargs = {self.lookup_field: lookup} - elif pk is not None and self.lookup_field == 'pk': - warnings.warn( - 'The `pk_url_kwarg` attribute is deprecated. ' - 'Use the `lookup_field` attribute instead', - DeprecationWarning - ) - filter_kwargs = {'pk': pk} - elif slug is not None and self.lookup_field == 'pk': - warnings.warn( - 'The `slug_url_kwarg` attribute is deprecated. ' - 'Use the `lookup_field` attribute instead', - DeprecationWarning - ) - filter_kwargs = {self.slug_field: slug} - else: + if lookup is None: raise ImproperlyConfigured( 'Expected view %s to be called with a URL keyword argument ' 'named "%s". Fix your URL conf, or set the `.lookup_field` ' @@ -315,6 +286,7 @@ class GenericAPIView(views.APIView): (self.__class__.__name__, self.lookup_field) ) + filter_kwargs = {self.lookup_field: lookup} obj = get_object_or_404(queryset, **filter_kwargs) # May raise a permission denied diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index dc4c9f35..ac59d979 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -12,10 +12,9 @@ from rest_framework import status from rest_framework.response import Response from rest_framework.request import clone_request from rest_framework.settings import api_settings -import warnings -def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None): +def _get_validation_exclusions(obj, lookup_field=None): """ Given a model instance, and an optional pk and slug field, return the full list of all other field names on that model. @@ -23,23 +22,13 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None) For use when performing full_clean on a model instance, so we only clean the required fields. """ - include = [] - - if pk: - # Deprecated + if lookup_field == 'pk': pk_field = obj._meta.pk while pk_field.rel: pk_field = pk_field.rel.to._meta.pk - include.append(pk_field.name) - - if slug_field: - # Deprecated - include.append(slug_field) - - if lookup_field and lookup_field != 'pk': - include.append(lookup_field) + lookup_field = pk_field.name - return [field.name for field in obj._meta.fields if field.name not in include] + return [field.name for field in obj._meta.fields if field.name != lookup_field] class CreateModelMixin(object): @@ -146,26 +135,15 @@ class UpdateModelMixin(object): """ Set any attributes on the object that are implicit in the request. """ - # pk and/or slug attributes are implicit in the URL. lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - lookup = self.kwargs.get(lookup_url_kwarg, None) - pk = self.kwargs.get(self.pk_url_kwarg, None) - slug = self.kwargs.get(self.slug_url_kwarg, None) - slug_field = slug and self.slug_field or None - - if lookup: - setattr(obj, self.lookup_field, lookup) - - if pk: - setattr(obj, 'pk', pk) + lookup_value = self.kwargs[lookup_url_kwarg] - if slug: - setattr(obj, slug_field, slug) + setattr(obj, self.lookup_field, lookup_value) # Ensure we clean the attributes so that we don't eg return integer # pk using a string representation, as provided by the url conf kwarg. if hasattr(obj, 'full_clean'): - exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field) + exclude = _get_validation_exclusions(obj, self.lookup_field) obj.full_clean(exclude) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 1acbdce2..56870b40 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -16,7 +16,6 @@ from rest_framework.fields import Field, WritableField, get_component, is_simple from rest_framework.reverse import reverse from rest_framework.compat import urlparse from rest_framework.compat import smart_text -import warnings # Relational fields @@ -320,11 +319,6 @@ class HyperlinkedRelatedField(RelatedField): 'incorrect_type': _('Incorrect type. Expected url string, received %s.'), } - # These are all deprecated - pk_url_kwarg = 'pk' - slug_field = 'slug' - slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden - def __init__(self, *args, **kwargs): try: self.view_name = kwargs.pop('view_name') @@ -334,22 +328,6 @@ class HyperlinkedRelatedField(RelatedField): self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.format = kwargs.pop('format', None) - # These are deprecated - if 'pk_url_kwarg' in kwargs: - msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - if 'slug_url_kwarg' in kwargs: - msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - if 'slug_field' in kwargs: - msg = 'slug_field is deprecated. Use lookup_field instead.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - - self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) - self.slug_field = kwargs.pop('slug_field', self.slug_field) - default_slug_kwarg = self.slug_url_kwarg or self.slug_field - self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) - super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) def get_url(self, obj, view_name, request, format): @@ -361,39 +339,7 @@ class HyperlinkedRelatedField(RelatedField): """ lookup_field = getattr(obj, self.lookup_field) kwargs = {self.lookup_field: lookup_field} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass - - if self.pk_url_kwarg != 'pk': - # Only try pk if it has been explicitly set. - # Otherwise, the default `lookup_field = 'pk'` has us covered. - pk = obj.pk - kwargs = {self.pk_url_kwarg: pk} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass - - slug = getattr(obj, self.slug_field, None) - if slug is not None: - # Only try slug if it corresponds to an attribute on the object. - kwargs = {self.slug_url_kwarg: slug} - try: - ret = reverse(view_name, kwargs=kwargs, request=request, format=format) - if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug': - # If the lookup succeeds using the default slug params, - # then `slug_field` is being used implicitly, and we - # we need to warn about the pending deprecation. - msg = 'Implicit slug field hyperlinked fields are deprecated.' \ - 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - return ret - except NoReverseMatch: - pass - - raise NoReverseMatch() + return reverse(view_name, kwargs=kwargs, request=request, format=format) def get_object(self, queryset, view_name, view_args, view_kwargs): """ @@ -402,19 +348,8 @@ class HyperlinkedRelatedField(RelatedField): Takes the matched URL conf arguments, and the queryset, and should return an object instance, or raise an `ObjectDoesNotExist` exception. """ - lookup = view_kwargs.get(self.lookup_field, None) - pk = view_kwargs.get(self.pk_url_kwarg, None) - slug = view_kwargs.get(self.slug_url_kwarg, None) - - if lookup is not None: - filter_kwargs = {self.lookup_field: lookup} - elif pk is not None: - filter_kwargs = {'pk': pk} - elif slug is not None: - filter_kwargs = {self.slug_field: slug} - else: - raise ObjectDoesNotExist() - + lookup_value = view_kwargs[self.lookup_field] + filter_kwargs = {self.lookup_field: lookup_value} return queryset.get(**filter_kwargs) def to_native(self, obj): @@ -486,11 +421,6 @@ class HyperlinkedIdentityField(Field): lookup_field = 'pk' read_only = True - # These are all deprecated - pk_url_kwarg = 'pk' - slug_field = 'slug' - slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden - def __init__(self, *args, **kwargs): try: self.view_name = kwargs.pop('view_name') @@ -502,22 +432,6 @@ class HyperlinkedIdentityField(Field): lookup_field = kwargs.pop('lookup_field', None) self.lookup_field = lookup_field or self.lookup_field - # These are deprecated - if 'pk_url_kwarg' in kwargs: - msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - if 'slug_url_kwarg' in kwargs: - msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - if 'slug_field' in kwargs: - msg = 'slug_field is deprecated. Use lookup_field instead.' - warnings.warn(msg, DeprecationWarning, stacklevel=2) - - self.slug_field = kwargs.pop('slug_field', self.slug_field) - default_slug_kwarg = self.slug_url_kwarg or self.slug_field - self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) - self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) - super(HyperlinkedIdentityField, self).__init__(*args, **kwargs) def field_to_native(self, obj, field_name): @@ -569,27 +483,4 @@ class HyperlinkedIdentityField(Field): if lookup_field is None: return None - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass - - if self.pk_url_kwarg != 'pk': - # Only try pk lookup if it has been explicitly set. - # Otherwise, the default `lookup_field = 'pk'` has us covered. - kwargs = {self.pk_url_kwarg: obj.pk} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass - - slug = getattr(obj, self.slug_field, None) - if slug: - # Only use slug lookup if a slug field exists on the model - kwargs = {self.slug_url_kwarg: slug} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except NoReverseMatch: - pass - - raise NoReverseMatch() + return reverse(view_name, kwargs=kwargs, request=request, format=format) -- cgit v1.2.3 From b8c8d10a18741b76355ed7035655d0101c1d778a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 11:38:54 +0100 Subject: Remove `page_size` argument. `paginate_queryset` no longer takes an optional `page_size` argument. --- rest_framework/generics.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index e21dc5c7..09035303 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -111,26 +111,14 @@ class GenericAPIView(views.APIView): context = self.get_serializer_context() return pagination_serializer_class(instance=page, context=context) - def paginate_queryset(self, queryset, page_size=None): + def paginate_queryset(self, queryset): """ Paginate a queryset if required, either returning a page object, or `None` if pagination is not configured for this view. """ - deprecated_style = False - if page_size is not None: - warnings.warn('The `page_size` parameter to `paginate_queryset()` ' - 'is deprecated. ' - 'Note that the return style of this method is also ' - 'changed, and will simply return a page object ' - 'when called without a `page_size` argument.', - DeprecationWarning, stacklevel=2) - deprecated_style = True - else: - # Determine the required page size. - # If pagination is not configured, simply return None. - page_size = self.get_paginate_by() - if not page_size: - return None + page_size = self.get_paginate_by() + if not page_size: + return None paginator = self.paginator_class(queryset, page_size) page_kwarg = self.kwargs.get(self.page_kwarg) @@ -152,8 +140,6 @@ class GenericAPIView(views.APIView): 'message': str(exc) }) - if deprecated_style: - return (paginator, page, page.object_list, page.has_other_pages()) return page def filter_queryset(self, queryset): -- cgit v1.2.3 From b3253b42836acd123224e88c0927f1ee6a031d94 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 12:35:53 +0100 Subject: Remove `.model` usage in tests. Remove the shortcut `.model` view attribute usage from test cases. --- rest_framework/generics.py | 49 ++++++++++++---------------------------------- 1 file changed, 12 insertions(+), 37 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 09035303..68222864 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -51,11 +51,6 @@ class GenericAPIView(views.APIView): queryset = None serializer_class = None - # This shortcut may be used instead of setting either or both - # of the `queryset`/`serializer_class` attributes, although using - # the explicit style is generally preferred. - model = None - # If you want to use object lookups other than pk, set this attribute. # For more complex lookup requirements override `get_object()`. lookup_field = 'pk' @@ -71,9 +66,8 @@ class GenericAPIView(views.APIView): # The filter backend classes to use for queryset filtering filter_backends = api_settings.DEFAULT_FILTER_BACKENDS - # The following attributes may be subject to change, + # The following attribute may be subject to change, # and should be considered private API. - model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS paginator_class = Paginator def get_serializer_context(self): @@ -199,26 +193,13 @@ class GenericAPIView(views.APIView): (Eg. admins get full serialization, others get basic serialization) """ - serializer_class = self.serializer_class - if serializer_class is not None: - return serializer_class - - warnings.warn( - 'The `.model` attribute on view classes is now deprecated in favor ' - 'of the more explicit `serializer_class` and `queryset` attributes.', - DeprecationWarning, stacklevel=2 - ) - - assert self.model is not None, \ - "'%s' should either include a 'serializer_class' attribute, " \ - "or use the 'model' attribute as a shortcut for " \ - "automatically generating a serializer class." \ + assert self.serializer_class is not None, ( + "'%s' should either include a `serializer_class` attribute, " + "or override the `get_serializer_class()` method." % self.__class__.__name__ + ) - class DefaultSerializer(self.model_serializer_class): - class Meta: - model = self.model - return DefaultSerializer + return self.serializer_class def get_queryset(self): """ @@ -235,19 +216,13 @@ class GenericAPIView(views.APIView): (Eg. return a list of items that is specific to the user) """ - if self.queryset is not None: - return self.queryset._clone() - - if self.model is not None: - warnings.warn( - 'The `.model` attribute on view classes is now deprecated in favor ' - 'of the more explicit `serializer_class` and `queryset` attributes.', - DeprecationWarning, stacklevel=2 - ) - return self.model._default_manager.all() + assert self.queryset is not None, ( + "'%s' should either include a `queryset` attribute, " + "or override the `get_queryset()` method." + % self.__class__.__name__ + ) - error_format = "'%s' must define 'queryset' or 'model'" - raise ImproperlyConfigured(error_format % self.__class__.__name__) + return self.queryset._clone() def get_object(self): """ -- cgit v1.2.3 From 72c0811576feb89decf6fc6dc4ee5e25eca0aece Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 12:48:04 +0100 Subject: Minor tidy up. --- rest_framework/generics.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 68222864..d0adeaec 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -3,7 +3,7 @@ Generic views that provide commonly needed behaviour. """ from __future__ import unicode_literals -from django.core.exceptions import ImproperlyConfigured, PermissionDenied +from django.core.exceptions import PermissionDenied from django.core.paginator import Paginator, InvalidPage from django.http import Http404 from django.shortcuts import get_object_or_404 as _get_object_or_404 @@ -235,19 +235,16 @@ class GenericAPIView(views.APIView): queryset = self.filter_queryset(self.get_queryset()) # Perform the lookup filtering. - # Note that `pk` and `slug` are deprecated styles of lookup filtering. lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - lookup = self.kwargs.get(lookup_url_kwarg, None) - if lookup is None: - raise ImproperlyConfigured( - 'Expected view %s to be called with a URL keyword argument ' - 'named "%s". Fix your URL conf, or set the `.lookup_field` ' - 'attribute on the view correctly.' % - (self.__class__.__name__, self.lookup_field) - ) + assert lookup_url_kwarg in self.kwargs, ( + 'Expected view %s to be called with a URL keyword argument ' + 'named "%s". Fix your URL conf, or set the `.lookup_field` ' + 'attribute on the view correctly.' % + (self.__class__.__name__, lookup_url_kwarg) + ) - filter_kwargs = {self.lookup_field: lookup} + filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} obj = get_object_or_404(queryset, **filter_kwargs) # May raise a permission denied -- cgit v1.2.3 From ce7b2cded94abc12ae1be076642de96684d0927b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 12:48:49 +0100 Subject: Remove deprecated generic views. `MultipleObjectAPIView` and `SingleObjectAPIView` are no longer required. --- rest_framework/generics.py | 22 ---------------------- 1 file changed, 22 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index d0adeaec..e6cbfca9 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -442,25 +442,3 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) - - -# Deprecated classes - -class MultipleObjectAPIView(GenericAPIView): - def __init__(self, *args, **kwargs): - warnings.warn( - 'Subclassing `MultipleObjectAPIView` is deprecated. ' - 'You should simply subclass `GenericAPIView` instead.', - DeprecationWarning, stacklevel=2 - ) - super(MultipleObjectAPIView, self).__init__(*args, **kwargs) - - -class SingleObjectAPIView(GenericAPIView): - def __init__(self, *args, **kwargs): - warnings.warn( - 'Subclassing `SingleObjectAPIView` is deprecated. ' - 'You should simply subclass `GenericAPIView` instead.', - DeprecationWarning, stacklevel=2 - ) - super(SingleObjectAPIView, self).__init__(*args, **kwargs) -- cgit v1.2.3 From f87d32558eb3b36f14a798ec48e4943d25380b92 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 12:53:45 +0100 Subject: Remove `.link()` and `.action()` decorators. --- rest_framework/decorators.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 449ba0a2..cc5d92c2 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -130,37 +130,3 @@ def list_route(methods=['get'], **kwargs): func.kwargs = kwargs return func return decorator - - -# These are now pending deprecation, in favor of `detail_route` and `list_route`. - -def link(**kwargs): - """ - Used to mark a method on a ViewSet that should be routed for detail GET requests. - """ - msg = 'link is pending deprecation. Use detail_route instead.' - warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) - - def decorator(func): - func.bind_to_methods = ['get'] - func.detail = True - func.kwargs = kwargs - return func - - return decorator - - -def action(methods=['post'], **kwargs): - """ - Used to mark a method on a ViewSet that should be routed for detail POST requests. - """ - msg = 'action is pending deprecation. Use detail_route instead.' - warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) - - def decorator(func): - func.bind_to_methods = methods - func.detail = True - func.kwargs = kwargs - return func - - return decorator -- cgit v1.2.3 From b552b62540e5144272c9c13c28f120ffe5fcbe45 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 12:54:03 +0100 Subject: `get_paginate_by` no longer takes optional `.queryset` --- rest_framework/generics.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index e6cbfca9..40c49844 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -158,7 +158,7 @@ class GenericAPIView(views.APIView): # The following methods provide default implementations # that you may want to override for more complex cases. - def get_paginate_by(self, queryset=None): + def get_paginate_by(self): """ Return the size of pages to use with pagination. @@ -167,11 +167,6 @@ class GenericAPIView(views.APIView): Otherwise defaults to using `self.paginate_by`. """ - if queryset is not None: - warnings.warn('The `queryset` parameter to `get_paginate_by()` ' - 'is deprecated.', - DeprecationWarning, stacklevel=2) - if self.paginate_by_param: try: return strict_positive_int( -- cgit v1.2.3 From 371d30aa8737c4b3aaf28ee10cc2b77a9c4d1fd9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 12:54:52 +0100 Subject: Remove unused imports. --- rest_framework/decorators.py | 1 - rest_framework/generics.py | 1 - 2 files changed, 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index cc5d92c2..d28d6e22 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -10,7 +10,6 @@ from __future__ import unicode_literals from django.utils import six from rest_framework.views import APIView import types -import warnings def api_view(http_method_names): diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 40c49844..b3bd6ce9 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -11,7 +11,6 @@ from django.utils.translation import ugettext as _ from rest_framework import views, mixins, exceptions from rest_framework.request import clone_request from rest_framework.settings import api_settings -import warnings def strict_positive_int(integer_string, cutoff=None): -- cgit v1.2.3 From 4ac4676a40b121d27cfd1173ff548d96b8d3de2f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 29 Aug 2014 16:46:26 +0100 Subject: First pass --- rest_framework/fields.py | 1166 +++++++------------------------------- rest_framework/generics.py | 10 +- rest_framework/mixins.py | 24 +- rest_framework/pagination.py | 22 +- rest_framework/relations.py | 486 ---------------- rest_framework/renderers.py | 10 +- rest_framework/serializers.py | 1096 ++++++++++------------------------- rest_framework/utils/encoders.py | 18 +- rest_framework/utils/html.py | 86 +++ 9 files changed, 632 insertions(+), 2286 deletions(-) create mode 100644 rest_framework/utils/html.py (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 9d707c9b..a83bf94c 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,1038 +1,308 @@ -""" -Serializer fields perform validation on incoming data. +from rest_framework.utils import html -They are very similar to Django's form fields. -""" -from __future__ import unicode_literals -import copy -import datetime -import inspect -import re -import warnings -from decimal import Decimal, DecimalException -from django import forms -from django.core import validators -from django.core.exceptions import ValidationError -from django.conf import settings -from django.db.models.fields import BLANK_CHOICE_DASH -from django.http import QueryDict -from django.forms import widgets -from django.utils import six, timezone -from django.utils.encoding import is_protected_type -from django.utils.translation import ugettext_lazy as _ -from django.utils.datastructures import SortedDict -from django.utils.dateparse import parse_date, parse_datetime, parse_time -from rest_framework import ISO_8601 -from rest_framework.compat import ( - BytesIO, smart_text, - force_text, is_non_str_iterable -) -from rest_framework.settings import api_settings - - -def is_simple_callable(obj): +class empty: """ - True if the object is a callable that takes no arguments. - """ - function = inspect.isfunction(obj) - method = inspect.ismethod(obj) - - if not (function or method): - return False + This class is used to represent no data being provided for a given input + or output value. - args, _, _, defaults = inspect.getargspec(obj) - len_args = len(args) if function else len(args) - 1 - len_defaults = len(defaults) if defaults else 0 - return len_args <= len_defaults + It is required because `None` may be a valid input or output value. + """ + pass -def get_component(obj, attr_name): +def get_attribute(instance, attrs): """ - Given an object, and an attribute name, - return that attribute on the object. + Similar to Python's built in `getattr(instance, attr)`, + but takes a list of nested attributes, instead of a single attribute. """ - if isinstance(obj, dict): - val = obj.get(attr_name) - else: - val = getattr(obj, attr_name) - - if is_simple_callable(val): - return val() - return val - - -def readable_datetime_formats(formats): - format = ', '.join(formats).replace( - ISO_8601, - 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' - ) - return humanize_strptime(format) - - -def readable_date_formats(formats): - format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') - return humanize_strptime(format) + for attr in attrs: + instance = getattr(instance, attr) + return instance -def readable_time_formats(formats): - format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') - return humanize_strptime(format) - - -def humanize_strptime(format_string): - # Note that we're missing some of the locale specific mappings that - # don't really make sense. - mapping = { - "%Y": "YYYY", - "%y": "YY", - "%m": "MM", - "%b": "[Jan-Dec]", - "%B": "[January-December]", - "%d": "DD", - "%H": "hh", - "%I": "hh", # Requires '%p' to differentiate from '%H'. - "%M": "mm", - "%S": "ss", - "%f": "uuuuuu", - "%a": "[Mon-Sun]", - "%A": "[Monday-Sunday]", - "%p": "[AM|PM]", - "%z": "[+HHMM|-HHMM]" - } - for key, val in mapping.items(): - format_string = format_string.replace(key, val) - return format_string - - -def strip_multiple_choice_msg(help_text): +def set_value(dictionary, keys, value): """ - Remove the 'Hold down "control" ...' message that is Django enforces in - select multiple fields on ModelForms. (Required for 1.5 and earlier) + Similar to Python's built in `dictionary[key] = value`, + but takes a list of nested keys instead of a single key. - See https://code.djangoproject.com/ticket/9321 + set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2} + set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2} + set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}} """ - multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.') - multiple_choice_msg = force_text(multiple_choice_msg) + if not keys: + dictionary.update(value) + return - return help_text.replace(multiple_choice_msg, '') + for key in keys[:-1]: + if key not in dictionary: + dictionary[key] = {} + dictionary = dictionary[key] + dictionary[keys[-1]] = value -class Field(object): - read_only = True - creation_counter = 0 - empty = '' - type_name = None - partial = False - use_files = False - form_field_class = forms.CharField - type_label = 'field' - widget = None - def __init__(self, source=None, label=None, help_text=None): - self.parent = None +class ValidationError(Exception): + pass - self.creation_counter = Field.creation_counter - Field.creation_counter += 1 - self.source = source +class SkipField(Exception): + pass - if label is not None: - self.label = smart_text(label) - else: - self.label = None - if help_text is not None: - self.help_text = strip_multiple_choice_msg(smart_text(help_text)) - else: - self.help_text = None +class Field(object): + _creation_counter = 0 - self._errors = [] - self._value = None - self._name = None + MESSAGES = { + 'required': 'This field is required.' + } - @property - def errors(self): - return self._errors + _NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' + _NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' + _NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' + _NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' + _MISSING_ERROR_MESSAGE = ( + 'ValidationError raised by `{class_name}`, but error key `{key}` does ' + 'not exist in the `MESSAGES` dictionary.' + ) - def widget_html(self): - if not self.widget: - return '' + def __init__(self, read_only=False, write_only=False, + required=None, default=empty, initial=None, source=None, + label=None, style=None): + self._creation_counter = Field._creation_counter + Field._creation_counter += 1 - attrs = {} - if 'id' not in self.widget.attrs: - attrs['id'] = self._name + # If `required` is unset, then use `True` unless a default is provided. + if required is None: + required = default is empty and not read_only - return self.widget.render(self._name, self._value, attrs=attrs) + # Some combinations of keyword arguments do not make sense. + assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY + assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED + assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT + assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT - def label_tag(self): - return '' % (self._name, self.label) + self.read_only = read_only + self.write_only = write_only + self.required = required + self.default = default + self.source = source + self.initial = initial + self.label = label + self.style = {} if style is None else style - def initialize(self, parent, field_name): + def bind(self, field_name, parent, root): """ - Called to set up a field prior to field_to_native or field_from_native. - - parent - The parent serializer. - field_name - The name of the field being initialized. + Setup the context for the field instance. """ + self.field_name = field_name self.parent = parent - self.root = parent.root or parent - self.context = self.root.context - self.partial = self.root.partial - if self.partial: - self.required = False + self.root = root - def field_from_native(self, data, files, field_name, into): - """ - Given a dictionary and a field name, updates the dictionary `into`, - with the field and it's deserialized value. - """ - return + # `self.label` should deafult to being based on the field name. + if self.label is None: + self.label = self.field_name.replace('_', ' ').capitalize() - def field_to_native(self, obj, field_name): - """ - Given an object and a field name, returns the value that should be - serialized for that field. - """ - if obj is None: - return self.empty + # self.source should default to being the same as the field name. + if self.source is None: + self.source = field_name + # self.source_attrs is a list of attributes that need to be looked up + # when serializing the instance, or populating the validated data. if self.source == '*': - return self.to_native(obj) - - source = self.source or field_name - value = obj - - for component in source.split('.'): - value = get_component(value, component) - if value is None: - break - - return self.to_native(value) + self.source_attrs = [] + else: + self.source_attrs = self.source.split('.') - def to_native(self, value): + def get_initial(self): """ - Converts the field's value into it's simple representation. + Return a value to use when the field is being returned as a primative + value, without any object instance. """ - if is_simple_callable(value): - value = value() - - if is_protected_type(value): - return value - elif (is_non_str_iterable(value) and - not isinstance(value, (dict, six.string_types))): - return [self.to_native(item) for item in value] - elif isinstance(value, dict): - # Make sure we preserve field ordering, if it exists - ret = SortedDict() - for key, val in value.items(): - ret[key] = self.to_native(val) - return ret - return force_text(value) + return self.initial - def attributes(self): + def get_value(self, dictionary): """ - Returns a dictionary of attributes to be used when serializing to xml. + Given the *incoming* primative data, return the value for this field + that should be validated and transformed to a native value. """ - if self.type_name: - return {'type': self.type_name} - return {} - - def metadata(self): - metadata = SortedDict() - metadata['type'] = self.type_label - metadata['required'] = getattr(self, 'required', False) - optional_attrs = ['read_only', 'label', 'help_text', - 'min_length', 'max_length'] - for attr in optional_attrs: - value = getattr(self, attr, None) - if value is not None and value != '': - metadata[attr] = force_text(value, strings_only=True) - return metadata - - -class WritableField(Field): - """ - Base for read/write fields. - """ - write_only = False - default_validators = [] - default_error_messages = { - 'required': _('This field is required.'), - 'invalid': _('Invalid value.'), - } - widget = widgets.TextInput - default = None - - def __init__(self, source=None, label=None, help_text=None, - read_only=False, write_only=False, required=None, - validators=[], error_messages=None, widget=None, - default=None, blank=None): - - super(WritableField, self).__init__(source=source, label=label, help_text=help_text) - - self.read_only = read_only - self.write_only = write_only - - assert not (read_only and write_only), "Cannot set read_only=True and write_only=True" - - if required is None: - self.required = not(read_only) - else: - assert not (read_only and required), "Cannot set required=True and read_only=True" - self.required = required - - messages = {} - for c in reversed(self.__class__.__mro__): - messages.update(getattr(c, 'default_error_messages', {})) - messages.update(error_messages or {}) - self.error_messages = messages + return dictionary.get(self.field_name, empty) - self.validators = self.default_validators + validators - self.default = default if default is not None else self.default - - # Widgets are only used for HTML forms. - widget = widget or self.widget - if isinstance(widget, type): - widget = widget() - self.widget = widget + def get_attribute(self, instance): + """ + Given the *outgoing* object instance, return the value for this field + that should be returned as a primative value. + """ + return get_attribute(instance, self.source_attrs) - def __deepcopy__(self, memo): - result = copy.copy(self) - memo[id(self)] = result - result.validators = self.validators[:] - return result + def get_default(self): + """ + Return the default value to use when validating data if no input + is provided for this field. - def get_default_value(self): - if is_simple_callable(self.default): - return self.default() + If a default has not been set for this field then this will simply + return `empty`, indicating that no value should be set in the + validated data for this field. + """ + if self.default is empty: + raise SkipField() return self.default - def validate(self, value): - if value in validators.EMPTY_VALUES and self.required: - raise ValidationError(self.error_messages['required']) + def validate(self, data=empty): + """ + Validate a simple representation and return the internal value. - def run_validators(self, value): - if value in validators.EMPTY_VALUES: - return - errors = [] - for v in self.validators: - try: - v(value) - except ValidationError as e: - if hasattr(e, 'code') and e.code in self.error_messages: - message = self.error_messages[e.code] - if e.params: - message = message % e.params - errors.append(message) - else: - errors.extend(e.messages) - if errors: - raise ValidationError(errors) + The provided data may be `empty` if no representation was included. + May return `empty` if the field should not be included in the + validated data. + """ + if data is empty: + if self.required: + self.fail('required') + return self.get_default() - def field_to_native(self, obj, field_name): - if self.write_only: - return None - return super(WritableField, self).field_to_native(obj, field_name) + return self.to_native(data) - def field_from_native(self, data, files, field_name, into): + def to_native(self, data): """ - Given a dictionary and a field name, updates the dictionary `into`, - with the field and it's deserialized value. + Transform the *incoming* primative data into a native value. """ - if self.read_only: - return - - try: - data = data or {} - if self.use_files: - files = files or {} - try: - native = files[field_name] - except KeyError: - native = data[field_name] - else: - native = data[field_name] - except KeyError: - if self.default is not None and not self.partial: - # Note: partial updates shouldn't set defaults - native = self.get_default_value() - else: - if self.required: - raise ValidationError(self.error_messages['required']) - return - - value = self.from_native(native) - if self.source == '*': - if value: - into.update(value) - else: - self.validate(value) - self.run_validators(value) - into[self.source or field_name] = value + return data - def from_native(self, value): + def to_primative(self, value): """ - Reverts a simple representation back to the field's value. + Transform the *outgoing* native value into primative data. """ return value - -class ModelField(WritableField): - """ - A generic field that can be used against an arbitrary model field. - """ - def __init__(self, *args, **kwargs): + def fail(self, key, **kwargs): + """ + A helper method that simply raises a validation error. + """ try: - self.model_field = kwargs.pop('model_field') + raise ValidationError(self.MESSAGES[key].format(**kwargs)) except KeyError: - raise ValueError("ModelField requires 'model_field' kwarg") - - self.min_length = kwargs.pop('min_length', - getattr(self.model_field, 'min_length', None)) - self.max_length = kwargs.pop('max_length', - getattr(self.model_field, 'max_length', None)) - self.min_value = kwargs.pop('min_value', - getattr(self.model_field, 'min_value', None)) - self.max_value = kwargs.pop('max_value', - getattr(self.model_field, 'max_value', None)) - - super(ModelField, self).__init__(*args, **kwargs) - - if self.min_length is not None: - self.validators.append(validators.MinLengthValidator(self.min_length)) - if self.max_length is not None: - self.validators.append(validators.MaxLengthValidator(self.max_length)) - if self.min_value is not None: - self.validators.append(validators.MinValueValidator(self.min_value)) - if self.max_value is not None: - self.validators.append(validators.MaxValueValidator(self.max_value)) - - def from_native(self, value): - rel = getattr(self.model_field, "rel", None) - if rel is not None: - return rel.to._meta.get_field(rel.field_name).to_python(value) - else: - return self.model_field.to_python(value) + class_name = self.__class__.__name__ + msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) + raise AssertionError(msg) - def field_to_native(self, obj, field_name): - value = self.model_field._get_val_from_obj(obj) - if is_protected_type(value): - return value - return self.model_field.value_to_string(obj) - def attributes(self): - return { - "type": self.model_field.get_internal_type() - } - - -# Typed Fields - -class BooleanField(WritableField): - type_name = 'BooleanField' - type_label = 'boolean' - form_field_class = forms.BooleanField - widget = widgets.CheckboxInput - default_error_messages = { - 'invalid': _("'%s' value must be either True or False."), +class BooleanField(Field): + MESSAGES = { + 'required': 'This field is required.', + 'invalid_value': '`{input}` is not a valid boolean.' } - empty = False - - def field_from_native(self, data, files, field_name, into): - # HTML checkboxes do not explicitly represent unchecked as `False` - # we deal with that here... - if isinstance(data, QueryDict) and self.default is None: - self.default = False - - return super(BooleanField, self).field_from_native( - data, files, field_name, into - ) - - def from_native(self, value): - if value in ('true', 't', 'True', '1'): + TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True} + FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False} + + def get_value(self, dictionary): + if html.is_html_input(dictionary): + # HTML forms do not send a `False` value on an empty checkbox, + # so we override the default empty value to be False. + return dictionary.get(self.field_name, False) + return dictionary.get(self.field_name, empty) + + def to_native(self, data): + if data in self.TRUE_VALUES: return True - if value in ('false', 'f', 'False', '0'): + elif data in self.FALSE_VALUES: return False - return bool(value) + self.fail('invalid_value', input=data) -class CharField(WritableField): - type_name = 'CharField' - type_label = 'string' - form_field_class = forms.CharField +class CharField(Field): + MESSAGES = { + 'required': 'This field is required.', + 'blank': 'This field may not be blank.' + } - def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs): - self.max_length, self.min_length = max_length, min_length - self.allow_none = allow_none + def __init__(self, *args, **kwargs): + self.allow_blank = kwargs.pop('allow_blank', False) super(CharField, self).__init__(*args, **kwargs) - if min_length is not None: - self.validators.append(validators.MinLengthValidator(min_length)) - if max_length is not None: - self.validators.append(validators.MaxLengthValidator(max_length)) - - def from_native(self, value): - if isinstance(value, six.string_types): - return value - - if value is None and not self.allow_none: - return '' - - return smart_text(value) - -class URLField(CharField): - type_name = 'URLField' - type_label = 'url' - - def __init__(self, **kwargs): - if 'validators' not in kwargs: - kwargs['validators'] = [validators.URLValidator()] - super(URLField, self).__init__(**kwargs) + def to_native(self, data): + if data == '' and not self.allow_blank: + self.fail('blank') + return str(data) -class SlugField(CharField): - type_name = 'SlugField' - type_label = 'slug' - form_field_class = forms.SlugField - - default_error_messages = { - 'invalid': _("Enter a valid 'slug' consisting of letters, numbers," - " underscores or hyphens."), +class ChoiceField(Field): + MESSAGES = { + 'required': 'This field is required.', + 'invalid_choice': '`{input}` is not a valid choice.' } - default_validators = [validators.validate_slug] + coerce_to_type = str def __init__(self, *args, **kwargs): - super(SlugField, self).__init__(*args, **kwargs) - + choices = kwargs.pop('choices') + + assert choices, '`choices` argument is required and may not be empty' + + # Allow either single or paired choices style: + # choices = [1, 2, 3] + # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] + pairs = [ + isinstance(item, (list, tuple)) and len(item) == 2 + for item in choices + ] + if all(pairs): + self.choices = {key: val for key, val in choices} + else: + self.choices = {item: item for item in choices} -class ChoiceField(WritableField): - type_name = 'ChoiceField' - type_label = 'choice' - form_field_class = forms.ChoiceField - widget = widgets.Select - default_error_messages = { - 'invalid_choice': _('Select a valid choice. %(value)s is not one of ' - 'the available choices.'), - } + # Map the string representation of choices to the underlying value. + # Allows us to deal with eg. integer choices while supporting either + # integer or string input, but still get the correct datatype out. + self.choice_strings_to_values = { + str(key): key for key in self.choices.keys() + } - def __init__(self, choices=(), blank_display_value=None, *args, **kwargs): - self.empty = kwargs.pop('empty', '') super(ChoiceField, self).__init__(*args, **kwargs) - self.choices = choices - if not self.required: - if blank_display_value is None: - blank_choice = BLANK_CHOICE_DASH - else: - blank_choice = [('', blank_display_value)] - self.choices = blank_choice + self.choices - - def _get_choices(self): - return self._choices - - def _set_choices(self, value): - # Setting choices also sets the choices on the widget. - # choices can be any iterable, but we call list() on it because - # it will be consumed more than once. - self._choices = self.widget.choices = list(value) - - choices = property(_get_choices, _set_choices) - - def metadata(self): - data = super(ChoiceField, self).metadata() - data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices] - return data - - def validate(self, value): - """ - Validates that the input is in self.choices. - """ - super(ChoiceField, self).validate(value) - if value and not self.valid_value(value): - raise ValidationError(self.error_messages['invalid_choice'] % {'value': value}) - - def valid_value(self, value): - """ - Check to see if the provided value is a valid choice. - """ - for k, v in self.choices: - if isinstance(v, (list, tuple)): - # This is an optgroup, so look inside the group for options - for k2, v2 in v: - if value == smart_text(k2): - return True - else: - if value == smart_text(k) or value == k: - return True - return False - - def from_native(self, value): - value = super(ChoiceField, self).from_native(value) - if value == self.empty or value in validators.EMPTY_VALUES: - return self.empty - return value - - -class EmailField(CharField): - type_name = 'EmailField' - type_label = 'email' - form_field_class = forms.EmailField - - default_error_messages = { - 'invalid': _('Enter a valid email address.'), - } - default_validators = [validators.validate_email] - - def from_native(self, value): - ret = super(EmailField, self).from_native(value) - if ret is None: - return None - return ret.strip() - - -class RegexField(CharField): - type_name = 'RegexField' - type_label = 'regex' - form_field_class = forms.RegexField - - def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs): - super(RegexField, self).__init__(max_length, min_length, *args, **kwargs) - self.regex = regex - - def _get_regex(self): - return self._regex - - def _set_regex(self, regex): - if isinstance(regex, six.string_types): - regex = re.compile(regex) - self._regex = regex - if hasattr(self, '_regex_validator') and self._regex_validator in self.validators: - self.validators.remove(self._regex_validator) - self._regex_validator = validators.RegexValidator(regex=regex) - self.validators.append(self._regex_validator) - - regex = property(_get_regex, _set_regex) - - -class DateField(WritableField): - type_name = 'DateField' - type_label = 'date' - widget = widgets.DateInput - form_field_class = forms.DateField - - default_error_messages = { - 'invalid': _("Date has wrong format. Use one of these formats instead: %s"), - } - empty = None - input_formats = api_settings.DATE_INPUT_FORMATS - format = api_settings.DATE_FORMAT - - def __init__(self, input_formats=None, format=None, *args, **kwargs): - self.input_formats = input_formats if input_formats is not None else self.input_formats - self.format = format if format is not None else self.format - super(DateField, self).__init__(*args, **kwargs) - - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None - - if isinstance(value, datetime.datetime): - if timezone and settings.USE_TZ and timezone.is_aware(value): - # Convert aware datetimes to the default time zone - # before casting them to dates (#17742). - default_timezone = timezone.get_default_timezone() - value = timezone.make_naive(value, default_timezone) - return value.date() - if isinstance(value, datetime.date): - return value - - for format in self.input_formats: - if format.lower() == ISO_8601: - try: - parsed = parse_date(value) - except (ValueError, TypeError): - pass - else: - if parsed is not None: - return parsed - else: - try: - parsed = datetime.datetime.strptime(value, format) - except (ValueError, TypeError): - pass - else: - return parsed.date() - - msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats) - raise ValidationError(msg) - - def to_native(self, value): - if value is None or self.format is None: - return value - - if isinstance(value, datetime.datetime): - value = value.date() - - if self.format.lower() == ISO_8601: - return value.isoformat() - return value.strftime(self.format) - - -class DateTimeField(WritableField): - type_name = 'DateTimeField' - type_label = 'datetime' - widget = widgets.DateTimeInput - form_field_class = forms.DateTimeField - - default_error_messages = { - 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), - } - empty = None - input_formats = api_settings.DATETIME_INPUT_FORMATS - format = api_settings.DATETIME_FORMAT - - def __init__(self, input_formats=None, format=None, *args, **kwargs): - self.input_formats = input_formats if input_formats is not None else self.input_formats - self.format = format if format is not None else self.format - super(DateTimeField, self).__init__(*args, **kwargs) - - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None - - if isinstance(value, datetime.datetime): - return value - if isinstance(value, datetime.date): - value = datetime.datetime(value.year, value.month, value.day) - if settings.USE_TZ: - # For backwards compatibility, interpret naive datetimes in - # local time. This won't work during DST change, but we can't - # do much about it, so we let the exceptions percolate up the - # call stack. - warnings.warn("DateTimeField received a naive datetime (%s)" - " while time zone support is active." % value, - RuntimeWarning) - default_timezone = timezone.get_default_timezone() - value = timezone.make_aware(value, default_timezone) - return value - - for format in self.input_formats: - if format.lower() == ISO_8601: - try: - parsed = parse_datetime(value) - except (ValueError, TypeError): - pass - else: - if parsed is not None: - return parsed - else: - try: - parsed = datetime.datetime.strptime(value, format) - except (ValueError, TypeError): - pass - else: - return parsed - - msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats) - raise ValidationError(msg) - - def to_native(self, value): - if value is None or self.format is None: - return value - - if self.format.lower() == ISO_8601: - ret = value.isoformat() - if ret.endswith('+00:00'): - ret = ret[:-6] + 'Z' - return ret - return value.strftime(self.format) - - -class TimeField(WritableField): - type_name = 'TimeField' - type_label = 'time' - widget = widgets.TimeInput - form_field_class = forms.TimeField - - default_error_messages = { - 'invalid': _("Time has wrong format. Use one of these formats instead: %s"), - } - empty = None - input_formats = api_settings.TIME_INPUT_FORMATS - format = api_settings.TIME_FORMAT - - def __init__(self, input_formats=None, format=None, *args, **kwargs): - self.input_formats = input_formats if input_formats is not None else self.input_formats - self.format = format if format is not None else self.format - super(TimeField, self).__init__(*args, **kwargs) - - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None - - if isinstance(value, datetime.time): - return value - - for format in self.input_formats: - if format.lower() == ISO_8601: - try: - parsed = parse_time(value) - except (ValueError, TypeError): - pass - else: - if parsed is not None: - return parsed - else: - try: - parsed = datetime.datetime.strptime(value, format) - except (ValueError, TypeError): - pass - else: - return parsed.time() - - msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats) - raise ValidationError(msg) - - def to_native(self, value): - if value is None or self.format is None: - return value - - if isinstance(value, datetime.datetime): - value = value.time() - - if self.format.lower() == ISO_8601: - return value.isoformat() - return value.strftime(self.format) - - -class IntegerField(WritableField): - type_name = 'IntegerField' - type_label = 'integer' - form_field_class = forms.IntegerField - empty = 0 - - default_error_messages = { - 'invalid': _('Enter a whole number.'), - 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), - 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), - } - - def __init__(self, max_value=None, min_value=None, *args, **kwargs): - self.max_value, self.min_value = max_value, min_value - super(IntegerField, self).__init__(*args, **kwargs) - - if max_value is not None: - self.validators.append(validators.MaxValueValidator(max_value)) - if min_value is not None: - self.validators.append(validators.MinValueValidator(min_value)) - - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None - - try: - value = int(str(value)) - except (ValueError, TypeError): - raise ValidationError(self.error_messages['invalid']) - return value - - -class FloatField(WritableField): - type_name = 'FloatField' - type_label = 'float' - form_field_class = forms.FloatField - empty = 0 - - default_error_messages = { - 'invalid': _("'%s' value must be a float."), - } - - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None + def to_native(self, data): try: - return float(value) - except (TypeError, ValueError): - msg = self.error_messages['invalid'] % value - raise ValidationError(msg) - + return self.choice_strings_to_values[str(data)] + except KeyError: + self.fail('invalid_choice', input=data) -class DecimalField(WritableField): - type_name = 'DecimalField' - type_label = 'decimal' - form_field_class = forms.DecimalField - empty = Decimal('0') - default_error_messages = { - 'invalid': _('Enter a number.'), - 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), - 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), - 'max_digits': _('Ensure that there are no more than %s digits in total.'), - 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), - 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') +class MultipleChoiceField(ChoiceField): + MESSAGES = { + 'required': 'This field is required.', + 'invalid_choice': '`{input}` is not a valid choice.', + 'not_a_list': 'Expected a list of items but got type `{input_type}`' } - def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): - self.max_value, self.min_value = max_value, min_value - self.max_digits, self.decimal_places = max_digits, decimal_places - super(DecimalField, self).__init__(*args, **kwargs) - - if max_value is not None: - self.validators.append(validators.MaxValueValidator(max_value)) - if min_value is not None: - self.validators.append(validators.MinValueValidator(min_value)) - - def from_native(self, value): - """ - Validates that the input is a decimal number. Returns a Decimal - instance. Returns None for empty values. Ensures that there are no more - than max_digits in the number, and no more than decimal_places digits - after the decimal point. - """ - if value in validators.EMPTY_VALUES: - return None - value = smart_text(value).strip() - try: - value = Decimal(value) - except DecimalException: - raise ValidationError(self.error_messages['invalid']) - return value - - def validate(self, value): - super(DecimalField, self).validate(value) - if value in validators.EMPTY_VALUES: - return - # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, - # since it is never equal to itself. However, NaN is the only value that - # isn't equal to itself, so we can use this to identify NaN - if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): - raise ValidationError(self.error_messages['invalid']) - sign, digittuple, exponent = value.as_tuple() - decimals = abs(exponent) - # digittuple doesn't include any leading zeros. - digits = len(digittuple) - if decimals > digits: - # We have leading zeros up to or past the decimal point. Count - # everything past the decimal point as a digit. We do not count - # 0 before the decimal point as a digit since that would mean - # we would not allow max_digits = decimal_places. - digits = decimals - whole_digits = digits - decimals - - if self.max_digits is not None and digits > self.max_digits: - raise ValidationError(self.error_messages['max_digits'] % self.max_digits) - if self.decimal_places is not None and decimals > self.decimal_places: - raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) - if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): - raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) - return value - + def to_native(self, data): + if not hasattr(data, '__iter__'): + self.fail('not_a_list', input_type=type(data).__name__) + return set([ + super(MultipleChoiceField, self).to_native(item) + for item in data + ]) -class FileField(WritableField): - use_files = True - type_name = 'FileField' - type_label = 'file upload' - form_field_class = forms.FileField - widget = widgets.FileInput - default_error_messages = { - 'invalid': _("No file was submitted. Check the encoding type on the form."), - 'missing': _("No file was submitted."), - 'empty': _("The submitted file is empty."), - 'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'), - 'contradiction': _('Please either submit a file or check the clear checkbox, not both.') +class IntegerField(Field): + MESSAGES = { + 'required': 'This field is required.', + 'invalid_integer': 'A valid integer is required.' } - def __init__(self, *args, **kwargs): - self.max_length = kwargs.pop('max_length', None) - self.allow_empty_file = kwargs.pop('allow_empty_file', False) - super(FileField, self).__init__(*args, **kwargs) - - def from_native(self, data): - if data in validators.EMPTY_VALUES: - return None - - # UploadedFile objects should have name and size attributes. + def to_native(self, data): try: - file_name = data.name - file_size = data.size - except AttributeError: - raise ValidationError(self.error_messages['invalid']) - - if self.max_length is not None and len(file_name) > self.max_length: - error_values = {'max': self.max_length, 'length': len(file_name)} - raise ValidationError(self.error_messages['max_length'] % error_values) - if not file_name: - raise ValidationError(self.error_messages['invalid']) - if not self.allow_empty_file and not file_size: - raise ValidationError(self.error_messages['empty']) - + data = int(str(data)) + except (ValueError, TypeError): + self.fail('invalid_integer') return data - def to_native(self, value): - return value.name - - -class ImageField(FileField): - use_files = True - type_name = 'ImageField' - type_label = 'image upload' - form_field_class = forms.ImageField - - default_error_messages = { - 'invalid_image': _("Upload a valid image. The file you uploaded was " - "either not an image or a corrupted image."), - } - - def from_native(self, data): - """ - Checks that the file-upload field data contains a valid image (GIF, JPG, - PNG, possibly others -- whatever the Python Imaging Library supports). - """ - f = super(ImageField, self).from_native(data) - if f is None: - return None - - from rest_framework.compat import Image - assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.' - - # We need to get a file object for PIL. We might have a path or we might - # have to read the data into memory. - if hasattr(data, 'temporary_file_path'): - file = data.temporary_file_path() - else: - if hasattr(data, 'read'): - file = BytesIO(data.read()) - else: - file = BytesIO(data['content']) - - try: - # load() could spot a truncated JPEG, but it loads the entire - # image in memory, which is a DoS vector. See #3848 and #18520. - # verify() must be called immediately after the constructor. - Image.open(file).verify() - except ImportError: - # Under PyPy, it is possible to import PIL. However, the underlying - # _imaging C module isn't available, so an ImportError will be - # raised. Catch and re-raise. - raise - except Exception: # Python Imaging Library doesn't recognize it as an image - raise ValidationError(self.error_messages['invalid_image']) - if hasattr(f, 'seek') and callable(f.seek): - f.seek(0) - return f - -class SerializerMethodField(Field): - """ - A field that gets its value by calling a method on the serializer it's attached to. - """ - - def __init__(self, method_name, *args, **kwargs): - self.method_name = method_name - super(SerializerMethodField, self).__init__(*args, **kwargs) - - def field_to_native(self, obj, field_name): - value = getattr(self.parent, self.method_name)(obj) - return self.to_native(value) +class MethodField(Field): + def __init__(self, **kwargs): + kwargs['source'] = '*' + kwargs['read_only'] = True + super(MethodField, self).__init__(**kwargs) + + def to_primative(self, value): + attr = 'get_{field_name}'.format(field_name=self.field_name) + method = getattr(self.parent, attr) + return method(value) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index b3bd6ce9..6705cbb2 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -79,18 +79,16 @@ class GenericAPIView(views.APIView): 'view': self } - def get_serializer(self, instance=None, data=None, files=None, many=False, - partial=False, allow_add_remove=False): + def get_serializer(self, instance=None, data=None, many=False, partial=False): """ Return the serializer instance that should be used for validating and deserializing input, and for serializing output. """ serializer_class = self.get_serializer_class() context = self.get_serializer_context() - return serializer_class(instance, data=data, files=files, - many=many, partial=partial, - allow_add_remove=allow_add_remove, - context=context) + return serializer_class( + instance, data=data, many=many, partial=partial, context=context + ) def get_pagination_serializer(self, page): """ diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index ac59d979..ee01cabc 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -36,12 +36,10 @@ class CreateModelMixin(object): Create a model instance. """ def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.DATA, files=request.FILES) + serializer = self.get_serializer(data=request.DATA) if serializer.is_valid(): - self.pre_save(serializer.object) - self.object = serializer.save(force_insert=True) - self.post_save(self.object, created=True) + self.object = serializer.save() headers = self.get_success_headers(serializer.data) return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) @@ -90,26 +88,20 @@ class UpdateModelMixin(object): partial = kwargs.pop('partial', False) self.object = self.get_object_or_none() - serializer = self.get_serializer(self.object, data=request.DATA, - files=request.FILES, partial=partial) + serializer = self.get_serializer(self.object, data=request.DATA, partial=partial) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - try: - self.pre_save(serializer.object) - except ValidationError as err: - # full_clean on model instance may be called in pre_save, - # so we have to handle eventual errors. - return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST) + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + lookup_value = self.kwargs[lookup_url_kwarg] + extras = {self.lookup_field: lookup_value} if self.object is None: - self.object = serializer.save(force_insert=True) - self.post_save(self.object, created=True) + self.object = serializer.save(extras=extras) return Response(serializer.data, status=status.HTTP_201_CREATED) - self.object = serializer.save(force_update=True) - self.post_save(self.object, created=False) + self.object = serializer.save(extras=extras) return Response(serializer.data, status=status.HTTP_200_OK) def partial_update(self, request, *args, **kwargs): diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index d51ea929..83ef97c5 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -48,17 +48,17 @@ class DefaultObjectSerializer(serializers.Field): super(DefaultObjectSerializer, self).__init__(source=source) -class PaginationSerializerOptions(serializers.SerializerOptions): - """ - An object that stores the options that may be provided to a - pagination serializer by using the inner `Meta` class. +# class PaginationSerializerOptions(serializers.SerializerOptions): +# """ +# An object that stores the options that may be provided to a +# pagination serializer by using the inner `Meta` class. - Accessible on the instance as `serializer.opts`. - """ - def __init__(self, meta): - super(PaginationSerializerOptions, self).__init__(meta) - self.object_serializer_class = getattr(meta, 'object_serializer_class', - DefaultObjectSerializer) +# Accessible on the instance as `serializer.opts`. +# """ +# def __init__(self, meta): +# super(PaginationSerializerOptions, self).__init__(meta) +# self.object_serializer_class = getattr(meta, 'object_serializer_class', +# DefaultObjectSerializer) class BasePaginationSerializer(serializers.Serializer): @@ -66,7 +66,7 @@ class BasePaginationSerializer(serializers.Serializer): A base class for pagination serializers to inherit from, to make implementing custom serializers more easy. """ - _options_class = PaginationSerializerOptions + # _options_class = PaginationSerializerOptions results_field = 'results' def __init__(self, *args, **kwargs): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 56870b40..e69de29b 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,486 +0,0 @@ -""" -Serializer fields that deal with relationships. - -These fields allow you to specify the style that should be used to represent -model relationships, including hyperlinks, primary keys, or slugs. -""" -from __future__ import unicode_literals -from django.core.exceptions import ObjectDoesNotExist, ValidationError -from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch -from django import forms -from django.db.models.fields import BLANK_CHOICE_DASH -from django.forms import widgets -from django.forms.models import ModelChoiceIterator -from django.utils.translation import ugettext_lazy as _ -from rest_framework.fields import Field, WritableField, get_component, is_simple_callable -from rest_framework.reverse import reverse -from rest_framework.compat import urlparse -from rest_framework.compat import smart_text - - -# Relational fields - -# Not actually Writable, but subclasses may need to be. -class RelatedField(WritableField): - """ - Base class for related model fields. - - This represents a relationship using the unicode representation of the target. - """ - widget = widgets.Select - many_widget = widgets.SelectMultiple - form_field_class = forms.ChoiceField - many_form_field_class = forms.MultipleChoiceField - null_values = (None, '', 'None') - - cache_choices = False - empty_label = None - read_only = True - many = False - - def __init__(self, *args, **kwargs): - queryset = kwargs.pop('queryset', None) - self.many = kwargs.pop('many', self.many) - if self.many: - self.widget = self.many_widget - self.form_field_class = self.many_form_field_class - - kwargs['read_only'] = kwargs.pop('read_only', self.read_only) - super(RelatedField, self).__init__(*args, **kwargs) - - if not self.required: - # Accessed in ModelChoiceIterator django/forms/models.py:1034 - # If set adds empty choice. - self.empty_label = BLANK_CHOICE_DASH[0][1] - - self.queryset = queryset - - def initialize(self, parent, field_name): - super(RelatedField, self).initialize(parent, field_name) - if self.queryset is None and not self.read_only: - manager = getattr(self.parent.opts.model, self.source or field_name) - if hasattr(manager, 'related'): # Forward - self.queryset = manager.related.model._default_manager.all() - else: # Reverse - self.queryset = manager.field.rel.to._default_manager.all() - - # We need this stuff to make form choices work... - - def prepare_value(self, obj): - return self.to_native(obj) - - def label_from_instance(self, obj): - """ - Return a readable representation for use with eg. select widgets. - """ - desc = smart_text(obj) - ident = smart_text(self.to_native(obj)) - if desc == ident: - return desc - return "%s - %s" % (desc, ident) - - def _get_queryset(self): - return self._queryset - - def _set_queryset(self, queryset): - self._queryset = queryset - self.widget.choices = self.choices - - queryset = property(_get_queryset, _set_queryset) - - def _get_choices(self): - # If self._choices is set, then somebody must have manually set - # the property self.choices. In this case, just return self._choices. - if hasattr(self, '_choices'): - return self._choices - - # Otherwise, execute the QuerySet in self.queryset to determine the - # choices dynamically. Return a fresh ModelChoiceIterator that has not been - # consumed. Note that we're instantiating a new ModelChoiceIterator *each* - # time _get_choices() is called (and, thus, each time self.choices is - # accessed) so that we can ensure the QuerySet has not been consumed. This - # construct might look complicated but it allows for lazy evaluation of - # the queryset. - return ModelChoiceIterator(self) - - def _set_choices(self, value): - # Setting choices also sets the choices on the widget. - # choices can be any iterable, but we call list() on it because - # it will be consumed more than once. - self._choices = self.widget.choices = list(value) - - choices = property(_get_choices, _set_choices) - - # Default value handling - - def get_default_value(self): - default = super(RelatedField, self).get_default_value() - if self.many and default is None: - return [] - return default - - # Regular serializer stuff... - - def field_to_native(self, obj, field_name): - try: - if self.source == '*': - return self.to_native(obj) - - source = self.source or field_name - value = obj - - for component in source.split('.'): - if value is None: - break - value = get_component(value, component) - except ObjectDoesNotExist: - return None - - if value is None: - return None - - if self.many: - if is_simple_callable(getattr(value, 'all', None)): - return [self.to_native(item) for item in value.all()] - else: - # Also support non-queryset iterables. - # This allows us to also support plain lists of related items. - return [self.to_native(item) for item in value] - return self.to_native(value) - - def field_from_native(self, data, files, field_name, into): - if self.read_only: - return - - try: - if self.many: - try: - # Form data - value = data.getlist(field_name) - if value == [''] or value == []: - raise KeyError - except AttributeError: - # Non-form data - value = data[field_name] - else: - value = data[field_name] - except KeyError: - if self.partial: - return - value = self.get_default_value() - - if value in self.null_values: - if self.required: - raise ValidationError(self.error_messages['required']) - into[(self.source or field_name)] = None - elif self.many: - into[(self.source or field_name)] = [self.from_native(item) for item in value] - else: - into[(self.source or field_name)] = self.from_native(value) - - -# PrimaryKey relationships - -class PrimaryKeyRelatedField(RelatedField): - """ - Represents a relationship as a pk value. - """ - read_only = False - - default_error_messages = { - 'does_not_exist': _("Invalid pk '%s' - object does not exist."), - 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'), - } - - # TODO: Remove these field hacks... - def prepare_value(self, obj): - return self.to_native(obj.pk) - - def label_from_instance(self, obj): - """ - Return a readable representation for use with eg. select widgets. - """ - desc = smart_text(obj) - ident = smart_text(self.to_native(obj.pk)) - if desc == ident: - return desc - return "%s - %s" % (desc, ident) - - # TODO: Possibly change this to just take `obj`, through prob less performant - def to_native(self, pk): - return pk - - def from_native(self, data): - if self.queryset is None: - raise Exception('Writable related fields must include a `queryset` argument') - - try: - return self.queryset.get(pk=data) - except ObjectDoesNotExist: - msg = self.error_messages['does_not_exist'] % smart_text(data) - raise ValidationError(msg) - except (TypeError, ValueError): - received = type(data).__name__ - msg = self.error_messages['incorrect_type'] % received - raise ValidationError(msg) - - def field_to_native(self, obj, field_name): - if self.many: - # To-many relationship - - queryset = None - if not self.source: - # Prefer obj.serializable_value for performance reasons - try: - queryset = obj.serializable_value(field_name) - except AttributeError: - pass - if queryset is None: - # RelatedManager (reverse relationship) - source = self.source or field_name - queryset = obj - for component in source.split('.'): - if queryset is None: - return [] - queryset = get_component(queryset, component) - - # Forward relationship - if is_simple_callable(getattr(queryset, 'all', None)): - return [self.to_native(item.pk) for item in queryset.all()] - else: - # Also support non-queryset iterables. - # This allows us to also support plain lists of related items. - return [self.to_native(item.pk) for item in queryset] - - # To-one relationship - try: - # Prefer obj.serializable_value for performance reasons - pk = obj.serializable_value(self.source or field_name) - except AttributeError: - # RelatedObject (reverse relationship) - try: - pk = getattr(obj, self.source or field_name).pk - except (ObjectDoesNotExist, AttributeError): - return None - - # Forward relationship - return self.to_native(pk) - - -# Slug relationships - -class SlugRelatedField(RelatedField): - """ - Represents a relationship using a unique field on the target. - """ - read_only = False - - default_error_messages = { - 'does_not_exist': _("Object with %s=%s does not exist."), - 'invalid': _('Invalid value.'), - } - - def __init__(self, *args, **kwargs): - self.slug_field = kwargs.pop('slug_field', None) - assert self.slug_field, 'slug_field is required' - super(SlugRelatedField, self).__init__(*args, **kwargs) - - def to_native(self, obj): - return getattr(obj, self.slug_field) - - def from_native(self, data): - if self.queryset is None: - raise Exception('Writable related fields must include a `queryset` argument') - - try: - return self.queryset.get(**{self.slug_field: data}) - except ObjectDoesNotExist: - raise ValidationError(self.error_messages['does_not_exist'] % - (self.slug_field, smart_text(data))) - except (TypeError, ValueError): - msg = self.error_messages['invalid'] - raise ValidationError(msg) - - -# Hyperlinked relationships - -class HyperlinkedRelatedField(RelatedField): - """ - Represents a relationship using hyperlinking. - """ - read_only = False - lookup_field = 'pk' - - default_error_messages = { - 'no_match': _('Invalid hyperlink - No URL match'), - 'incorrect_match': _('Invalid hyperlink - Incorrect URL match'), - 'configuration_error': _('Invalid hyperlink due to configuration error'), - 'does_not_exist': _("Invalid hyperlink - object does not exist."), - 'incorrect_type': _('Incorrect type. Expected url string, received %s.'), - } - - def __init__(self, *args, **kwargs): - try: - self.view_name = kwargs.pop('view_name') - except KeyError: - raise ValueError("Hyperlinked field requires 'view_name' kwarg") - - self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) - self.format = kwargs.pop('format', None) - - super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) - - def get_url(self, obj, view_name, request, format): - """ - Given an object, return the URL that hyperlinks to the object. - - May raise a `NoReverseMatch` if the `view_name` and `lookup_field` - attributes are not configured to correctly match the URL conf. - """ - lookup_field = getattr(obj, self.lookup_field) - kwargs = {self.lookup_field: lookup_field} - return reverse(view_name, kwargs=kwargs, request=request, format=format) - - def get_object(self, queryset, view_name, view_args, view_kwargs): - """ - Return the object corresponding to a matched URL. - - Takes the matched URL conf arguments, and the queryset, and should - return an object instance, or raise an `ObjectDoesNotExist` exception. - """ - lookup_value = view_kwargs[self.lookup_field] - filter_kwargs = {self.lookup_field: lookup_value} - return queryset.get(**filter_kwargs) - - def to_native(self, obj): - view_name = self.view_name - request = self.context.get('request', None) - format = self.format or self.context.get('format', None) - - assert request is not None, ( - "`HyperlinkedRelatedField` requires the request in the serializer " - "context. Add `context={'request': request}` when instantiating " - "the serializer." - ) - - # If the object has not yet been saved then we cannot hyperlink to it. - if getattr(obj, 'pk', None) is None: - return - - # Return the hyperlink, or error if incorrectly configured. - try: - return self.get_url(obj, view_name, request, format) - except NoReverseMatch: - msg = ( - 'Could not resolve URL for hyperlinked relationship using ' - 'view name "%s". You may have failed to include the related ' - 'model in your API, or incorrectly configured the ' - '`lookup_field` attribute on this field.' - ) - raise Exception(msg % view_name) - - def from_native(self, value): - # Convert URL -> model instance pk - # TODO: Use values_list - queryset = self.queryset - if queryset is None: - raise Exception('Writable related fields must include a `queryset` argument') - - try: - http_prefix = value.startswith(('http:', 'https:')) - except AttributeError: - msg = self.error_messages['incorrect_type'] - raise ValidationError(msg % type(value).__name__) - - if http_prefix: - # If needed convert absolute URLs to relative path - value = urlparse.urlparse(value).path - prefix = get_script_prefix() - if value.startswith(prefix): - value = '/' + value[len(prefix):] - - try: - match = resolve(value) - except Exception: - raise ValidationError(self.error_messages['no_match']) - - if match.view_name != self.view_name: - raise ValidationError(self.error_messages['incorrect_match']) - - try: - return self.get_object(queryset, match.view_name, - match.args, match.kwargs) - except (ObjectDoesNotExist, TypeError, ValueError): - raise ValidationError(self.error_messages['does_not_exist']) - - -class HyperlinkedIdentityField(Field): - """ - Represents the instance, or a property on the instance, using hyperlinking. - """ - lookup_field = 'pk' - read_only = True - - def __init__(self, *args, **kwargs): - try: - self.view_name = kwargs.pop('view_name') - except KeyError: - msg = "HyperlinkedIdentityField requires 'view_name' argument" - raise ValueError(msg) - - self.format = kwargs.pop('format', None) - lookup_field = kwargs.pop('lookup_field', None) - self.lookup_field = lookup_field or self.lookup_field - - super(HyperlinkedIdentityField, self).__init__(*args, **kwargs) - - def field_to_native(self, obj, field_name): - request = self.context.get('request', None) - format = self.context.get('format', None) - view_name = self.view_name - - assert request is not None, ( - "`HyperlinkedIdentityField` requires the request in the serializer" - " context. Add `context={'request': request}` when instantiating " - "the serializer." - ) - - # By default use whatever format is given for the current context - # unless the target is a different type to the source. - # - # Eg. Consider a HyperlinkedIdentityField pointing from a json - # representation to an html property of that representation... - # - # '/snippets/1/' should link to '/snippets/1/highlight/' - # ...but... - # '/snippets/1/.json' should link to '/snippets/1/highlight/.html' - if format and self.format and self.format != format: - format = self.format - - # Return the hyperlink, or error if incorrectly configured. - try: - return self.get_url(obj, view_name, request, format) - except NoReverseMatch: - msg = ( - 'Could not resolve URL for hyperlinked relationship using ' - 'view name "%s". You may have failed to include the related ' - 'model in your API, or incorrectly configured the ' - '`lookup_field` attribute on this field.' - ) - raise Exception(msg % view_name) - - def get_url(self, obj, view_name, request, format): - """ - Given an object, return the URL that hyperlinks to the object. - - May raise a `NoReverseMatch` if the `view_name` and `lookup_field` - attributes are not configured to correctly match the URL conf. - """ - lookup_field = getattr(obj, self.lookup_field, None) - kwargs = {self.lookup_field: lookup_field} - - # Handle unsaved object case - if lookup_field is None: - return None - - return reverse(view_name, kwargs=kwargs, request=request, format=format) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 748ebac9..e8935b01 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -458,7 +458,7 @@ class BrowsableAPIRenderer(BaseRenderer): ): return - serializer = view.get_serializer(instance=obj, data=data, files=files) + serializer = view.get_serializer(instance=obj, data=data) serializer.is_valid() data = serializer.data @@ -579,10 +579,10 @@ class BrowsableAPIRenderer(BaseRenderer): 'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes], 'response_headers': response_headers, - 'put_form': self.get_rendered_html_form(view, 'PUT', request), - 'post_form': self.get_rendered_html_form(view, 'POST', request), - 'delete_form': self.get_rendered_html_form(view, 'DELETE', request), - 'options_form': self.get_rendered_html_form(view, 'OPTIONS', request), + #'put_form': self.get_rendered_html_form(view, 'PUT', request), + #'post_form': self.get_rendered_html_form(view, 'POST', request), + #'delete_form': self.get_rendered_html_form(view, 'DELETE', request), + #'options_form': self.get_rendered_html_form(view, 'OPTIONS', request), 'raw_data_put_form': raw_data_put_form, 'raw_data_post_form': raw_data_post_form, diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index be8ad3f2..d121812d 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,21 +10,14 @@ python primitives. 2. The process of marshalling between python primitives and request and response content is handled by parsers and renderers. """ -from __future__ import unicode_literals -import copy -import datetime -import inspect -import types -from decimal import Decimal -from django.contrib.contenttypes.generic import GenericForeignKey -from django.core.paginator import Page from django.db import models -from django.forms import widgets from django.utils import six -from django.utils.datastructures import SortedDict -from django.core.exceptions import ObjectDoesNotExist +from collections import namedtuple, OrderedDict +from rest_framework.fields import empty, set_value, Field, SkipField, ValidationError from rest_framework.settings import api_settings - +from rest_framework.utils import html +import copy +import inspect # Note: We do the following so that users of the framework can use this style: # @@ -37,635 +30,339 @@ from rest_framework.relations import * # NOQA from rest_framework.fields import * # NOQA -def _resolve_model(obj): - """ - Resolve supplied `obj` to a Django model class. +FieldResult = namedtuple('FieldResult', ['field', 'value', 'error']) - `obj` must be a Django model class itself, or a string - representation of one. Useful in situtations like GH #1225 where - Django may not have resolved a string-based reference to a model in - another model's foreign key definition. - - String representations should have the format: - 'appname.ModelName' - """ - if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: - app_name, model_name = obj.split('.') - return models.get_model(app_name, model_name) - elif inspect.isclass(obj) and issubclass(obj, models.Model): - return obj - else: - raise ValueError("{0} is not a Django model".format(obj)) - - -def pretty_name(name): - """Converts 'first_name' to 'First name'""" - if not name: - return '' - return name.replace('_', ' ').capitalize() +class BaseSerializer(Field): + def __init__(self, instance=None, data=None, **kwargs): + super(BaseSerializer, self).__init__(**kwargs) + self.instance = instance + self._initial_data = data -class RelationsList(list): - _deleted = [] + def to_native(self, data): + raise NotImplementedError() + def to_primative(self, instance): + raise NotImplementedError() -class NestedValidationError(ValidationError): - """ - The default ValidationError behavior is to stringify each item in the list - if the messages are a list of error messages. + def update(self, instance): + raise NotImplementedError() - In the case of nested serializers, where the parent has many children, - then the child's `serializer.errors` will be a list of dicts. In the case - of a single child, the `serializer.errors` will be a dict. + def create(self): + raise NotImplementedError() - We need to override the default behavior to get properly nested error dicts. - """ + def save(self, extras=None): + if extras is not None: + self._validated_data.update(extras) - def __init__(self, message): - if isinstance(message, dict): - self._messages = [message] + if self.instance is not None: + self.update(self.instance) else: - self._messages = message - - @property - def messages(self): - return self._messages + self.instance = self.create() + return self.instance -class DictWithMetadata(dict): - """ - A dict-like object, that can have additional properties attached. - """ - def __getstate__(self): - """ - Used by pickle (e.g., caching). - Overridden to remove the metadata from the dict, since it shouldn't be - pickled and may in some instances be unpickleable. - """ - return dict(self) - + def is_valid(self): + try: + self._validated_data = self.to_native(self._initial_data) + except ValidationError as exc: + self._validated_data = {} + self._errors = exc.args[0] + return False + self._errors = {} + return True -class SortedDictWithMetadata(SortedDict): - """ - A sorted dict-like object, that can have additional properties attached. - """ - def __getstate__(self): - """ - Used by pickle (e.g., caching). - Overriden to remove the metadata from the dict, since it shouldn't be - pickle and may in some instances be unpickleable. - """ - return SortedDict(self).__dict__ + @property + def data(self): + if not hasattr(self, '_data'): + if self.instance is not None: + self._data = self.to_primative(self.instance) + elif self._initial_data is not None: + self._data = { + field_name: field.get_value(self._initial_data) + for field_name, field in self.fields.items() + } + else: + self._data = self.get_initial() + return self._data + @property + def errors(self): + if not hasattr(self, '_errors'): + msg = 'You must call `.is_valid()` before accessing `.errors`.' + raise AssertionError(msg) + return self._errors -def _is_protected_type(obj): - """ - True if the object is a native datatype that does not need to - be serialized further. - """ - return isinstance(obj, ( - types.NoneType, - int, long, - datetime.datetime, datetime.date, datetime.time, - float, Decimal, - basestring) - ) + @property + def validated_data(self): + if not hasattr(self, '_validated_data'): + msg = 'You must call `.is_valid()` before accessing `.validated_data`.' + raise AssertionError(msg) + return self._validated_data -def _get_declared_fields(bases, attrs): +class SerializerMetaclass(type): """ - Create a list of serializer field instances from the passed in 'attrs', - plus any fields on the base classes (in 'bases'). + This metaclass sets a dictionary named `base_fields` on the class. - Note that all fields from the base classes are used. + Any fields included as attributes on either the class or it's superclasses + will be include in the `base_fields` dictionary. """ - fields = [(field_name, attrs.pop(field_name)) - for field_name, obj in list(six.iteritems(attrs)) - if isinstance(obj, Field)] - fields.sort(key=lambda x: x[1].creation_counter) - # If this class is subclassing another Serializer, add that Serializer's - # fields. Note that we loop over the bases in *reverse*. This is necessary - # in order to maintain the correct order of fields. - for base in bases[::-1]: - if hasattr(base, 'base_fields'): - fields = list(base.base_fields.items()) + fields + @classmethod + def _get_fields(cls, bases, attrs): + fields = [(field_name, attrs.pop(field_name)) + for field_name, obj in list(attrs.items()) + if isinstance(obj, Field)] + fields.sort(key=lambda x: x[1]._creation_counter) - return SortedDict(fields) + # If this class is subclassing another Serializer, add that Serializer's + # fields. Note that we loop over the bases in *reverse*. This is necessary + # in order to maintain the correct order of fields. + for base in bases[::-1]: + if hasattr(base, 'base_fields'): + fields = list(base.base_fields.items()) + fields + return OrderedDict(fields) -class SerializerMetaclass(type): def __new__(cls, name, bases, attrs): - attrs['base_fields'] = _get_declared_fields(bases, attrs) + attrs['base_fields'] = cls._get_fields(bases, attrs) return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) -class SerializerOptions(object): - """ - Meta class options for Serializer - """ - def __init__(self, meta): - self.depth = getattr(meta, 'depth', 0) - self.fields = getattr(meta, 'fields', ()) - self.exclude = getattr(meta, 'exclude', ()) +@six.add_metaclass(SerializerMetaclass) +class Serializer(BaseSerializer): + def __new__(cls, *args, **kwargs): + many = kwargs.pop('many', False) + if many: + class DynamicListSerializer(ListSerializer): + child = cls() + return DynamicListSerializer(*args, **kwargs) + return super(Serializer, cls).__new__(cls) -class BaseSerializer(WritableField): - """ - This is the Serializer implementation. - We need to implement it as `BaseSerializer` due to metaclass magicks. - """ - class Meta(object): - pass - - _options_class = SerializerOptions - _dict_class = SortedDictWithMetadata - - def __init__(self, instance=None, data=None, files=None, - context=None, partial=False, many=False, - allow_add_remove=False, **kwargs): - super(BaseSerializer, self).__init__(**kwargs) - self.opts = self._options_class(self.Meta) - self.parent = None - self.root = None - self.partial = partial - self.many = many - self.allow_add_remove = allow_add_remove + def __init__(self, *args, **kwargs): + kwargs.pop('context', None) + kwargs.pop('partial', None) + kwargs.pop('many', False) - self.context = context or {} + super(Serializer, self).__init__(*args, **kwargs) - self.init_data = data - self.init_files = files - self.object = instance + # Every new serializer is created with a clone of the field instances. + # This allows users to dynamically modify the fields on a serializer + # instance without affecting every other serializer class. self.fields = self.get_fields() - self._data = None - self._files = None - self._errors = None - - if many and instance is not None and not hasattr(instance, '__iter__'): - raise ValueError('instance should be a queryset or other iterable with many=True') - - if allow_add_remove and not many: - raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') - - ##### - # Methods to determine which fields to use when (de)serializing objects. - - def get_default_fields(self): - """ - Return the complete set of default fields for the object, as a dict. - """ - return {} - - def get_fields(self): - """ - Returns the complete set of fields for the object as a dict. - - This will be the set of any explicitly declared fields, - plus the set of fields returned by get_default_fields(). - """ - ret = SortedDict() - - # Get the explicitly declared fields - base_fields = copy.deepcopy(self.base_fields) - for key, field in base_fields.items(): - ret[key] = field - - # Add in the default fields - default_fields = self.get_default_fields() - for key, val in default_fields.items(): - if key not in ret: - ret[key] = val - - # If 'fields' is specified, use those fields, in that order. - if self.opts.fields: - assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple' - new = SortedDict() - for key in self.opts.fields: - new[key] = ret[key] - ret = new - - # Remove anything in 'exclude' - if self.opts.exclude: - assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple' - for key in self.opts.exclude: - ret.pop(key, None) - - for key, field in ret.items(): - field.initialize(parent=self, field_name=key) - - return ret - - ##### - # Methods to convert or revert from objects <--> primitive representations. - - def get_field_key(self, field_name): - """ - Return the key that should be used for a given field. - """ - return field_name - - def restore_fields(self, data, files): - """ - Core of deserialization, together with `restore_object`. - Converts a dictionary of data into a dictionary of deserialized fields. - """ - reverted_data = {} - - if data is not None and not isinstance(data, dict): - self._errors['non_field_errors'] = ['Invalid data'] - return None - + # Setup all the child fields, to provide them with the current context. for field_name, field in self.fields.items(): - field.initialize(parent=self, field_name=field_name) - try: - field.field_from_native(data, files, field_name, reverted_data) - except ValidationError as err: - self._errors[field_name] = list(err.messages) + field.bind(field_name, self, self) - return reverted_data + def get_fields(self): + return copy.deepcopy(self.base_fields) - def perform_validation(self, attrs): - """ - Run `validate_()` and `validate()` methods on the serializer - """ + def bind(self, field_name, parent, root): + # If the serializer is used as a field then when it becomes bound + # it also needs to bind all its child fields. + super(Serializer, self).bind(field_name, parent, root) for field_name, field in self.fields.items(): - if field_name in self._errors: - continue + field.bind(field_name, self, root) - source = field.source or field_name - if self.partial and source not in attrs: - continue - try: - validate_method = getattr(self, 'validate_%s' % field_name, None) - if validate_method: - attrs = validate_method(attrs, source) - except ValidationError as err: - self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) - - # If there are already errors, we don't run .validate() because - # field-validation failed and thus `attrs` may not be complete. - # which in turn can cause inconsistent validation errors. - if not self._errors: - try: - attrs = self.validate(attrs) - except ValidationError as err: - if hasattr(err, 'message_dict'): - for field_name, error_messages in err.message_dict.items(): - self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) - elif hasattr(err, 'messages'): - self._errors['non_field_errors'] = err.messages - - return attrs + def get_initial(self): + return { + field.field_name: field.get_initial() + for field in self.fields.values() + } - def validate(self, attrs): - """ - Stub method, to be overridden in Serializer subclasses - """ - return attrs + def get_value(self, dictionary): + # We override the default field access in order to support + # nested HTML forms. + if html.is_html_input(dictionary): + return html.parse_html_dict(dictionary, prefix=self.field_name) + return dictionary.get(self.field_name, empty) - def restore_object(self, attrs, instance=None): + def to_native(self, data): """ - Deserialize a dictionary of attributes into an object instance. - You should override this method to control how deserialized objects - are instantiated. + Dict of native values <- Dict of primitive datatypes. """ - if instance is not None: - instance.update(attrs) - return instance - return attrs + ret = {} + errors = {} + fields = [field for field in self.fields.values() if not field.read_only] - def to_native(self, obj): - """ - Serialize objects -> primitives. - """ - ret = self._dict_class() - ret.fields = self._dict_class() + for field in fields: + primitive_value = field.get_value(data) + try: + validated_value = field.validate(primitive_value) + except ValidationError as exc: + errors[field.field_name] = str(exc) + except SkipField: + pass + else: + set_value(ret, field.source_attrs, validated_value) - for field_name, field in self.fields.items(): - if field.read_only and obj is None: - continue - field.initialize(parent=self, field_name=field_name) - key = self.get_field_key(field_name) - value = field.field_to_native(obj, field_name) - method = getattr(self, 'transform_%s' % field_name, None) - if callable(method): - value = method(obj, value) - if not getattr(field, 'write_only', False): - ret[key] = value - ret.fields[key] = self.augment_field(field, field_name, key, value) + if errors: + raise ValidationError(errors) return ret - def from_native(self, data, files=None): - """ - Deserialize primitives -> objects. - """ - self._errors = {} - - if data is not None or files is not None: - attrs = self.restore_fields(data, files) - if attrs is not None: - attrs = self.perform_validation(attrs) - else: - self._errors['non_field_errors'] = ['No input provided'] - - if not self._errors: - return self.restore_object(attrs, instance=getattr(self, 'object', None)) - - def augment_field(self, field, field_name, key, value): - # This horrible stuff is to manage serializers rendering to HTML - field._errors = self._errors.get(key) if self._errors else None - field._name = field_name - field._value = self.init_data.get(key) if self._errors and self.init_data else value - if not field.label: - field.label = pretty_name(key) - return field - - def field_to_native(self, obj, field_name): + def to_primative(self, instance): """ - Override default so that the serializer can be used as a nested field - across relationships. + Object instance -> Dict of primitive datatypes. """ - if self.write_only: - return None + ret = OrderedDict() + fields = [field for field in self.fields.values() if not field.write_only] - if self.source == '*': - return self.to_native(obj) + for field in fields: + native_value = field.get_attribute(instance) + ret[field.field_name] = field.to_primative(native_value) - # Get the raw field value - try: - source = self.source or field_name - value = obj - - for component in source.split('.'): - if value is None: - break - value = get_component(value, component) - except ObjectDoesNotExist: - return None + return ret - if is_simple_callable(getattr(value, 'all', None)): - return [self.to_native(item) for item in value.all()] + def __iter__(self): + errors = self.errors if hasattr(self, '_errors') else {} + for field in self.fields.values(): + value = self.data.get(field.field_name) if self.data else None + error = errors.get(field.field_name) + yield FieldResult(field, value, error) - if value is None: - return None - if self.many: - return [self.to_native(item) for item in value] - return self.to_native(value) +class ListSerializer(BaseSerializer): + child = None + initial = [] - def field_from_native(self, data, files, field_name, into): - """ - Override default so that the serializer can be used as a writable - nested field across relationships. - """ - if self.read_only: - return + def __init__(self, *args, **kwargs): + self.child = kwargs.pop('child', copy.deepcopy(self.child)) + assert self.child is not None, '`child` is a required argument.' - try: - value = data[field_name] - except KeyError: - if self.default is not None and not self.partial: - # Note: partial updates shouldn't set defaults - value = copy.deepcopy(self.default) - else: - if self.required: - raise ValidationError(self.error_messages['required']) - return - - if self.source == '*': - if value: - reverted_data = self.restore_fields(value, {}) - if not self._errors: - into.update(reverted_data) - else: - if value in (None, ''): - into[(self.source or field_name)] = None - else: - # Set the serializer object if it exists - obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None - - # If we have a model manager or similar object then we need - # to iterate through each instance. - if ( - self.many and - not hasattr(obj, '__iter__') and - is_simple_callable(getattr(obj, 'all', None)) - ): - obj = obj.all() - - kwargs = { - 'instance': obj, - 'data': value, - 'context': self.context, - 'partial': self.partial, - 'many': self.many, - 'allow_add_remove': self.allow_add_remove - } - serializer = self.__class__(**kwargs) + kwargs.pop('context', None) + kwargs.pop('partial', None) - if serializer.is_valid(): - into[self.source or field_name] = serializer.object - else: - # Propagate errors up to our parent - raise NestedValidationError(serializer.errors) + super(ListSerializer, self).__init__(*args, **kwargs) + self.child.bind('', self, self) - def get_identity(self, data): - """ - This hook is required for bulk update. - It is used to determine the canonical identity of a given object. + def bind(self, field_name, parent, root): + # If the list is used as a field then it needs to provide + # the current context to the child serializer. + super(ListSerializer, self).bind(field_name, parent, root) + self.child.bind(field_name, self, root) - Note that the data has not been validated at this point, so we need - to make sure that we catch any cases of incorrect datatypes being - passed to this method. - """ - try: - return data.get('id', None) - except AttributeError: - return None + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if is_html_input(dictionary): + return html.parse_html_list(dictionary, prefix=self.field_name) + return dictionary.get(self.field_name, empty) - @property - def errors(self): + def to_native(self, data): """ - Run deserialization and return error data, - setting self.object if no errors occurred. + List of dicts of native values <- List of dicts of primitive datatypes. """ - if self._errors is None: - data, files = self.init_data, self.init_files + if html.is_html_input(data): + data = html.parse_html_list(data) - if self.many is not None: - many = self.many - else: - many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) - if many: - warnings.warn('Implicit list/queryset serialization is deprecated. ' - 'Use the `many=True` flag when instantiating the serializer.', - DeprecationWarning, stacklevel=3) - - if many: - ret = RelationsList() - errors = [] - update = self.object is not None - - if update: - # If this is a bulk update we need to map all the objects - # to a canonical identity so we can determine which - # individual object is being updated for each item in the - # incoming data - objects = self.object - identities = [self.get_identity(self.to_native(obj)) for obj in objects] - identity_to_objects = dict(zip(identities, objects)) - - if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): - for item in data: - if update: - # Determine which object we're updating - identity = self.get_identity(item) - self.object = identity_to_objects.pop(identity, None) - if self.object is None and not self.allow_add_remove: - ret.append(None) - errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) - continue - - ret.append(self.from_native(item, None)) - errors.append(self._errors) - - if update and self.allow_add_remove: - ret._deleted = identity_to_objects.values() - - self._errors = any(errors) and errors or [] - else: - self._errors = {'non_field_errors': ['Expected a list of items.']} - else: - ret = self.from_native(data, files) - - if not self._errors: - self.object = ret - - return self._errors - - def is_valid(self): - return not self.errors + return [self.child.validate(item) for item in data] - @property - def data(self): + def to_primative(self, data): """ - Returns the serialized data on the serializer. + List of object instances -> List of dicts of primitive datatypes. """ - if self._data is None: - obj = self.object + return [self.child.to_primative(item) for item in data] - if self.many is not None: - many = self.many - else: - many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) - if many: - warnings.warn('Implicit list/queryset serialization is deprecated. ' - 'Use the `many=True` flag when instantiating the serializer.', - DeprecationWarning, stacklevel=2) - - if many: - self._data = [self.to_native(item) for item in obj] - else: - self._data = self.to_native(obj) + def create(self, attrs_list): + return [self.child.create(attrs) for attrs in attrs_list] - return self._data + def save(self): + if self.instance is not None: + self.update(self.instance, self.validated_data) + self.instance = self.create(self.validated_data) + return self.instance - def save_object(self, obj, **kwargs): - obj.save(**kwargs) - def delete_object(self, obj): - obj.delete() - - def save(self, **kwargs): - """ - Save the deserialized object and return it. - """ - # Clear cached _data, which may be invalidated by `save()` - self._data = None - - if isinstance(self.object, list): - [self.save_object(item, **kwargs) for item in self.object] - - if self.object._deleted: - [self.delete_object(item) for item in self.object._deleted] - else: - self.save_object(self.object, **kwargs) - - return self.object - - def metadata(self): - """ - Return a dictionary of metadata about the fields on the serializer. - Useful for things like responding to OPTIONS requests, or generating - API schemas for auto-documentation. - """ - return SortedDict( - [ - (field_name, field.metadata()) - for field_name, field in six.iteritems(self.fields) - ] - ) +def _resolve_model(obj): + """ + Resolve supplied `obj` to a Django model class. + `obj` must be a Django model class itself, or a string + representation of one. Useful in situtations like GH #1225 where + Django may not have resolved a string-based reference to a model in + another model's foreign key definition. -class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)): - pass + String representations should have the format: + 'appname.ModelName' + """ + if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: + app_name, model_name = obj.split('.') + return models.get_model(app_name, model_name) + elif inspect.isclass(obj) and issubclass(obj, models.Model): + return obj + else: + raise ValueError("{0} is not a Django model".format(obj)) -class ModelSerializerOptions(SerializerOptions): +class ModelSerializerOptions(object): """ Meta class options for ModelSerializer """ def __init__(self, meta): - super(ModelSerializerOptions, self).__init__(meta) - self.model = getattr(meta, 'model', None) - self.read_only_fields = getattr(meta, 'read_only_fields', ()) - self.write_only_fields = getattr(meta, 'write_only_fields', ()) + self.model = getattr(meta, 'model') + self.fields = getattr(meta, 'fields', ()) + self.depth = getattr(meta, 'depth', 0) class ModelSerializer(Serializer): - """ - A serializer that deals with model instances and querysets. - """ - _options_class = ModelSerializerOptions - field_mapping = { models.AutoField: IntegerField, - models.FloatField: FloatField, + # models.FloatField: FloatField, models.IntegerField: IntegerField, models.PositiveIntegerField: IntegerField, models.SmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField, - models.DateTimeField: DateTimeField, - models.DateField: DateField, - models.TimeField: TimeField, - models.DecimalField: DecimalField, - models.EmailField: EmailField, + # models.DateTimeField: DateTimeField, + # models.DateField: DateField, + # models.TimeField: TimeField, + # models.DecimalField: DecimalField, + # models.EmailField: EmailField, models.CharField: CharField, - models.URLField: URLField, - models.SlugField: SlugField, + # models.URLField: URLField, + # models.SlugField: SlugField, models.TextField: CharField, models.CommaSeparatedIntegerField: CharField, models.BooleanField: BooleanField, models.NullBooleanField: BooleanField, - models.FileField: FileField, - models.ImageField: ImageField, + # models.FileField: FileField, + # models.ImageField: ImageField, } + _options_class = ModelSerializerOptions + + def __init__(self, *args, **kwargs): + self.opts = self._options_class(self.Meta) + super(ModelSerializer, self).__init__(*args, **kwargs) + + def get_fields(self): + # Get the explicitly declared fields. + fields = copy.deepcopy(self.base_fields) + + # Add in the default fields. + for key, val in self.get_default_fields().items(): + if key not in fields: + fields[key] = val + + # If `fields` is set on the `Meta` class, + # then use only those fields, and in that order. + if self.opts.fields: + fields = OrderedDict([ + (key, fields[key]) for key in self.opts.fields + ]) + + return fields + def get_default_fields(self): """ Return all the fields that should be serialized for the model. """ - cls = self.opts.model - assert cls is not None, ( - "Serializer class '%s' is missing 'model' Meta option" % - self.__class__.__name__ - ) opts = cls._meta.concrete_model._meta - ret = SortedDict() + ret = OrderedDict() nested = bool(self.opts.depth) # Deal with adding the primary key field @@ -694,29 +391,9 @@ class ModelSerializer(Serializer): has_through_model = True if model_field.rel and nested: - if len(inspect.getargspec(self.get_nested_field).args) == 2: - warnings.warn( - 'The `get_nested_field(model_field)` call signature ' - 'is deprecated. ' - 'Use `get_nested_field(model_field, related_model, ' - 'to_many) instead', - DeprecationWarning - ) - field = self.get_nested_field(model_field) - else: - field = self.get_nested_field(model_field, related_model, to_many) + field = self.get_nested_field(model_field, related_model, to_many) elif model_field.rel: - if len(inspect.getargspec(self.get_nested_field).args) == 3: - warnings.warn( - 'The `get_related_field(model_field, to_many)` call ' - 'signature is deprecated. ' - 'Use `get_related_field(model_field, related_model, ' - 'to_many) instead', - DeprecationWarning - ) - field = self.get_related_field(model_field, to_many=to_many) - else: - field = self.get_related_field(model_field, related_model, to_many) + field = self.get_related_field(model_field, related_model, to_many) else: field = self.get_field(model_field) @@ -763,38 +440,6 @@ class ModelSerializer(Serializer): ret[accessor_name] = field - # Ensure that 'read_only_fields' is an iterable - assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple' - - # Add the `read_only` flag to any fields that have been specified - # in the `read_only_fields` option - for field_name in self.opts.read_only_fields: - assert field_name not in self.base_fields.keys(), ( - "field '%s' on serializer '%s' specified in " - "`read_only_fields`, but also added " - "as an explicit field. Remove it from `read_only_fields`." % - (field_name, self.__class__.__name__)) - assert field_name in ret, ( - "Non-existant field '%s' specified in `read_only_fields` " - "on serializer '%s'." % - (field_name, self.__class__.__name__)) - ret[field_name].read_only = True - - # Ensure that 'write_only_fields' is an iterable - assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple' - - for field_name in self.opts.write_only_fields: - assert field_name not in self.base_fields.keys(), ( - "field '%s' on serializer '%s' specified in " - "`write_only_fields`, but also added " - "as an explicit field. Remove it from `write_only_fields`." % - (field_name, self.__class__.__name__)) - assert field_name in ret, ( - "Non-existant field '%s' specified in `write_only_fields` " - "on serializer '%s'." % - (field_name, self.__class__.__name__)) - ret[field_name].write_only = True - return ret def get_pk_field(self, model_field): @@ -825,28 +470,24 @@ class ModelSerializer(Serializer): # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - kwargs = { - 'queryset': related_model._default_manager, - 'many': to_many - } + kwargs = {} + # 'queryset': related_model._default_manager, + # 'many': to_many + # } if model_field: kwargs['required'] = not(model_field.null or model_field.blank) - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text + # if model_field.help_text is not None: + # kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name - if not model_field.editable: kwargs['read_only'] = True - if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text - - return PrimaryKeyRelatedField(**kwargs) + return IntegerField(**kwargs) + # TODO: return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): """ @@ -869,8 +510,8 @@ class ModelSerializer(Serializer): if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text + # if model_field.help_text is not None: + # kwargs['help_text'] = model_field.help_text # TODO: TypedChoiceField? if model_field.flatchoices: # This ModelField contains choices @@ -880,7 +521,7 @@ class ModelSerializer(Serializer): return ChoiceField(**kwargs) # put this below the ChoiceField because min_value isn't a valid initializer - if issubclass(model_field.__class__, models.PositiveIntegerField) or\ + if issubclass(model_field.__class__, models.PositiveIntegerField) or \ issubclass(model_field.__class__, models.PositiveSmallIntegerField): kwargs['min_value'] = 0 @@ -888,170 +529,27 @@ class ModelSerializer(Serializer): issubclass(model_field.__class__, (models.CharField, models.TextField)): kwargs['allow_none'] = True - attribute_dict = { - models.CharField: ['max_length'], - models.CommaSeparatedIntegerField: ['max_length'], - models.DecimalField: ['max_digits', 'decimal_places'], - models.EmailField: ['max_length'], - models.FileField: ['max_length'], - models.ImageField: ['max_length'], - models.SlugField: ['max_length'], - models.URLField: ['max_length'], - } - - if model_field.__class__ in attribute_dict: - attributes = attribute_dict[model_field.__class__] - for attribute in attributes: - kwargs.update({attribute: getattr(model_field, attribute)}) + # attribute_dict = { + # models.CharField: ['max_length'], + # models.CommaSeparatedIntegerField: ['max_length'], + # models.DecimalField: ['max_digits', 'decimal_places'], + # models.EmailField: ['max_length'], + # models.FileField: ['max_length'], + # models.ImageField: ['max_length'], + # models.SlugField: ['max_length'], + # models.URLField: ['max_length'], + # } + + # if model_field.__class__ in attribute_dict: + # attributes = attribute_dict[model_field.__class__] + # for attribute in attributes: + # kwargs.update({attribute: getattr(model_field, attribute)}) try: return self.field_mapping[model_field.__class__](**kwargs) except KeyError: - return ModelField(model_field=model_field, **kwargs) - - def get_validation_exclusions(self, instance=None): - """ - Return a list of field names to exclude from model validation. - """ - cls = self.opts.model - opts = cls._meta.concrete_model._meta - exclusions = [field.name for field in opts.fields + opts.many_to_many] - - for field_name, field in self.fields.items(): - field_name = field.source or field_name - if ( - field_name in exclusions - and not field.read_only - and (field.required or hasattr(instance, field_name)) - and not isinstance(field, Serializer) - ): - exclusions.remove(field_name) - return exclusions - - def full_clean(self, instance): - """ - Perform Django's full_clean, and populate the `errors` dictionary - if any validation errors occur. - - Note that we don't perform this inside the `.restore_object()` method, - so that subclasses can override `.restore_object()`, and still get - the full_clean validation checking. - """ - try: - instance.full_clean(exclude=self.get_validation_exclusions(instance)) - except ValidationError as err: - self._errors = err.message_dict - return None - return instance - - def restore_object(self, attrs, instance=None): - """ - Restore the model instance. - """ - m2m_data = {} - related_data = {} - nested_forward_relations = {} - meta = self.opts.model._meta - - # Reverse fk or one-to-one relations - for (obj, model) in meta.get_all_related_objects_with_model(): - field_name = obj.get_accessor_name() - if field_name in attrs: - related_data[field_name] = attrs.pop(field_name) - - # Reverse m2m relations - for (obj, model) in meta.get_all_related_m2m_objects_with_model(): - field_name = obj.get_accessor_name() - if field_name in attrs: - m2m_data[field_name] = attrs.pop(field_name) - - # Forward m2m relations - for field in meta.many_to_many + meta.virtual_fields: - if isinstance(field, GenericForeignKey): - continue - if field.name in attrs: - m2m_data[field.name] = attrs.pop(field.name) - - # Nested forward relations - These need to be marked so we can save - # them before saving the parent model instance. - for field_name in attrs.keys(): - if isinstance(self.fields.get(field_name, None), Serializer): - nested_forward_relations[field_name] = attrs[field_name] - - # Create an empty instance of the model - if instance is None: - instance = self.opts.model() - - for key, val in attrs.items(): - try: - setattr(instance, key, val) - except ValueError: - self._errors[key] = [self.error_messages['required']] - - # Any relations that cannot be set until we've - # saved the model get hidden away on these - # private attributes, so we can deal with them - # at the point of save. - instance._related_data = related_data - instance._m2m_data = m2m_data - instance._nested_forward_relations = nested_forward_relations - - return instance - - def from_native(self, data, files): - """ - Override the default method to also include model field validation. - """ - instance = super(ModelSerializer, self).from_native(data, files) - if not self._errors: - return self.full_clean(instance) - - def save_object(self, obj, **kwargs): - """ - Save the deserialized object. - """ - if getattr(obj, '_nested_forward_relations', None): - # Nested relationships need to be saved before we can save the - # parent instance. - for field_name, sub_object in obj._nested_forward_relations.items(): - if sub_object: - self.save_object(sub_object) - setattr(obj, field_name, sub_object) - - obj.save(**kwargs) - - if getattr(obj, '_m2m_data', None): - for accessor_name, object_list in obj._m2m_data.items(): - setattr(obj, accessor_name, object_list) - del(obj._m2m_data) - - if getattr(obj, '_related_data', None): - related_fields = dict([ - (field.get_accessor_name(), field) - for field, model - in obj._meta.get_all_related_objects_with_model() - ]) - for accessor_name, related in obj._related_data.items(): - if isinstance(related, RelationsList): - # Nested reverse fk relationship - for related_item in related: - fk_field = related_fields[accessor_name].field.name - setattr(related_item, fk_field, obj) - self.save_object(related_item) - - # Delete any removed objects - if related._deleted: - [self.delete_object(item) for item in related._deleted] - - elif isinstance(related, models.Model): - # Nested reverse one-one relationship - fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related, fk_field, obj) - self.save_object(related) - else: - # Reverse FK or reverse one-one - setattr(obj, accessor_name, related) - del(obj._related_data) + # TODO: Change this to `return ModelField(model_field=model_field, **kwargs)` + return CharField(**kwargs) class HyperlinkedModelSerializerOptions(ModelSerializerOptions): @@ -1066,14 +564,10 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializer(ModelSerializer): - """ - A subclass of ModelSerializer that uses hyperlinked relationships, - instead of primary key relationships. - """ _options_class = HyperlinkedModelSerializerOptions _default_view_name = '%(model_name)s-detail' - _hyperlink_field_class = HyperlinkedRelatedField - _hyperlink_identify_field_class = HyperlinkedIdentityField + #_hyperlink_field_class = HyperlinkedRelatedField + #_hyperlink_identify_field_class = HyperlinkedIdentityField def get_default_fields(self): fields = super(HyperlinkedModelSerializer, self).get_default_fields() @@ -1081,15 +575,15 @@ class HyperlinkedModelSerializer(ModelSerializer): if self.opts.view_name is None: self.opts.view_name = self._get_default_view_name(self.opts.model) - if self.opts.url_field_name not in fields: - url_field = self._hyperlink_identify_field_class( - view_name=self.opts.view_name, - lookup_field=self.opts.lookup_field - ) - ret = self._dict_class() - ret[self.opts.url_field_name] = url_field - ret.update(fields) - fields = ret + # if self.opts.url_field_name not in fields: + # url_field = self._hyperlink_identify_field_class( + # view_name=self.opts.view_name, + # lookup_field=self.opts.lookup_field + # ) + # ret = self._dict_class() + # ret[self.opts.url_field_name] = url_field + # ret.update(fields) + # fields = ret return fields @@ -1103,33 +597,25 @@ class HyperlinkedModelSerializer(ModelSerializer): """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - kwargs = { - 'queryset': related_model._default_manager, - 'view_name': self._get_default_view_name(related_model), - 'many': to_many - } + # kwargs = { + # 'queryset': related_model._default_manager, + # 'view_name': self._get_default_view_name(related_model), + # 'many': to_many + # } + kwargs = {} if model_field: kwargs['required'] = not(model_field.null or model_field.blank) - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text + # if model_field.help_text is not None: + # kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name - if self.opts.lookup_field: - kwargs['lookup_field'] = self.opts.lookup_field - - return self._hyperlink_field_class(**kwargs) + return IntegerField(**kwargs) + # if self.opts.lookup_field: + # kwargs['lookup_field'] = self.opts.lookup_field - def get_identity(self, data): - """ - This hook is required for bulk update. - We need to override the default, to use the url as the identity. - """ - try: - return data.get(self.opts.url_field_name, None) - except AttributeError: - return None + # return self._hyperlink_field_class(**kwargs) def _get_default_view_name(self, model): """ diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 00ffdfba..6a2f6126 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -7,7 +7,7 @@ from django.db.models.query import QuerySet from django.utils.datastructures import SortedDict from django.utils.functional import Promise from rest_framework.compat import force_text -from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata +# from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata import datetime import decimal import types @@ -106,14 +106,14 @@ else: SortedDict, yaml.representer.SafeRepresenter.represent_dict ) - SafeDumper.add_representer( - DictWithMetadata, - yaml.representer.SafeRepresenter.represent_dict - ) - SafeDumper.add_representer( - SortedDictWithMetadata, - yaml.representer.SafeRepresenter.represent_dict - ) + # SafeDumper.add_representer( + # DictWithMetadata, + # yaml.representer.SafeRepresenter.represent_dict + # ) + # SafeDumper.add_representer( + # SortedDictWithMetadata, + # yaml.representer.SafeRepresenter.represent_dict + # ) SafeDumper.add_representer( types.GeneratorType, yaml.representer.SafeRepresenter.represent_list diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py new file mode 100644 index 00000000..bf17050d --- /dev/null +++ b/rest_framework/utils/html.py @@ -0,0 +1,86 @@ +""" +Helpers for dealing with HTML input. +""" + +def is_html_input(dictionary): + # MultiDict type datastructures are used to represent HTML form input, + # which may have more than one value for each key. + return hasattr(dictionary, 'getlist') + + +def parse_html_list(dictionary, prefix=''): + """ + Used to suport list values in HTML forms. + Supports lists of primitives and/or dictionaries. + + * List of primitives. + + { + '[0]': 'abc', + '[1]': 'def', + '[2]': 'hij' + } + --> + [ + 'abc', + 'def', + 'hij' + ] + + * List of dictionaries. + + { + '[0]foo': 'abc', + '[0]bar': 'def', + '[1]foo': 'hij', + '[2]bar': 'klm', + } + --> + [ + {'foo': 'abc', 'bar': 'def'}, + {'foo': 'hij', 'bar': 'klm'} + ] + """ + Dict = type(dictionary) + ret = {} + regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix)) + for field, value in dictionary.items(): + match = regex.match(field) + if not match: + continue + index, key = match.groups() + index = int(index) + if not key: + ret[index] = value + elif isinstance(ret.get(index), dict): + ret[index][key] = value + else: + ret[index] = Dict({key: value}) + return [ret[item] for item in sorted(ret.keys())] + + +def parse_html_dict(dictionary, prefix): + """ + Used to support dictionary values in HTML forms. + + { + 'profile.username': 'example', + 'profile.email': 'example@example.com', + } + --> + { + 'profile': { + 'username': 'example, + 'email': 'example@example.com' + } + } + """ + ret = {} + regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix)) + for field, value in dictionary.items(): + match = regex.match(field) + if not match: + continue + key = match.groups()[0] + ret[key] = value + return ret -- cgit v1.2.3 From ec096a1caceff6a4f5c75a152dd1c7bea9ed281d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 2 Sep 2014 15:07:56 +0100 Subject: Add relations and get tests running --- rest_framework/fields.py | 30 ++++++++++-- rest_framework/mixins.py | 1 - rest_framework/relations.py | 111 ++++++++++++++++++++++++++++++++++++++++++ rest_framework/renderers.py | 14 +++--- rest_framework/serializers.py | 8 +-- rest_framework/utils/html.py | 2 + 6 files changed, 149 insertions(+), 17 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index a83bf94c..3e0f7ca4 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -68,7 +68,7 @@ class Field(object): def __init__(self, read_only=False, write_only=False, required=None, default=empty, initial=None, source=None, - label=None, style=None): + label=None, style=None, error_messages=None): self._creation_counter = Field._creation_counter Field._creation_counter += 1 @@ -216,9 +216,11 @@ class CharField(Field): 'blank': 'This field may not be blank.' } - def __init__(self, *args, **kwargs): + def __init__(self, **kwargs): self.allow_blank = kwargs.pop('allow_blank', False) - super(CharField, self).__init__(*args, **kwargs) + self.max_length = kwargs.pop('max_length', None) + self.min_length = kwargs.pop('min_length', None) + super(CharField, self).__init__(**kwargs) def to_native(self, data): if data == '' and not self.allow_blank: @@ -233,7 +235,7 @@ class ChoiceField(Field): } coerce_to_type = str - def __init__(self, *args, **kwargs): + def __init__(self, **kwargs): choices = kwargs.pop('choices') assert choices, '`choices` argument is required and may not be empty' @@ -257,7 +259,7 @@ class ChoiceField(Field): str(key): key for key in self.choices.keys() } - super(ChoiceField, self).__init__(*args, **kwargs) + super(ChoiceField, self).__init__(**kwargs) def to_native(self, data): try: @@ -296,6 +298,24 @@ class IntegerField(Field): return data +class EmailField(CharField): + pass # TODO + + +class RegexField(CharField): + def __init__(self, **kwargs): + self.regex = kwargs.pop('regex') + super(CharField, self).__init__(**kwargs) + + +class DateTimeField(CharField): + pass # TODO + + +class FileField(Field): + pass # TODO + + class MethodField(Field): def __init__(self, **kwargs): kwargs['source'] = '*' diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index ee01cabc..3e9c9bb3 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -6,7 +6,6 @@ which allows mixin classes to be composed in interesting ways. """ from __future__ import unicode_literals -from django.core.exceptions import ValidationError from django.http import Http404 from rest_framework import status from rest_framework.response import Response diff --git a/rest_framework/relations.py b/rest_framework/relations.py index e69de29b..42d2c121 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -0,0 +1,111 @@ +from rest_framework.fields import Field +from django.core.exceptions import ObjectDoesNotExist +from django.core.urlresolvers import resolve, get_script_prefix +from rest_framework.compat import urlparse + + +def get_default_queryset(serializer_class, field_name): + manager = getattr(serializer_class.opts.model, field_name) + if hasattr(manager, 'related'): + # Forward relationships + return manager.related.model._default_manager.all() + # Reverse relationships + return manager.field.rel.to._default_manager.all() + + +class RelatedField(Field): + def __init__(self, **kwargs): + self.queryset = kwargs.pop('queryset', None) + self.many = kwargs.pop('many', False) + super(RelatedField, self).__init__(**kwargs) + + def bind(self, field_name, parent, root): + super(RelatedField, self).bind(field_name, parent, root) + if self.queryset is None and not self.read_only: + self.queryset = get_default_queryset(parent, self.source) + + +class PrimaryKeyRelatedField(RelatedField): + MESSAGES = { + 'required': 'This field is required.', + 'does_not_exist': "Invalid pk '{pk_value}' - object does not exist.", + 'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.', + } + + def from_native(self, data): + try: + return self.queryset.get(pk=data) + except ObjectDoesNotExist: + self.fail('does_not_exist', pk_value=data) + except (TypeError, ValueError): + self.fail('incorrect_type', data_type=type(data).__name__) + + +class HyperlinkedRelatedField(RelatedField): + lookup_field = 'pk' + + MESSAGES = { + 'required': 'This field is required.', + 'no_match': 'Invalid hyperlink - No URL match', + 'incorrect_match': 'Invalid hyperlink - Incorrect URL match.', + 'does_not_exist': "Invalid hyperlink - Object does not exist.", + 'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.', + } + + def __init__(self, **kwargs): + self.view_name = kwargs.pop('view_name') + self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) + super(HyperlinkedRelatedField, self).__init__(**kwargs) + + def get_object(self, view_name, view_args, view_kwargs): + """ + Return the object corresponding to a matched URL. + + Takes the matched URL conf arguments, and should return an + object instance, or raise an `ObjectDoesNotExist` exception. + """ + lookup_value = view_kwargs[self.lookup_url_kwarg] + lookup_kwargs = {self.lookup_field: lookup_value} + return self.queryset.get(**lookup_kwargs) + + def from_native(self, value): + try: + http_prefix = value.startswith(('http:', 'https:')) + except AttributeError: + self.fail('incorrect_type', type(value).__name__) + + if http_prefix: + # If needed convert absolute URLs to relative path + value = urlparse.urlparse(value).path + prefix = get_script_prefix() + if value.startswith(prefix): + value = '/' + value[len(prefix):] + + try: + match = resolve(value) + except Exception: + self.fail('no_match') + + if match.view_name != self.view_name: + self.fail('incorrect_match') + + try: + return self.get_object(match.view_name, match.args, match.kwargs) + except (ObjectDoesNotExist, TypeError, ValueError): + self.fail('does_not_exist') + + +class HyperlinkedIdentityField(RelatedField): + lookup_field = 'pk' + + def __init__(self, **kwargs): + self.view_name = kwargs.pop('view_name') + self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) + super(HyperlinkedIdentityField, self).__init__(**kwargs) + + +class SlugRelatedField(RelatedField): + def __init__(self, **kwargs): + self.slug_field = kwargs.pop('slug_field', None) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index e8935b01..dfc5a39f 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -436,13 +436,13 @@ class BrowsableAPIRenderer(BaseRenderer): if request.method == method: try: data = request.DATA - files = request.FILES + # files = request.FILES except ParseError: data = None - files = None + # files = None else: data = None - files = None + # files = None with override_method(view, request, method) as request: obj = getattr(view, 'object', None) @@ -579,10 +579,10 @@ class BrowsableAPIRenderer(BaseRenderer): 'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes], 'response_headers': response_headers, - #'put_form': self.get_rendered_html_form(view, 'PUT', request), - #'post_form': self.get_rendered_html_form(view, 'POST', request), - #'delete_form': self.get_rendered_html_form(view, 'DELETE', request), - #'options_form': self.get_rendered_html_form(view, 'OPTIONS', request), + # 'put_form': self.get_rendered_html_form(view, 'PUT', request), + # 'post_form': self.get_rendered_html_form(view, 'POST', request), + # 'delete_form': self.get_rendered_html_form(view, 'DELETE', request), + # 'options_form': self.get_rendered_html_form(view, 'OPTIONS', request), 'raw_data_put_form': raw_data_put_form, 'raw_data_post_form': raw_data_post_form, diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d121812d..2f23b4d9 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -477,8 +477,8 @@ class ModelSerializer(Serializer): if model_field: kwargs['required'] = not(model_field.null or model_field.blank) - # if model_field.help_text is not None: - # kwargs['help_text'] = model_field.help_text + # if model_field.help_text is not None: + # kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name if not model_field.editable: @@ -566,8 +566,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializer(ModelSerializer): _options_class = HyperlinkedModelSerializerOptions _default_view_name = '%(model_name)s-detail' - #_hyperlink_field_class = HyperlinkedRelatedField - #_hyperlink_identify_field_class = HyperlinkedIdentityField + # _hyperlink_field_class = HyperlinkedRelatedField + # _hyperlink_identify_field_class = HyperlinkedIdentityField def get_default_fields(self): fields = super(HyperlinkedModelSerializer, self).get_default_fields() diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py index bf17050d..edc591e9 100644 --- a/rest_framework/utils/html.py +++ b/rest_framework/utils/html.py @@ -1,6 +1,8 @@ """ Helpers for dealing with HTML input. """ +import re + def is_html_input(dictionary): # MultiDict type datastructures are used to represent HTML form input, -- cgit v1.2.3 From f2852811f93863f2eed04d51eeb7ef27716b2409 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 2 Sep 2014 17:41:23 +0100 Subject: Getting tests passing --- rest_framework/authtoken/serializers.py | 5 ++-- rest_framework/authtoken/views.py | 3 +- rest_framework/fields.py | 50 ++++++++++++++++++++++++++++++- rest_framework/mixins.py | 41 +++---------------------- rest_framework/pagination.py | 36 ++++++++-------------- rest_framework/relations.py | 2 +- rest_framework/serializers.py | 53 ++++++++++++++++++++------------- 7 files changed, 103 insertions(+), 87 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 99e99ae3..edeae857 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -19,11 +19,12 @@ class AuthTokenSerializer(serializers.Serializer): if not user.is_active: msg = _('User account is disabled.') raise serializers.ValidationError(msg) - attrs['user'] = user - return attrs else: msg = _('Unable to login with provided credentials.') raise serializers.ValidationError(msg) else: msg = _('Must include "username" and "password"') raise serializers.ValidationError(msg) + + attrs['user'] = user + return attrs diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 7c03cb76..94e6f061 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -18,7 +18,8 @@ class ObtainAuthToken(APIView): def post(self, request): serializer = self.serializer_class(data=request.DATA) if serializer.is_valid(): - token, created = Token.objects.get_or_create(user=serializer.object['user']) + user = serializer.validated_data['user'] + token, created = Token.objects.get_or_create(user=user) return Response({'token': token.key}) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 3e0f7ca4..838aa3b0 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,4 +1,5 @@ from rest_framework.utils import html +import inspect class empty: @@ -11,6 +12,22 @@ class empty: pass +def is_simple_callable(obj): + """ + True if the object is a callable that takes no arguments. + """ + function = inspect.isfunction(obj) + method = inspect.ismethod(obj) + + if not (function or method): + return False + + args, _, _, defaults = inspect.getargspec(obj) + len_args = len(args) if function else len(args) - 1 + len_defaults = len(defaults) if defaults else 0 + return len_args <= len_defaults + + def get_attribute(instance, attrs): """ Similar to Python's built in `getattr(instance, attr)`, @@ -98,6 +115,7 @@ class Field(object): self.field_name = field_name self.parent = parent self.root = root + self.context = parent.context # `self.label` should deafult to being based on the field name. if self.label is None: @@ -297,25 +315,55 @@ class IntegerField(Field): self.fail('invalid_integer') return data + def to_primative(self, value): + if value is None: + return None + return int(value) + class EmailField(CharField): pass # TODO +class URLField(CharField): + pass # TODO + + class RegexField(CharField): def __init__(self, **kwargs): self.regex = kwargs.pop('regex') super(CharField, self).__init__(**kwargs) +class DateField(CharField): + def __init__(self, **kwargs): + self.input_formats = kwargs.pop('input_formats', None) + super(DateField, self).__init__(**kwargs) + + +class TimeField(CharField): + def __init__(self, **kwargs): + self.input_formats = kwargs.pop('input_formats', None) + super(TimeField, self).__init__(**kwargs) + + class DateTimeField(CharField): - pass # TODO + def __init__(self, **kwargs): + self.input_formats = kwargs.pop('input_formats', None) + super(DateTimeField, self).__init__(**kwargs) class FileField(Field): pass # TODO +class ReadOnlyField(Field): + def to_primative(self, value): + if is_simple_callable(value): + return value() + return value + + class MethodField(Field): def __init__(self, **kwargs): kwargs['source'] = '*' diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 3e9c9bb3..359740ce 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -13,23 +13,6 @@ from rest_framework.request import clone_request from rest_framework.settings import api_settings -def _get_validation_exclusions(obj, lookup_field=None): - """ - Given a model instance, and an optional pk and slug field, - return the full list of all other field names on that model. - - For use when performing full_clean on a model instance, - so we only clean the required fields. - """ - if lookup_field == 'pk': - pk_field = obj._meta.pk - while pk_field.rel: - pk_field = pk_field.rel.to._meta.pk - lookup_field = pk_field.name - - return [field.name for field in obj._meta.fields if field.name != lookup_field] - - class CreateModelMixin(object): """ Create a model instance. @@ -92,15 +75,14 @@ class UpdateModelMixin(object): if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - lookup_value = self.kwargs[lookup_url_kwarg] - extras = {self.lookup_field: lookup_value} - if self.object is None: + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + lookup_value = self.kwargs[lookup_url_kwarg] + extras = {self.lookup_field: lookup_value} self.object = serializer.save(extras=extras) return Response(serializer.data, status=status.HTTP_201_CREATED) - self.object = serializer.save(extras=extras) + self.object = serializer.save() return Response(serializer.data, status=status.HTTP_200_OK) def partial_update(self, request, *args, **kwargs): @@ -122,21 +104,6 @@ class UpdateModelMixin(object): # return a 404 response. raise - def pre_save(self, obj): - """ - Set any attributes on the object that are implicit in the request. - """ - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - lookup_value = self.kwargs[lookup_url_kwarg] - - setattr(obj, self.lookup_field, lookup_value) - - # Ensure we clean the attributes so that we don't eg return integer - # pk using a string representation, as provided by the url conf kwarg. - if hasattr(obj, 'full_clean'): - exclude = _get_validation_exclusions(obj, self.lookup_field) - obj.full_clean(exclude) - class DestroyModelMixin(object): """ diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 83ef97c5..478d32b4 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -13,7 +13,7 @@ class NextPageField(serializers.Field): """ page_field = 'page' - def to_native(self, value): + def to_primative(self, value): if not value.has_next(): return None page = value.next_page_number() @@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field): """ page_field = 'page' - def to_native(self, value): + def to_primative(self, value): if not value.has_previous(): return None page = value.previous_page_number() @@ -48,25 +48,11 @@ class DefaultObjectSerializer(serializers.Field): super(DefaultObjectSerializer, self).__init__(source=source) -# class PaginationSerializerOptions(serializers.SerializerOptions): -# """ -# An object that stores the options that may be provided to a -# pagination serializer by using the inner `Meta` class. - -# Accessible on the instance as `serializer.opts`. -# """ -# def __init__(self, meta): -# super(PaginationSerializerOptions, self).__init__(meta) -# self.object_serializer_class = getattr(meta, 'object_serializer_class', -# DefaultObjectSerializer) - - class BasePaginationSerializer(serializers.Serializer): """ A base class for pagination serializers to inherit from, to make implementing custom serializers more easy. """ - # _options_class = PaginationSerializerOptions results_field = 'results' def __init__(self, *args, **kwargs): @@ -75,14 +61,16 @@ class BasePaginationSerializer(serializers.Serializer): """ super(BasePaginationSerializer, self).__init__(*args, **kwargs) results_field = self.results_field - object_serializer = self.opts.object_serializer_class - - if 'context' in kwargs: - context_kwarg = {'context': kwargs['context']} - else: - context_kwarg = {} - - self.fields[results_field] = object_serializer(source='object_list', **context_kwarg) + try: + object_serializer = self.Meta.object_serializer_class + except AttributeError: + object_serializer = DefaultObjectSerializer + + self.fields[results_field] = serializers.ListSerializer( + child=object_serializer(), + source='object_list' + ) + self.fields[results_field].bind(results_field, self, self) # TODO: Support automatic binding class PaginationSerializer(BasePaginationSerializer): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 42d2c121..0b01394a 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -73,7 +73,7 @@ class HyperlinkedRelatedField(RelatedField): try: http_prefix = value.startswith(('http:', 'https:')) except AttributeError: - self.fail('incorrect_type', type(value).__name__) + self.fail('incorrect_type', data_type=type(value).__name__) if http_prefix: # If needed convert absolute URLs to relative path diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 2f23b4d9..c38d8968 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -142,7 +142,7 @@ class Serializer(BaseSerializer): return super(Serializer, cls).__new__(cls) def __init__(self, *args, **kwargs): - kwargs.pop('context', None) + self.context = kwargs.pop('context', {}) kwargs.pop('partial', None) kwargs.pop('many', False) @@ -202,7 +202,7 @@ class Serializer(BaseSerializer): if errors: raise ValidationError(errors) - return ret + return self.validate(ret) def to_primative(self, instance): """ @@ -217,6 +217,9 @@ class Serializer(BaseSerializer): return ret + def validate(self, attrs): + return attrs + def __iter__(self): errors = self.errors if hasattr(self, '_errors') else {} for field in self.fields.values(): @@ -232,8 +235,7 @@ class ListSerializer(BaseSerializer): def __init__(self, *args, **kwargs): self.child = kwargs.pop('child', copy.deepcopy(self.child)) assert self.child is not None, '`child` is a required argument.' - - kwargs.pop('context', None) + self.context = kwargs.pop('context', {}) kwargs.pop('partial', None) super(ListSerializer, self).__init__(*args, **kwargs) @@ -316,19 +318,19 @@ class ModelSerializer(Serializer): models.PositiveIntegerField: IntegerField, models.SmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField, - # models.DateTimeField: DateTimeField, - # models.DateField: DateField, - # models.TimeField: TimeField, + models.DateTimeField: DateTimeField, + models.DateField: DateField, + models.TimeField: TimeField, # models.DecimalField: DecimalField, - # models.EmailField: EmailField, + models.EmailField: EmailField, models.CharField: CharField, - # models.URLField: URLField, + models.URLField: URLField, # models.SlugField: SlugField, models.TextField: CharField, models.CommaSeparatedIntegerField: CharField, models.BooleanField: BooleanField, models.NullBooleanField: BooleanField, - # models.FileField: FileField, + models.FileField: FileField, # models.ImageField: ImageField, } @@ -338,6 +340,15 @@ class ModelSerializer(Serializer): self.opts = self._options_class(self.Meta) super(ModelSerializer, self).__init__(*args, **kwargs) + def create(self): + ModelClass = self.opts.model + return ModelClass.objects.create(**self.validated_data) + + def update(self, obj): + for attr, value in self.validated_data.items(): + setattr(obj, attr, value) + obj.save() + def get_fields(self): # Get the explicitly declared fields. fields = copy.deepcopy(self.base_fields) @@ -566,8 +577,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializer(ModelSerializer): _options_class = HyperlinkedModelSerializerOptions _default_view_name = '%(model_name)s-detail' - # _hyperlink_field_class = HyperlinkedRelatedField - # _hyperlink_identify_field_class = HyperlinkedIdentityField + _hyperlink_field_class = HyperlinkedRelatedField + _hyperlink_identify_field_class = HyperlinkedIdentityField def get_default_fields(self): fields = super(HyperlinkedModelSerializer, self).get_default_fields() @@ -575,15 +586,15 @@ class HyperlinkedModelSerializer(ModelSerializer): if self.opts.view_name is None: self.opts.view_name = self._get_default_view_name(self.opts.model) - # if self.opts.url_field_name not in fields: - # url_field = self._hyperlink_identify_field_class( - # view_name=self.opts.view_name, - # lookup_field=self.opts.lookup_field - # ) - # ret = self._dict_class() - # ret[self.opts.url_field_name] = url_field - # ret.update(fields) - # fields = ret + if self.opts.url_field_name not in fields: + url_field = self._hyperlink_identify_field_class( + view_name=self.opts.view_name, + lookup_field=self.opts.lookup_field + ) + ret = fields.__class__() + ret[self.opts.url_field_name] = url_field + ret.update(fields) + fields = ret return fields -- cgit v1.2.3 From c1036c17533a3091401ff90f825571f0e6125eca Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 3 Sep 2014 16:34:09 +0100 Subject: More test passing --- rest_framework/relations.py | 56 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 0b01394a..661a1249 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,6 +1,7 @@ from rest_framework.fields import Field +from rest_framework.reverse import reverse from django.core.exceptions import ObjectDoesNotExist -from django.core.urlresolvers import resolve, get_script_prefix +from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch from rest_framework.compat import urlparse @@ -100,11 +101,64 @@ class HyperlinkedIdentityField(RelatedField): lookup_field = 'pk' def __init__(self, **kwargs): + kwargs['read_only'] = True self.view_name = kwargs.pop('view_name') self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) super(HyperlinkedIdentityField, self).__init__(**kwargs) + def get_attribute(self, instance): + return instance + + def to_primative(self, value): + request = self.context.get('request', None) + format = self.context.get('format', None) + + assert request is not None, ( + "`HyperlinkedIdentityField` requires the request in the serializer" + " context. Add `context={'request': request}` when instantiating " + "the serializer." + ) + + # By default use whatever format is given for the current context + # unless the target is a different type to the source. + # + # Eg. Consider a HyperlinkedIdentityField pointing from a json + # representation to an html property of that representation... + # + # '/snippets/1/' should link to '/snippets/1/highlight/' + # ...but... + # '/snippets/1/.json' should link to '/snippets/1/highlight/.html' + if format and self.format and self.format != format: + format = self.format + + # Return the hyperlink, or error if incorrectly configured. + try: + return self.get_url(value, self.view_name, request, format) + except NoReverseMatch: + msg = ( + 'Could not resolve URL for hyperlinked relationship using ' + 'view name "%s". You may have failed to include the related ' + 'model in your API, or incorrectly configured the ' + '`lookup_field` attribute on this field.' + ) + raise Exception(msg % self.view_name) + + def get_url(self, obj, view_name, request, format): + """ + Given an object, return the URL that hyperlinks to the object. + + May raise a `NoReverseMatch` if the `view_name` and `lookup_field` + attributes are not configured to correctly match the URL conf. + """ + # Unsaved objects will not yet have a valid URL. + if obj.pk is None: + return None + + lookup_value = getattr(obj, self.lookup_field) + kwargs = {self.lookup_url_kwarg: lookup_value} + return reverse(view_name, kwargs=kwargs, request=request, format=format) + class SlugRelatedField(RelatedField): def __init__(self, **kwargs): -- cgit v1.2.3 From d934824bff21e4a11226af61efba319be227f4f0 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 5 Sep 2014 16:29:46 +0100 Subject: Workin on --- rest_framework/exceptions.py | 9 +++-- rest_framework/fields.py | 26 +++++++++++--- rest_framework/generics.py | 30 +--------------- rest_framework/mixins.py | 83 ++++++++++++++++++++++++------------------- rest_framework/serializers.py | 64 +++++++++++++++++++++------------ 5 files changed, 118 insertions(+), 94 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index ad52d172..852a08b1 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -15,7 +15,7 @@ class APIException(Exception): Subclasses should provide `.status_code` and `.default_detail` properties. """ status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - default_detail = '' + default_detail = 'A server error occured' def __init__(self, detail=None): self.detail = detail or self.default_detail @@ -29,6 +29,11 @@ class ParseError(APIException): default_detail = 'Malformed request.' +class ValidationError(APIException): + status_code = status.HTTP_400_BAD_REQUEST + default_detail = 'Invalid data in request.' + + class AuthenticationFailed(APIException): status_code = status.HTTP_401_UNAUTHORIZED default_detail = 'Incorrect authentication credentials.' @@ -54,7 +59,7 @@ class MethodNotAllowed(APIException): class NotAcceptable(APIException): status_code = status.HTTP_406_NOT_ACCEPTABLE - default_detail = "Could not satisfy the request's Accept header" + default_detail = "Could not satisfy the request Accept header" def __init__(self, detail=None, available_renderers=None): self.detail = detail or self.default_detail diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 838aa3b0..d18551b3 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,3 +1,4 @@ +from rest_framework.exceptions import ValidationError from rest_framework.utils import html import inspect @@ -59,10 +60,6 @@ def set_value(dictionary, keys, value): dictionary[keys[-1]] = value -class ValidationError(Exception): - pass - - class SkipField(Exception): pass @@ -204,6 +201,22 @@ class Field(object): msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) raise AssertionError(msg) + def __new__(cls, *args, **kwargs): + instance = super(Field, cls).__new__(cls) + instance._args = args + instance._kwargs = kwargs + return instance + + def __repr__(self): + arg_string = ', '.join([repr(val) for val in self._args]) + kwarg_string = ', '.join([ + '%s=%s' % (key, repr(val)) for key, val in self._kwargs.items() + ]) + if arg_string and kwarg_string: + arg_string += ', ' + class_name = self.__class__.__name__ + return "%s(%s%s)" % (class_name, arg_string, kwarg_string) + class BooleanField(Field): MESSAGES = { @@ -308,6 +321,11 @@ class IntegerField(Field): 'invalid_integer': 'A valid integer is required.' } + def __init__(self, **kwargs): + self.max_value = kwargs.pop('max_value') + self.min_value = kwargs.pop('min_value') + super(CharField, self).__init__(**kwargs) + def to_native(self, data): try: data = int(str(data)) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 6705cbb2..c2c59154 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -27,7 +27,7 @@ def strict_positive_int(integer_string, cutoff=None): def get_object_or_404(queryset, *filter_args, **filter_kwargs): """ - Same as Django's standard shortcut, but make sure to raise 404 + Same as Django's standard shortcut, but make sure to also raise 404 if the filter_kwargs don't match the required types. """ try: @@ -249,34 +249,6 @@ class GenericAPIView(views.APIView): # # The are not called by GenericAPIView directly, # but are used by the mixin methods. - - def pre_save(self, obj): - """ - Placeholder method for calling before saving an object. - - May be used to set attributes on the object that are implicit - in either the request, or the url. - """ - pass - - def post_save(self, obj, created=False): - """ - Placeholder method for calling after saving an object. - """ - pass - - def pre_delete(self, obj): - """ - Placeholder method for calling before deleting an object. - """ - pass - - def post_delete(self, obj): - """ - Placeholder method for calling after deleting an object. - """ - pass - def metadata(self, request): """ Return a dictionary of metadata about the view. diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 359740ce..14a6b44b 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -19,14 +19,10 @@ class CreateModelMixin(object): """ def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.DATA) - - if serializer.is_valid(): - self.object = serializer.save() - headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, - headers=headers) - - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + serializer.is_valid(raise_exception=True) + serializer.save() + headers = self.get_success_headers(serializer.data) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) def get_success_headers(self, data): try: @@ -40,15 +36,12 @@ class ListModelMixin(object): List a queryset. """ def list(self, request, *args, **kwargs): - self.object_list = self.filter_queryset(self.get_queryset()) - - # Switch between paginated or standard style responses - page = self.paginate_queryset(self.object_list) + instance = self.filter_queryset(self.get_queryset()) + page = self.paginate_queryset(instance) if page is not None: serializer = self.get_pagination_serializer(page) else: - serializer = self.get_serializer(self.object_list, many=True) - + serializer = self.get_serializer(instance, many=True) return Response(serializer.data) @@ -57,8 +50,8 @@ class RetrieveModelMixin(object): Retrieve a model instance. """ def retrieve(self, request, *args, **kwargs): - self.object = self.get_object() - serializer = self.get_serializer(self.object) + instance = self.get_object() + serializer = self.get_serializer(instance) return Response(serializer.data) @@ -68,22 +61,52 @@ class UpdateModelMixin(object): """ def update(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) - self.object = self.get_object_or_none() + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.DATA, partial=partial) + serializer.is_valid(raise_exception=True) + serializer.save() + return Response(serializer.data) + + def partial_update(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + - serializer = self.get_serializer(self.object, data=request.DATA, partial=partial) +class DestroyModelMixin(object): + """ + Destroy a model instance. + """ + def destroy(self, request, *args, **kwargs): + instance = self.get_object() + instance.delete() + return Response(status=status.HTTP_204_NO_CONTENT) - if not serializer.is_valid(): - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - if self.object is None: +# The AllowPUTAsCreateMixin was previously the default behaviour +# for PUT requests. This has now been removed and must be *explictly* +# included if it is the behavior that you want. +# For more info see: ... + +class AllowPUTAsCreateMixin(object): + """ + The following mixin class may be used in order to support PUT-as-create + behavior for incoming requests. + """ + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + instance = self.get_object_or_none() + serializer = self.get_serializer(instance, data=request.DATA, partial=partial) + serializer.is_valid(raise_exception=True) + + if instance is None: lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field lookup_value = self.kwargs[lookup_url_kwarg] extras = {self.lookup_field: lookup_value} - self.object = serializer.save(extras=extras) + serializer.save(extras=extras) return Response(serializer.data, status=status.HTTP_201_CREATED) - self.object = serializer.save() - return Response(serializer.data, status=status.HTTP_200_OK) + serializer.save() + return Response(serializer.data) def partial_update(self, request, *args, **kwargs): kwargs['partial'] = True @@ -103,15 +126,3 @@ class UpdateModelMixin(object): # PATCH requests where the object does not exist should still # return a 404 response. raise - - -class DestroyModelMixin(object): - """ - Destroy a model instance. - """ - def destroy(self, request, *args, **kwargs): - obj = self.get_object() - self.pre_delete(obj) - obj.delete() - self.post_delete(obj) - return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index c38d8968..49eb6ce9 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -13,7 +13,8 @@ response content is handled by parsers and renderers. from django.db import models from django.utils import six from collections import namedtuple, OrderedDict -from rest_framework.fields import empty, set_value, Field, SkipField, ValidationError +from rest_framework.exceptions import ValidationError +from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings from rest_framework.utils import html import copy @@ -34,43 +35,53 @@ FieldResult = namedtuple('FieldResult', ['field', 'value', 'error']) class BaseSerializer(Field): + """ + The BaseSerializer class provides a minimal class which may be used + for writing custom serializer implementations. + """ + def __init__(self, instance=None, data=None, **kwargs): super(BaseSerializer, self).__init__(**kwargs) self.instance = instance self._initial_data = data def to_native(self, data): - raise NotImplementedError() + raise NotImplementedError('`to_native()` must be implemented.') def to_primative(self, instance): - raise NotImplementedError() + raise NotImplementedError('`to_primative()` must be implemented.') - def update(self, instance): - raise NotImplementedError() + def update(self, instance, attrs): + raise NotImplementedError('`update()` must be implemented.') - def create(self): - raise NotImplementedError() + def create(self, attrs): + raise NotImplementedError('`create()` must be implemented.') def save(self, extras=None): if extras is not None: - self._validated_data.update(extras) + self.validated_data.update(extras) if self.instance is not None: - self.update(self.instance) + self.update(self.instance, self._validated_data) else: - self.instance = self.create() + self.instance = self.create(self._validated_data) return self.instance - def is_valid(self): - try: - self._validated_data = self.to_native(self._initial_data) - except ValidationError as exc: - self._validated_data = {} - self._errors = exc.args[0] - return False - self._errors = {} - return True + def is_valid(self, raise_exception=False): + if not hasattr(self, '_validated_data'): + try: + self._validated_data = self.to_native(self._initial_data) + except ValidationError as exc: + self._validated_data = {} + self._errors = exc.detail + else: + self._errors = {} + + if self._errors and raise_exception: + raise ValidationError(self._errors) + + return not bool(self._errors) @property def data(self): @@ -184,14 +195,20 @@ class Serializer(BaseSerializer): """ Dict of native values <- Dict of primitive datatypes. """ + if not isinstance(data, dict): + raise ValidationError({'non_field_errors': ['Invalid data']}) + ret = {} errors = {} fields = [field for field in self.fields.values() if not field.read_only] for field in fields: + validate_method = getattr(self, 'validate_' + field.field_name, None) primitive_value = field.get_value(data) try: validated_value = field.validate(primitive_value) + if validate_method is not None: + validated_value = validate_method(validated_value) except ValidationError as exc: errors[field.field_name] = str(exc) except SkipField: @@ -202,6 +219,7 @@ class Serializer(BaseSerializer): if errors: raise ValidationError(errors) + # TODO: 'Non field errors' return self.validate(ret) def to_primative(self, instance): @@ -340,12 +358,12 @@ class ModelSerializer(Serializer): self.opts = self._options_class(self.Meta) super(ModelSerializer, self).__init__(*args, **kwargs) - def create(self): + def create(self, attrs): ModelClass = self.opts.model - return ModelClass.objects.create(**self.validated_data) + return ModelClass.objects.create(**attrs) - def update(self, obj): - for attr, value in self.validated_data.items(): + def update(self, obj, attrs): + for attr, value in attrs.items(): setattr(obj, attr, value) obj.save() -- cgit v1.2.3 From 21980b800d04a1d82a6003823abfdf4ab80ae979 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 8 Sep 2014 14:24:05 +0100 Subject: More test sorting --- rest_framework/exceptions.py | 5 --- rest_framework/fields.py | 85 ++++++++++++++++++++++++++++++++++++++++--- rest_framework/serializers.py | 25 ++++++++----- rest_framework/views.py | 9 ++++- 4 files changed, 101 insertions(+), 23 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 852a08b1..06b5e8a2 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -29,11 +29,6 @@ class ParseError(APIException): default_detail = 'Malformed request.' -class ValidationError(APIException): - status_code = status.HTTP_400_BAD_REQUEST - default_detail = 'Invalid data in request.' - - class AuthenticationFailed(APIException): status_code = status.HTTP_401_UNAUTHORIZED default_detail = 'Incorrect authentication credentials.' diff --git a/rest_framework/fields.py b/rest_framework/fields.py index d18551b3..250c0579 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,4 +1,6 @@ -from rest_framework.exceptions import ValidationError +from django.core import validators +from django.core.exceptions import ValidationError +from django.utils.encoding import is_protected_type from rest_framework.utils import html import inspect @@ -33,9 +35,14 @@ def get_attribute(instance, attrs): """ Similar to Python's built in `getattr(instance, attr)`, but takes a list of nested attributes, instead of a single attribute. + + Also accepts either attribute lookup on objects or dictionary lookups. """ for attr in attrs: - instance = getattr(instance, attr) + try: + instance = getattr(instance, attr) + except AttributeError: + return instance[attr] return instance @@ -80,9 +87,11 @@ class Field(object): 'not exist in the `MESSAGES` dictionary.' ) + default_validators = [] + def __init__(self, read_only=False, write_only=False, required=None, default=empty, initial=None, source=None, - label=None, style=None, error_messages=None): + label=None, style=None, error_messages=None, validators=[]): self._creation_counter = Field._creation_counter Field._creation_counter += 1 @@ -104,6 +113,7 @@ class Field(object): self.initial = initial self.label = label self.style = {} if style is None else style + self.validators = self.default_validators + validators def bind(self, field_name, parent, root): """ @@ -176,8 +186,21 @@ class Field(object): self.fail('required') return self.get_default() + self.run_validators(data) return self.to_native(data) + def run_validators(self, value): + if value in validators.EMPTY_VALUES: + return + errors = [] + for validator in self.validators: + try: + validator(value) + except ValidationError as exc: + errors.extend(exc.messages) + if errors: + raise ValidationError(errors) + def to_native(self, data): """ Transform the *incoming* primative data into a native value. @@ -322,9 +345,13 @@ class IntegerField(Field): } def __init__(self, **kwargs): - self.max_value = kwargs.pop('max_value') - self.min_value = kwargs.pop('min_value') - super(CharField, self).__init__(**kwargs) + max_value = kwargs.pop('max_value', None) + min_value = kwargs.pop('min_value', None) + super(IntegerField, self).__init__(**kwargs) + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) def to_native(self, data): try: @@ -392,3 +419,49 @@ class MethodField(Field): attr = 'get_{field_name}'.format(field_name=self.field_name) method = getattr(self.parent, attr) return method(value) + + +class ModelField(Field): + """ + A generic field that can be used against an arbitrary model field. + """ + def __init__(self, *args, **kwargs): + try: + self.model_field = kwargs.pop('model_field') + except KeyError: + raise ValueError("ModelField requires 'model_field' kwarg") + + self.min_length = kwargs.pop('min_length', + getattr(self.model_field, 'min_length', None)) + self.max_length = kwargs.pop('max_length', + getattr(self.model_field, 'max_length', None)) + self.min_value = kwargs.pop('min_value', + getattr(self.model_field, 'min_value', None)) + self.max_value = kwargs.pop('max_value', + getattr(self.model_field, 'max_value', None)) + + super(ModelField, self).__init__(*args, **kwargs) + + if self.min_length is not None: + self.validators.append(validators.MinLengthValidator(self.min_length)) + if self.max_length is not None: + self.validators.append(validators.MaxLengthValidator(self.max_length)) + if self.min_value is not None: + self.validators.append(validators.MinValueValidator(self.min_value)) + if self.max_value is not None: + self.validators.append(validators.MaxValueValidator(self.max_value)) + + def get_attribute(self, instance): + return get_attribute(instance, self.source_attrs[:-1]) + + def to_native(self, data): + rel = getattr(self.model_field, 'rel', None) + if rel is not None: + return rel.to._meta.get_field(rel.field_name).to_python(data) + return self.model_field.to_python(data) + + def to_primative(self, obj): + value = self.model_field._get_val_from_obj(obj) + if is_protected_type(value): + return value + return self.model_field.value_to_string(obj) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 49eb6ce9..93226d32 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,10 +10,10 @@ python primitives. 2. The process of marshalling between python primitives and request and response content is handled by parsers and renderers. """ +from django.core.exceptions import ValidationError from django.db import models from django.utils import six from collections import namedtuple, OrderedDict -from rest_framework.exceptions import ValidationError from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings from rest_framework.utils import html @@ -58,13 +58,14 @@ class BaseSerializer(Field): raise NotImplementedError('`create()` must be implemented.') def save(self, extras=None): + attrs = self.validated_data if extras is not None: - self.validated_data.update(extras) + attrs = dict(list(attrs.items()) + list(extras.items())) if self.instance is not None: - self.update(self.instance, self._validated_data) + self.update(self.instance, attrs) else: - self.instance = self.create(self._validated_data) + self.instance = self.create(attrs) return self.instance @@ -74,7 +75,7 @@ class BaseSerializer(Field): self._validated_data = self.to_native(self._initial_data) except ValidationError as exc: self._validated_data = {} - self._errors = exc.detail + self._errors = exc.message_dict else: self._errors = {} @@ -210,7 +211,7 @@ class Serializer(BaseSerializer): if validate_method is not None: validated_value = validate_method(validated_value) except ValidationError as exc: - errors[field.field_name] = str(exc) + errors[field.field_name] = exc.messages except SkipField: pass else: @@ -219,8 +220,10 @@ class Serializer(BaseSerializer): if errors: raise ValidationError(errors) - # TODO: 'Non field errors' - return self.validate(ret) + try: + return self.validate(ret) + except ValidationError, exc: + raise ValidationError({'non_field_errors': exc.messages}) def to_primative(self, instance): """ @@ -539,6 +542,9 @@ class ModelSerializer(Serializer): if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name + if model_field.validators is not None: + kwargs['validators'] = model_field.validators + # if model_field.help_text is not None: # kwargs['help_text'] = model_field.help_text @@ -577,8 +583,7 @@ class ModelSerializer(Serializer): try: return self.field_mapping[model_field.__class__](**kwargs) except KeyError: - # TODO: Change this to `return ModelField(model_field=model_field, **kwargs)` - return CharField(**kwargs) + return ModelField(model_field=model_field, **kwargs) class HyperlinkedModelSerializerOptions(ModelSerializerOptions): diff --git a/rest_framework/views.py b/rest_framework/views.py index 23df3443..079e9285 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework. """ from __future__ import unicode_literals -from django.core.exceptions import PermissionDenied +from django.core.exceptions import PermissionDenied, ValidationError from django.http import Http404 from django.utils.datastructures import SortedDict from django.views.decorators.csrf import csrf_exempt @@ -51,7 +51,8 @@ def exception_handler(exc): Returns the response that should be used for any given exception. By default we handle the REST framework `APIException`, and also - Django's builtin `Http404` and `PermissionDenied` exceptions. + Django's built-in `ValidationError`, `Http404` and `PermissionDenied` + exceptions. Any unhandled exceptions may return `None`, which will cause a 500 error to be raised. @@ -68,6 +69,10 @@ def exception_handler(exc): status=exc.status_code, headers=headers) + elif isinstance(exc, ValidationError): + return Response(exc.message_dict, + status=status.HTTP_400_BAD_REQUEST) + elif isinstance(exc, Http404): return Response({'detail': 'Not found'}, status=status.HTTP_404_NOT_FOUND) -- cgit v1.2.3 From b1c07670ca65084c5fef2bbb63d1f4163763014b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Sep 2014 17:46:28 +0100 Subject: Fleshing out serializer fields --- rest_framework/fields.py | 591 +++++++++++++++++++++++------- rest_framework/serializers.py | 380 +++++++++---------- rest_framework/utils/humanize_datetime.py | 47 +++ rest_framework/utils/modelinfo.py | 97 +++++ rest_framework/utils/representation.py | 72 ++++ 5 files changed, 872 insertions(+), 315 deletions(-) create mode 100644 rest_framework/utils/humanize_datetime.py create mode 100644 rest_framework/utils/modelinfo.py create mode 100644 rest_framework/utils/representation.py (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 250c0579..043a44ed 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,8 +1,18 @@ +from django.conf import settings from django.core import validators from django.core.exceptions import ValidationError +from django.utils import timezone +from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import is_protected_type -from rest_framework.utils import html +from django.utils.translation import ugettext_lazy as _ +from rest_framework import ISO_8601 +from rest_framework.compat import smart_text +from rest_framework.settings import api_settings +from rest_framework.utils import html, representation, humanize_datetime +import datetime +import decimal import inspect +import warnings class empty: @@ -71,22 +81,22 @@ class SkipField(Exception): pass +NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' +NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' +NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' +NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' +MISSING_ERROR_MESSAGE = ( + 'ValidationError raised by `{class_name}`, but error key `{key}` does ' + 'not exist in the `error_messages` dictionary.' +) + + class Field(object): _creation_counter = 0 - MESSAGES = { - 'required': 'This field is required.' + default_error_messages = { + 'required': _('This field is required.') } - - _NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' - _NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' - _NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' - _NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' - _MISSING_ERROR_MESSAGE = ( - 'ValidationError raised by `{class_name}`, but error key `{key}` does ' - 'not exist in the `MESSAGES` dictionary.' - ) - default_validators = [] def __init__(self, read_only=False, write_only=False, @@ -100,10 +110,10 @@ class Field(object): required = default is empty and not read_only # Some combinations of keyword arguments do not make sense. - assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY - assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED - assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT - assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT + assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY + assert not (read_only and required), NOT_READ_ONLY_REQUIRED + assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT + assert not (required and default is not empty), NOT_REQUIRED_DEFAULT self.read_only = read_only self.write_only = write_only @@ -113,7 +123,14 @@ class Field(object): self.initial = initial self.label = label self.style = {} if style is None else style - self.validators = self.default_validators + validators + self.validators = validators or self.default_validators[:] + + # Collect default error message from self and parent classes + messages = {} + for cls in reversed(self.__class__.__mro__): + messages.update(getattr(cls, 'default_error_messages', {})) + messages.update(error_messages or {}) + self.error_messages = messages def bind(self, field_name, parent, root): """ @@ -186,12 +203,14 @@ class Field(object): self.fail('required') return self.get_default() - self.run_validators(data) - return self.to_native(data) + value = self.to_native(data) + self.run_validators(value) + return value def run_validators(self, value): if value in validators.EMPTY_VALUES: return + errors = [] for validator in self.validators: try: @@ -218,33 +237,32 @@ class Field(object): A helper method that simply raises a validation error. """ try: - raise ValidationError(self.MESSAGES[key].format(**kwargs)) + msg = self.error_messages[key] except KeyError: class_name = self.__class__.__name__ - msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) + msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) raise AssertionError(msg) + raise ValidationError(msg.format(**kwargs)) def __new__(cls, *args, **kwargs): + """ + When a field is instantiated, we store the arguments that were used, + so that we can present a helpful representation of the object. + """ instance = super(Field, cls).__new__(cls) instance._args = args instance._kwargs = kwargs return instance def __repr__(self): - arg_string = ', '.join([repr(val) for val in self._args]) - kwarg_string = ', '.join([ - '%s=%s' % (key, repr(val)) for key, val in self._kwargs.items() - ]) - if arg_string and kwarg_string: - arg_string += ', ' - class_name = self.__class__.__name__ - return "%s(%s%s)" % (class_name, arg_string, kwarg_string) + return representation.field_repr(self) +# Boolean types... + class BooleanField(Field): - MESSAGES = { - 'required': 'This field is required.', - 'invalid_value': '`{input}` is not a valid boolean.' + default_error_messages = { + 'invalid': _('`{input}` is not a valid boolean.') } TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True} FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False} @@ -261,13 +279,23 @@ class BooleanField(Field): return True elif data in self.FALSE_VALUES: return False - self.fail('invalid_value', input=data) + self.fail('invalid', input=data) + + def to_primative(self, value): + if value is None: + return None + if value in self.TRUE_VALUES: + return True + elif value in self.FALSE_VALUES: + return False + return bool(value) +# String types... + class CharField(Field): - MESSAGES = { - 'required': 'This field is required.', - 'blank': 'This field may not be blank.' + default_error_messages = { + 'blank': _('This field may not be blank.') } def __init__(self, **kwargs): @@ -281,19 +309,364 @@ class CharField(Field): self.fail('blank') return str(data) + def to_primative(self, value): + if value is None: + return None + return str(value) -class ChoiceField(Field): - MESSAGES = { - 'required': 'This field is required.', - 'invalid_choice': '`{input}` is not a valid choice.' + +class EmailField(CharField): + default_error_messages = { + 'invalid': _('Enter a valid email address.') + } + default_validators = [validators.validate_email] + + def to_native(self, data): + ret = super(EmailField, self).to_native(data) + if ret is None: + return None + return ret.strip() + + def to_primative(self, value): + ret = super(EmailField, self).to_primative(value) + if ret is None: + return None + return ret.strip() + + +class RegexField(CharField): + def __init__(self, regex, **kwargs): + kwargs['validators'] = ( + [validators.RegexValidator(regex)] + + kwargs.get('validators', []) + ) + super(RegexField, self).__init__(**kwargs) + + +class SlugField(CharField): + default_error_messages = { + 'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.") + } + default_validators = [validators.validate_slug] + + +class URLField(CharField): + default_error_messages = { + 'invalid': _("Enter a valid URL.") + } + default_validators = [validators.URLValidator()] + + +# Number types... + +class IntegerField(Field): + default_error_messages = { + 'invalid': _('A valid integer is required.') + } + + def __init__(self, **kwargs): + max_value = kwargs.pop('max_value', None) + min_value = kwargs.pop('min_value', None) + super(IntegerField, self).__init__(**kwargs) + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) + print self.__class__.__name__, self.validators + + def to_native(self, data): + try: + data = int(str(data)) + except (ValueError, TypeError): + self.fail('invalid') + return data + + def to_primative(self, value): + if value is None: + return None + return int(value) + + +class FloatField(Field): + default_error_messages = { + 'invalid': _("'%s' value must be a float."), } - coerce_to_type = str def __init__(self, **kwargs): - choices = kwargs.pop('choices') + max_value = kwargs.pop('max_value', None) + min_value = kwargs.pop('min_value', None) + super(FloatField, self).__init__(**kwargs) + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) + + def to_primative(self, value): + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + self.fail('invalid', value=value) + + def to_native(self, value): + if value is None: + return None + return float(value) + + +class DecimalField(Field): + default_error_messages = { + 'invalid': _('Enter a number.'), + 'max_value': _('Ensure this value is less than or equal to {max_value}.'), + 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), + 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'), + 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'), + 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.') + } + + def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs): + self.max_value, self.min_value = max_value, min_value + self.max_digits, self.max_decimal_places = max_digits, decimal_places + super(DecimalField, self).__init__(**kwargs) + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) + + def from_native(self, value): + """ + Validates that the input is a decimal number. Returns a Decimal + instance. Returns None for empty values. Ensures that there are no more + than max_digits in the number, and no more than decimal_places digits + after the decimal point. + """ + if value in validators.EMPTY_VALUES: + return None + + value = smart_text(value).strip() + try: + value = decimal.Decimal(value) + except decimal.DecimalException: + self.fail('invalid') + + # Check for NaN. It is the only value that isn't equal to itself, + # so we can use this to identify NaN values. + if value != value: + self.fail('invalid') + + # Check for infinity and negative infinity. + if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')): + self.fail('invalid') + + sign, digittuple, exponent = value.as_tuple() + decimals = abs(exponent) + # digittuple doesn't include any leading zeros. + digits = len(digittuple) + if decimals > digits: + # We have leading zeros up to or past the decimal point. Count + # everything past the decimal point as a digit. We do not count + # 0 before the decimal point as a digit since that would mean + # we would not allow max_digits = decimal_places. + digits = decimals + whole_digits = digits - decimals + + if self.max_digits is not None and digits > self.max_digits: + self.fail('max_digits', max_digits=self.max_digits) + if self.decimal_places is not None and decimals > self.decimal_places: + self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places) + if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): + self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places) + + return value + + +# Date & time fields... + +class DateField(Field): + default_error_messages = { + 'invalid': _("Date has wrong format. Use one of these formats instead: %s"), + } + input_formats = api_settings.DATE_INPUT_FORMATS + format = api_settings.DATE_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(DateField, self).__init__(*args, **kwargs) + + def from_native(self, value): + if value in validators.EMPTY_VALUES: + return None + + if isinstance(value, datetime.datetime): + if timezone and settings.USE_TZ and timezone.is_aware(value): + # Convert aware datetimes to the default time zone + # before casting them to dates (#17742). + default_timezone = timezone.get_default_timezone() + value = timezone.make_naive(value, default_timezone) + return value.date() + if isinstance(value, datetime.date): + return value + + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_date(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed.date() + + humanized_format = humanize_datetime.date_formats(self.input_formats) + msg = self.error_messages['invalid'] % humanized_format + raise ValidationError(msg) + + def to_primative(self, value): + if value is None or self.format is None: + return value + + if isinstance(value, datetime.datetime): + value = value.date() + + if self.format.lower() == ISO_8601: + return value.isoformat() + return value.strftime(self.format) + + +class DateTimeField(Field): + default_error_messages = { + 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), + } + input_formats = api_settings.DATETIME_INPUT_FORMATS + format = api_settings.DATETIME_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(DateTimeField, self).__init__(*args, **kwargs) + + def from_native(self, value): + if value in validators.EMPTY_VALUES: + return None + + if isinstance(value, datetime.datetime): + return value + if isinstance(value, datetime.date): + value = datetime.datetime(value.year, value.month, value.day) + if settings.USE_TZ: + # For backwards compatibility, interpret naive datetimes in + # local time. This won't work during DST change, but we can't + # do much about it, so we let the exceptions percolate up the + # call stack. + warnings.warn("DateTimeField received a naive datetime (%s)" + " while time zone support is active." % value, + RuntimeWarning) + default_timezone = timezone.get_default_timezone() + value = timezone.make_aware(value, default_timezone) + return value + + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_datetime(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed + + humanized_format = humanize_datetime.datetime_formats(self.input_formats) + msg = self.error_messages['invalid'] % humanized_format + raise ValidationError(msg) + + def to_primative(self, value): + if value is None or self.format is None: + return value + + if self.format.lower() == ISO_8601: + ret = value.isoformat() + if ret.endswith('+00:00'): + ret = ret[:-6] + 'Z' + return ret + return value.strftime(self.format) - assert choices, '`choices` argument is required and may not be empty' +class TimeField(Field): + default_error_messages = { + 'invalid': _("Time has wrong format. Use one of these formats instead: %s"), + } + input_formats = api_settings.TIME_INPUT_FORMATS + format = api_settings.TIME_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(TimeField, self).__init__(*args, **kwargs) + + def from_native(self, value): + if value in validators.EMPTY_VALUES: + return None + + if isinstance(value, datetime.time): + return value + + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_time(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed.time() + + humanized_format = humanize_datetime.time_formats(self.input_formats) + msg = self.error_messages['invalid'] % humanized_format + raise ValidationError(msg) + + def to_primative(self, value): + if value is None or self.format is None: + return value + + if isinstance(value, datetime.datetime): + value = value.time() + + if self.format.lower() == ISO_8601: + return value.isoformat() + return value.strftime(self.format) + + +# Choice types... + +class ChoiceField(Field): + default_error_messages = { + 'invalid_choice': _('`{input}` is not a valid choice.') + } + + def __init__(self, choices, **kwargs): # Allow either single or paired choices style: # choices = [1, 2, 3] # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] @@ -321,12 +694,14 @@ class ChoiceField(Field): except KeyError: self.fail('invalid_choice', input=data) + def to_primative(self, value): + return value + class MultipleChoiceField(ChoiceField): - MESSAGES = { - 'required': 'This field is required.', - 'invalid_choice': '`{input}` is not a valid choice.', - 'not_a_list': 'Expected a list of items but got type `{input_type}`' + default_error_messages = { + 'invalid_choice': _('`{input}` is not a valid choice.'), + 'not_a_list': _('Expected a list of items but got type `{input_type}`') } def to_native(self, data): @@ -337,72 +712,42 @@ class MultipleChoiceField(ChoiceField): for item in data ]) - -class IntegerField(Field): - MESSAGES = { - 'required': 'This field is required.', - 'invalid_integer': 'A valid integer is required.' - } - - def __init__(self, **kwargs): - max_value = kwargs.pop('max_value', None) - min_value = kwargs.pop('min_value', None) - super(IntegerField, self).__init__(**kwargs) - if max_value is not None: - self.validators.append(validators.MaxValueValidator(max_value)) - if min_value is not None: - self.validators.append(validators.MinValueValidator(min_value)) - - def to_native(self, data): - try: - data = int(str(data)) - except (ValueError, TypeError): - self.fail('invalid_integer') - return data - def to_primative(self, value): - if value is None: - return None - return int(value) + return value -class EmailField(CharField): +# File types... + +class FileField(Field): pass # TODO -class URLField(CharField): +class ImageField(Field): pass # TODO -class RegexField(CharField): - def __init__(self, **kwargs): - self.regex = kwargs.pop('regex') - super(CharField, self).__init__(**kwargs) - +# Advanced field types... -class DateField(CharField): - def __init__(self, **kwargs): - self.input_formats = kwargs.pop('input_formats', None) - super(DateField, self).__init__(**kwargs) +class ReadOnlyField(Field): + """ + A read-only field that simply returns the field value. + If the field is a method with no parameters, the method will be called + and it's return value used as the representation. -class TimeField(CharField): - def __init__(self, **kwargs): - self.input_formats = kwargs.pop('input_formats', None) - super(TimeField, self).__init__(**kwargs) + For example, the following would call `get_expiry_date()` on the object: + class ExampleSerializer(self): + expiry_date = ReadOnlyField(source='get_expiry_date') + """ -class DateTimeField(CharField): def __init__(self, **kwargs): - self.input_formats = kwargs.pop('input_formats', None) - super(DateTimeField, self).__init__(**kwargs) - - -class FileField(Field): - pass # TODO + kwargs['read_only'] = True + super(ReadOnlyField, self).__init__(**kwargs) + def to_native(self, data): + raise NotImplemented('.to_native() not supported.') -class ReadOnlyField(Field): def to_primative(self, value): if is_simple_callable(value): return value() @@ -410,11 +755,28 @@ class ReadOnlyField(Field): class MethodField(Field): + """ + A read-only field that get its representation from calling a method on the + parent serializer class. The method called will be of the form + "get_{field_name}", and should take a single argument, which is the + object being serialized. + + For example: + + class ExampleSerializer(self): + extra_info = MethodField() + + def get_extra_info(self, obj): + return ... # Calculate some data to return. + """ def __init__(self, **kwargs): kwargs['source'] = '*' kwargs['read_only'] = True super(MethodField, self).__init__(**kwargs) + def to_native(self, data): + raise NotImplemented('.to_native() not supported.') + def to_primative(self, value): attr = 'get_{field_name}'.format(field_name=self.field_name) method = getattr(self.parent, attr) @@ -424,35 +786,14 @@ class MethodField(Field): class ModelField(Field): """ A generic field that can be used against an arbitrary model field. - """ - def __init__(self, *args, **kwargs): - try: - self.model_field = kwargs.pop('model_field') - except KeyError: - raise ValueError("ModelField requires 'model_field' kwarg") - - self.min_length = kwargs.pop('min_length', - getattr(self.model_field, 'min_length', None)) - self.max_length = kwargs.pop('max_length', - getattr(self.model_field, 'max_length', None)) - self.min_value = kwargs.pop('min_value', - getattr(self.model_field, 'min_value', None)) - self.max_value = kwargs.pop('max_value', - getattr(self.model_field, 'max_value', None)) - - super(ModelField, self).__init__(*args, **kwargs) - - if self.min_length is not None: - self.validators.append(validators.MinLengthValidator(self.min_length)) - if self.max_length is not None: - self.validators.append(validators.MaxLengthValidator(self.max_length)) - if self.min_value is not None: - self.validators.append(validators.MinValueValidator(self.min_value)) - if self.max_value is not None: - self.validators.append(validators.MaxValueValidator(self.max_value)) - def get_attribute(self, instance): - return get_attribute(instance, self.source_attrs[:-1]) + This is used by `ModelSerializer` when dealing with custom model fields, + that do not have a serializer field to be mapped to. + """ + def __init__(self, model_field, **kwargs): + self.model_field = model_field + kwargs['source'] = '*' + super(ModelField, self).__init__(**kwargs) def to_native(self, data): rel = getattr(self.model_field, 'rel', None) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 93226d32..8ca28387 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,15 +10,15 @@ python primitives. 2. The process of marshalling between python primitives and request and response content is handled by parsers and renderers. """ +from django.core import validators from django.core.exceptions import ValidationError from django.db import models from django.utils import six from collections import namedtuple, OrderedDict from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings -from rest_framework.utils import html +from rest_framework.utils import html, modelinfo, representation import copy -import inspect # Note: We do the following so that users of the framework can use this style: # @@ -146,12 +146,10 @@ class SerializerMetaclass(type): class Serializer(BaseSerializer): def __new__(cls, *args, **kwargs): - many = kwargs.pop('many', False) - if many: - class DynamicListSerializer(ListSerializer): - child = cls() - return DynamicListSerializer(*args, **kwargs) - return super(Serializer, cls).__new__(cls) + if kwargs.pop('many', False): + kwargs['child'] = cls() + return ListSerializer(*args, **kwargs) + return super(Serializer, cls).__new__(cls, *args, **kwargs) def __init__(self, *args, **kwargs): self.context = kwargs.pop('context', {}) @@ -248,6 +246,9 @@ class Serializer(BaseSerializer): error = errors.get(field.field_name) yield FieldResult(field, value, error) + def __repr__(self): + return representation.serializer_repr(self, indent=1) + class ListSerializer(BaseSerializer): child = None @@ -299,26 +300,8 @@ class ListSerializer(BaseSerializer): self.instance = self.create(self.validated_data) return self.instance - -def _resolve_model(obj): - """ - Resolve supplied `obj` to a Django model class. - - `obj` must be a Django model class itself, or a string - representation of one. Useful in situtations like GH #1225 where - Django may not have resolved a string-based reference to a model in - another model's foreign key definition. - - String representations should have the format: - 'appname.ModelName' - """ - if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: - app_name, model_name = obj.split('.') - return models.get_model(app_name, model_name) - elif inspect.isclass(obj) and issubclass(obj, models.Model): - return obj - else: - raise ValueError("{0} is not a Django model".format(obj)) + def __repr__(self): + return representation.list_repr(self, indent=1) class ModelSerializerOptions(object): @@ -334,24 +317,25 @@ class ModelSerializerOptions(object): class ModelSerializer(Serializer): field_mapping = { models.AutoField: IntegerField, - # models.FloatField: FloatField, + models.BigIntegerField: IntegerField, + models.BooleanField: BooleanField, + models.CharField: CharField, + models.CommaSeparatedIntegerField: CharField, + models.DateField: DateField, + models.DateTimeField: DateTimeField, + models.DecimalField: DecimalField, + models.EmailField: EmailField, + models.FileField: FileField, + models.FloatField: FloatField, models.IntegerField: IntegerField, + models.NullBooleanField: BooleanField, models.PositiveIntegerField: IntegerField, - models.SmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField, - models.DateTimeField: DateTimeField, - models.DateField: DateField, + models.SlugField: SlugField, + models.SmallIntegerField: IntegerField, + models.TextField: CharField, models.TimeField: TimeField, - # models.DecimalField: DecimalField, - models.EmailField: EmailField, - models.CharField: CharField, models.URLField: URLField, - # models.SlugField: SlugField, - models.TextField: CharField, - models.CommaSeparatedIntegerField: CharField, - models.BooleanField: BooleanField, - models.NullBooleanField: BooleanField, - models.FileField: FileField, # models.ImageField: ImageField, } @@ -392,85 +376,31 @@ class ModelSerializer(Serializer): """ Return all the fields that should be serialized for the model. """ - cls = self.opts.model - opts = cls._meta.concrete_model._meta + info = modelinfo.get_field_info(self.opts.model) ret = OrderedDict() - nested = bool(self.opts.depth) - # Deal with adding the primary key field - pk_field = opts.pk - while pk_field.rel and pk_field.rel.parent_link: - # If model is a child via multitable inheritance, use parent's pk - pk_field = pk_field.rel.to._meta.pk - - serializer_pk_field = self.get_pk_field(pk_field) + serializer_pk_field = self.get_pk_field(info.pk) if serializer_pk_field: - ret[pk_field.name] = serializer_pk_field - - # Deal with forward relationships - forward_rels = [field for field in opts.fields if field.serialize] - forward_rels += [field for field in opts.many_to_many if field.serialize] - - for model_field in forward_rels: - has_through_model = False - - if model_field.rel: - to_many = isinstance(model_field, - models.fields.related.ManyToManyField) - related_model = _resolve_model(model_field.rel.to) - - if to_many and not model_field.rel.through._meta.auto_created: - has_through_model = True + ret[info.pk.name] = serializer_pk_field - if model_field.rel and nested: - field = self.get_nested_field(model_field, related_model, to_many) - elif model_field.rel: - field = self.get_related_field(model_field, related_model, to_many) - else: - field = self.get_field(model_field) - - if field: - if has_through_model: - field.read_only = True + # Regular fields + for field_name, field in info.fields.items(): + ret[field_name] = self.get_field(field) - ret[model_field.name] = field - - # Deal with reverse relationships - if not self.opts.fields: - reverse_rels = [] - else: - # Reverse relationships are only included if they are explicitly - # present in the `fields` option on the serializer - reverse_rels = opts.get_all_related_objects() - reverse_rels += opts.get_all_related_many_to_many_objects() - - for relation in reverse_rels: - accessor_name = relation.get_accessor_name() - if not self.opts.fields or accessor_name not in self.opts.fields: - continue - related_model = relation.model - to_many = relation.field.rel.multiple - has_through_model = False - is_m2m = isinstance(relation.field, - models.fields.related.ManyToManyField) - - if ( - is_m2m and - hasattr(relation.field.rel, 'through') and - not relation.field.rel.through._meta.auto_created - ): - has_through_model = True - - if nested: - field = self.get_nested_field(None, related_model, to_many) + # Forward relations + for field_name, relation_info in info.forward_relations.items(): + if self.opts.depth: + ret[field_name] = self.get_nested_field(*relation_info) else: - field = self.get_related_field(None, related_model, to_many) + ret[field_name] = self.get_related_field(*relation_info) - if field: - if has_through_model: - field.read_only = True - - ret[accessor_name] = field + # Reverse relations + for accessor_name, relation_info in info.reverse_relations.items(): + if accessor_name in self.opts.fields: + if self.opts.depth: + ret[field_name] = self.get_nested_field(*relation_info) + else: + ret[field_name] = self.get_related_field(*relation_info) return ret @@ -480,7 +410,7 @@ class ModelSerializer(Serializer): """ return self.get_field(model_field) - def get_nested_field(self, model_field, related_model, to_many): + def get_nested_field(self, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a nested relational field. @@ -491,59 +421,148 @@ class ModelSerializer(Serializer): model = related_model depth = self.opts.depth - 1 - return NestedModelSerializer(many=to_many) + kwargs = {'read_only': True} + if to_many: + kwargs['many'] = True + return NestedModelSerializer(**kwargs) - def get_related_field(self, model_field, related_model, to_many): + def get_related_field(self, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a flat relational field. Note that model_field will be `None` for reverse relationships. """ - # TODO: filter queryset using: - # .using(db).complex_filter(self.rel.limit_choices_to) + kwargs = { + 'queryset': related_model._default_manager, + } - kwargs = {} - # 'queryset': related_model._default_manager, - # 'many': to_many - # } + if to_many: + kwargs['many'] = True + + if has_through_model: + kwargs['read_only'] = True + kwargs.pop('queryset', None) if model_field: - kwargs['required'] = not(model_field.null or model_field.blank) + if model_field.null or model_field.blank: + kwargs['required'] = False # if model_field.help_text is not None: # kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name if not model_field.editable: kwargs['read_only'] = True - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name + kwargs.pop('queryset', None) - return IntegerField(**kwargs) - # TODO: return PrimaryKeyRelatedField(**kwargs) + return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): """ Creates a default instance of a basic non-relational field. """ kwargs = {} + validator_kwarg = model_field.validators if model_field.null or model_field.blank: kwargs['required'] = False + if model_field.verbose_name is not None: + kwargs['label'] = model_field.verbose_name + if isinstance(model_field, models.AutoField) or not model_field.editable: kwargs['read_only'] = True + # Read only implies that the field is not required. + # We have a cleaner repr on the instance if we don't set it. + kwargs.pop('required', None) if model_field.has_default(): kwargs['default'] = model_field.get_default() - - if issubclass(model_field.__class__, models.TextField): - kwargs['widget'] = widgets.Textarea - - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name - - if model_field.validators is not None: - kwargs['validators'] = model_field.validators + # Having a default implies that the field is not required. + # We have a cleaner repr on the instance if we don't set it. + kwargs.pop('required', None) + + # Ensure that max_length is passed explicitly as a keyword arg, + # rather than as a validator. + max_length = getattr(model_field, 'max_length', None) + if max_length is not None: + kwargs['max_length'] = max_length + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MaxLengthValidator) + ] + + # Ensure that min_length is passed explicitly as a keyword arg, + # rather than as a validator. + min_length = getattr(model_field, 'min_length', None) + if min_length is not None: + kwargs['min_length'] = min_length + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MinLengthValidator) + ] + + # Ensure that max_value is passed explicitly as a keyword arg, + # rather than as a validator. + max_value = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MaxValueValidator) + ), None) + if max_value is not None: + kwargs['max_value'] = max_value + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MaxValueValidator) + ] + + # Ensure that max_value is passed explicitly as a keyword arg, + # rather than as a validator. + min_value = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MinValueValidator) + ), None) + if min_value is not None: + kwargs['min_value'] = min_value + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MinValueValidator) + ] + + # URLField does not need to include the URLValidator argument, + # as it is explicitly added in. + if isinstance(model_field, models.URLField): + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.URLValidator) + ] + + # EmailField does not need to include the validate_email argument, + # as it is explicitly added in. + if isinstance(model_field, models.EmailField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_email + ] + + # SlugField do not need to include the 'validate_slug' argument, + if isinstance(model_field, models.SlugField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_slug + ] + + max_digits = getattr(model_field, 'max_digits', None) + if max_digits is not None: + kwargs['max_digits'] = max_digits + + decimal_places = getattr(model_field, 'decimal_places', None) + if decimal_places is not None: + kwargs['decimal_places'] = decimal_places + + if validator_kwarg: + kwargs['validators'] = validator_kwarg + + # if issubclass(model_field.__class__, models.TextField): + # kwargs['widget'] = widgets.Textarea # if model_field.help_text is not None: # kwargs['help_text'] = model_field.help_text @@ -555,31 +574,10 @@ class ModelSerializer(Serializer): kwargs['empty'] = None return ChoiceField(**kwargs) - # put this below the ChoiceField because min_value isn't a valid initializer - if issubclass(model_field.__class__, models.PositiveIntegerField) or \ - issubclass(model_field.__class__, models.PositiveSmallIntegerField): - kwargs['min_value'] = 0 - if model_field.null and \ issubclass(model_field.__class__, (models.CharField, models.TextField)): kwargs['allow_none'] = True - # attribute_dict = { - # models.CharField: ['max_length'], - # models.CommaSeparatedIntegerField: ['max_length'], - # models.DecimalField: ['max_digits', 'decimal_places'], - # models.EmailField: ['max_length'], - # models.FileField: ['max_length'], - # models.ImageField: ['max_length'], - # models.SlugField: ['max_length'], - # models.URLField: ['max_length'], - # } - - # if model_field.__class__ in attribute_dict: - # attributes = attribute_dict[model_field.__class__] - # for attribute in attributes: - # kwargs.update({attribute: getattr(model_field, attribute)}) - try: return self.field_mapping[model_field.__class__](**kwargs) except KeyError: @@ -594,28 +592,21 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): super(HyperlinkedModelSerializerOptions, self).__init__(meta) self.view_name = getattr(meta, 'view_name', None) self.lookup_field = getattr(meta, 'lookup_field', None) - self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME) class HyperlinkedModelSerializer(ModelSerializer): _options_class = HyperlinkedModelSerializerOptions - _default_view_name = '%(model_name)s-detail' - _hyperlink_field_class = HyperlinkedRelatedField - _hyperlink_identify_field_class = HyperlinkedIdentityField def get_default_fields(self): fields = super(HyperlinkedModelSerializer, self).get_default_fields() if self.opts.view_name is None: - self.opts.view_name = self._get_default_view_name(self.opts.model) + self.opts.view_name = self.get_default_view_name(self.opts.model) - if self.opts.url_field_name not in fields: - url_field = self._hyperlink_identify_field_class( - view_name=self.opts.view_name, - lookup_field=self.opts.lookup_field - ) + url_field_name = api_settings.URL_FIELD_NAME + if url_field_name not in fields: ret = fields.__class__() - ret[self.opts.url_field_name] = url_field + ret[url_field_name] = self.get_url_field() ret.update(fields) fields = ret @@ -625,39 +616,48 @@ class HyperlinkedModelSerializer(ModelSerializer): if self.opts.fields and model_field.name in self.opts.fields: return self.get_field(model_field) - def get_related_field(self, model_field, related_model, to_many): + def get_url_field(self): + kwargs = { + 'view_name': self.get_default_view_name(self.opts.model) + } + if self.opts.lookup_field: + kwargs['lookup_field'] = self.opts.lookup_field + return HyperlinkedIdentityField(**kwargs) + + def get_related_field(self, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a flat relational field. """ - # TODO: filter queryset using: - # .using(db).complex_filter(self.rel.limit_choices_to) - # kwargs = { - # 'queryset': related_model._default_manager, - # 'view_name': self._get_default_view_name(related_model), - # 'many': to_many - # } - kwargs = {} + kwargs = { + 'queryset': related_model._default_manager, + 'view_name': self.get_default_view_name(related_model), + } + + if to_many: + kwargs['many'] = True + + if has_through_model: + kwargs['read_only'] = True + kwargs.pop('queryset', None) if model_field: - kwargs['required'] = not(model_field.null or model_field.blank) + if model_field.null or model_field.blank: + kwargs['required'] = False # if model_field.help_text is not None: # kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name + if not model_field.editable: + kwargs['read_only'] = True + kwargs.pop('queryset', None) - return IntegerField(**kwargs) - # if self.opts.lookup_field: - # kwargs['lookup_field'] = self.opts.lookup_field - - # return self._hyperlink_field_class(**kwargs) + return HyperlinkedRelatedField(**kwargs) - def _get_default_view_name(self, model): + def get_default_view_name(self, model): """ - Return the view name to use if 'view_name' is not specified in 'Meta' + Return the view name to use for related models. """ - model_meta = model._meta - format_kwargs = { - 'app_label': model_meta.app_label, - 'model_name': model_meta.object_name.lower() + return '%(model_name)s-detail' % { + 'app_label': model._meta.app_label, + 'model_name': model._meta.object_name.lower() } - return self._default_view_name % format_kwargs diff --git a/rest_framework/utils/humanize_datetime.py b/rest_framework/utils/humanize_datetime.py new file mode 100644 index 00000000..649f2abc --- /dev/null +++ b/rest_framework/utils/humanize_datetime.py @@ -0,0 +1,47 @@ +""" +Helper functions that convert strftime formats into more readable representations. +""" +from rest_framework import ISO_8601 + + +def datetime_formats(formats): + format = ', '.join(formats).replace( + ISO_8601, + 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' + ) + return humanize_strptime(format) + + +def date_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') + return humanize_strptime(format) + + +def time_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') + return humanize_strptime(format) + + +def humanize_strptime(format_string): + # Note that we're missing some of the locale specific mappings that + # don't really make sense. + mapping = { + "%Y": "YYYY", + "%y": "YY", + "%m": "MM", + "%b": "[Jan-Dec]", + "%B": "[January-December]", + "%d": "DD", + "%H": "hh", + "%I": "hh", # Requires '%p' to differentiate from '%H'. + "%M": "mm", + "%S": "ss", + "%f": "uuuuuu", + "%a": "[Mon-Sun]", + "%A": "[Monday-Sunday]", + "%p": "[AM|PM]", + "%z": "[+HHMM|-HHMM]" + } + for key, val in mapping.items(): + format_string = format_string.replace(key, val) + return format_string diff --git a/rest_framework/utils/modelinfo.py b/rest_framework/utils/modelinfo.py new file mode 100644 index 00000000..c0513886 --- /dev/null +++ b/rest_framework/utils/modelinfo.py @@ -0,0 +1,97 @@ +""" +Helper functions for returning the field information that is associated +with a model class. +""" +from collections import namedtuple, OrderedDict +from django.db import models +from django.utils import six +import inspect + +FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations']) +RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model']) + + +def _resolve_model(obj): + """ + Resolve supplied `obj` to a Django model class. + + `obj` must be a Django model class itself, or a string + representation of one. Useful in situtations like GH #1225 where + Django may not have resolved a string-based reference to a model in + another model's foreign key definition. + + String representations should have the format: + 'appname.ModelName' + """ + if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: + app_name, model_name = obj.split('.') + return models.get_model(app_name, model_name) + elif inspect.isclass(obj) and issubclass(obj, models.Model): + return obj + raise ValueError("{0} is not a Django model".format(obj)) + + +def get_field_info(model): + """ + Given a model class, returns a `FieldInfo` instance containing metadata + about the various field types on the model. + """ + opts = model._meta.concrete_model._meta + + # Deal with the primary key. + pk = opts.pk + while pk.rel and pk.rel.parent_link: + # If model is a child via multitable inheritance, use parent's pk. + pk = pk.rel.to._meta.pk + + # Deal with regular fields. + fields = OrderedDict() + for field in [field for field in opts.fields if field.serialize and not field.rel]: + fields[field.name] = field + + # Deal with forward relationships. + forward_relations = OrderedDict() + for field in [field for field in opts.fields if field.serialize and field.rel]: + forward_relations[field.name] = RelationInfo( + field=field, + related=_resolve_model(field.rel.to), + to_many=False, + has_through_model=False + ) + + # Deal with forward many-to-many relationships. + for field in [field for field in opts.many_to_many if field.serialize]: + forward_relations[field.name] = RelationInfo( + field=field, + related=_resolve_model(field.rel.to), + to_many=True, + has_through_model=( + not field.rel.through._meta.auto_created + ) + ) + + # Deal with reverse relationships. + reverse_relations = OrderedDict() + for relation in opts.get_all_related_objects(): + accessor_name = relation.get_accessor_name() + reverse_relations[accessor_name] = RelationInfo( + field=None, + related=relation.model, + to_many=relation.field.rel.multiple, + has_through_model=False + ) + + # Deal with reverse many-to-many relationships. + for relation in opts.get_all_related_many_to_many_objects(): + accessor_name = relation.get_accessor_name() + reverse_relations[accessor_name] = RelationInfo( + field=None, + related=relation.model, + to_many=True, + has_through_model=( + hasattr(relation.field.rel, 'through') and + not relation.field.rel.through._meta.auto_created + ) + ) + + return FieldInfo(pk, fields, forward_relations, reverse_relations) diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py new file mode 100644 index 00000000..1de21597 --- /dev/null +++ b/rest_framework/utils/representation.py @@ -0,0 +1,72 @@ +""" +Helper functions for creating user-friendly representations +of serializer classes and serializer fields. +""" +import re + + +def smart_repr(value): + value = repr(value) + + # Representations like u'help text' + # should simply be presented as 'help text' + if value.startswith("u'") and value.endswith("'"): + return value[1:] + + # Representations like + # + # Should be presented as + # + value = re.sub(' at 0x[0-9a-f]{8,10}>', '>', value) + + return value + + +def field_repr(field, force_many=False): + kwargs = field._kwargs + if force_many: + kwargs = kwargs.copy() + kwargs['many'] = True + kwargs.pop('child', None) + + arg_string = ', '.join([smart_repr(val) for val in field._args]) + kwarg_string = ', '.join([ + '%s=%s' % (key, smart_repr(val)) + for key, val in sorted(kwargs.items()) + ]) + if arg_string and kwarg_string: + arg_string += ', ' + + if force_many: + class_name = force_many.__class__.__name__ + else: + class_name = field.__class__.__name__ + + return "%s(%s%s)" % (class_name, arg_string, kwarg_string) + + +def serializer_repr(serializer, indent, force_many=None): + ret = field_repr(serializer, force_many) + ':' + indent_str = ' ' * indent + + if force_many: + fields = force_many.fields + else: + fields = serializer.fields + + for field_name, field in fields.items(): + ret += '\n' + indent_str + field_name + ' = ' + if hasattr(field, 'fields'): + ret += serializer_repr(field, indent + 1) + elif hasattr(field, 'child'): + ret += list_repr(field, indent + 1) + else: + ret += field_repr(field) + return ret + + +def list_repr(serializer, indent): + child = serializer.child + if hasattr(child, 'fields'): + return serializer_repr(serializer, indent, force_many=child) + return field_repr(serializer) -- cgit v1.2.3 From 234369aefdf08d7d0161d851866990754c00d31f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 10 Sep 2014 08:53:33 +0100 Subject: Tweaks --- rest_framework/fields.py | 6 ++++-- rest_framework/serializers.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 043a44ed..e2bd5700 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -190,7 +190,7 @@ class Field(object): raise SkipField() return self.default - def validate(self, data=empty): + def validate_value(self, data=empty): """ Validate a simple representation and return the internal value. @@ -506,6 +506,7 @@ class DateField(Field): default_timezone = timezone.get_default_timezone() value = timezone.make_naive(value, default_timezone) return value.date() + if isinstance(value, datetime.date): return value @@ -560,6 +561,7 @@ class DateTimeField(Field): if isinstance(value, datetime.datetime): return value + if isinstance(value, datetime.date): value = datetime.datetime(value.year, value.month, value.day) if settings.USE_TZ: @@ -675,7 +677,7 @@ class ChoiceField(Field): for item in choices ] if all(pairs): - self.choices = {key: val for key, val in choices} + self.choices = {key: display_value for key, display_value in choices} else: self.choices = {item: item for item in choices} diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8ca28387..0727b8cd 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -205,7 +205,7 @@ class Serializer(BaseSerializer): validate_method = getattr(self, 'validate_' + field.field_name, None) primitive_value = field.get_value(data) try: - validated_value = field.validate(primitive_value) + validated_value = field.validate_value(primitive_value) if validate_method is not None: validated_value = validate_method(validated_value) except ValidationError as exc: @@ -327,6 +327,7 @@ class ModelSerializer(Serializer): models.EmailField: EmailField, models.FileField: FileField, models.FloatField: FloatField, + models.ImageField: ImageField, models.IntegerField: IntegerField, models.NullBooleanField: BooleanField, models.PositiveIntegerField: IntegerField, @@ -336,7 +337,6 @@ class ModelSerializer(Serializer): models.TextField: CharField, models.TimeField: TimeField, models.URLField: URLField, - # models.ImageField: ImageField, } _options_class = ModelSerializerOptions -- cgit v1.2.3 From 01c8c0cad977fc0787dbfc78bd34f4fd37e613f4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 10 Sep 2014 13:52:16 +0100 Subject: Added help_text argument to fields --- rest_framework/compat.py | 23 ++++++------ rest_framework/fields.py | 5 ++- rest_framework/serializers.py | 86 +++++++++++++++++++++---------------------- 3 files changed, 56 insertions(+), 58 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index fa0f0bfb..70b38df9 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -39,6 +39,17 @@ except ImportError: django_filters = None +if django.VERSION >= (1, 6): + def clean_manytomany_helptext(text): + return text +else: + # Up to version 1.5 many to many fields automatically suffix + # the `help_text` attribute with hardcoded text. + def clean_manytomany_helptext(text): + if text.endswith(' Hold down "Control", or "Command" on a Mac, to select more than one.'): + text = text[:-69] + return text + # Django-guardian is optional. Import only if guardian is in INSTALLED_APPS # Fixes (#1712). We keep the try/except for the test suite. guardian = None @@ -99,18 +110,8 @@ def get_concrete_model(model_cls): return model_cls -# View._allowed_methods only present from 1.5 onwards -if django.VERSION >= (1, 5): - from django.views.generic import View -else: - from django.views.generic import View as DjangoView - - class View(DjangoView): - def _allowed_methods(self): - return [m.upper() for m in self.http_method_names if hasattr(self, m)] - - # PATCH method is not implemented by Django +from django.views.generic import View if 'patch' not in View.http_method_names: View.http_method_names = View.http_method_names + ['patch'] diff --git a/rest_framework/fields.py b/rest_framework/fields.py index e2bd5700..e71dce1d 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -101,7 +101,8 @@ class Field(object): def __init__(self, read_only=False, write_only=False, required=None, default=empty, initial=None, source=None, - label=None, style=None, error_messages=None, validators=[]): + label=None, help_text=None, style=None, + error_messages=None, validators=[]): self._creation_counter = Field._creation_counter Field._creation_counter += 1 @@ -122,6 +123,7 @@ class Field(object): self.source = source self.initial = initial self.label = label + self.help_text = help_text self.style = {} if style is None else style self.validators = validators or self.default_validators[:] @@ -372,7 +374,6 @@ class IntegerField(Field): self.validators.append(validators.MaxValueValidator(max_value)) if min_value is not None: self.validators.append(validators.MinValueValidator(min_value)) - print self.__class__.__name__, self.validators def to_native(self, data): try: diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 0727b8cd..459f8a8c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -15,6 +15,7 @@ from django.core.exceptions import ValidationError from django.db import models from django.utils import six from collections import namedtuple, OrderedDict +from rest_framework.compat import clean_manytomany_helptext from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings from rest_framework.utils import html, modelinfo, representation @@ -117,8 +118,9 @@ class SerializerMetaclass(type): """ This metaclass sets a dictionary named `base_fields` on the class. - Any fields included as attributes on either the class or it's superclasses - will be include in the `base_fields` dictionary. + Any instances of `Field` included as attributes on either the class + or on any of its superclasses will be include in the + `base_fields` dictionary. """ @classmethod @@ -379,6 +381,10 @@ class ModelSerializer(Serializer): info = modelinfo.get_field_info(self.opts.model) ret = OrderedDict() + serializer_url_field = self.get_url_field() + if serializer_url_field: + ret[api_settings.URL_FIELD_NAME] = serializer_url_field + serializer_pk_field = self.get_pk_field(info.pk) if serializer_pk_field: ret[info.pk.name] = serializer_pk_field @@ -404,6 +410,9 @@ class ModelSerializer(Serializer): return ret + def get_url_field(self): + return None + def get_pk_field(self, model_field): """ Returns a default instance of the pk field. @@ -446,13 +455,14 @@ class ModelSerializer(Serializer): if model_field: if model_field.null or model_field.blank: kwargs['required'] = False - # if model_field.help_text is not None: - # kwargs['help_text'] = model_field.help_text - if model_field.verbose_name is not None: + if model_field.verbose_name: kwargs['label'] = model_field.verbose_name if not model_field.editable: kwargs['read_only'] = True kwargs.pop('queryset', None) + help_text = clean_manytomany_helptext(model_field.help_text) + if help_text: + kwargs['help_text'] = help_text return PrimaryKeyRelatedField(**kwargs) @@ -469,6 +479,9 @@ class ModelSerializer(Serializer): if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name + if model_field.help_text: + kwargs['help_text'] = model_field.help_text + if isinstance(model_field, models.AutoField) or not model_field.editable: kwargs['read_only'] = True # Read only implies that the field is not required. @@ -481,6 +494,14 @@ class ModelSerializer(Serializer): # We have a cleaner repr on the instance if we don't set it. kwargs.pop('required', None) + if model_field.flatchoices: + # If this model field contains choices, then use a ChoiceField, + # rather than the standard serializer field for this type. + # Note that we return this prior to setting any validation type + # keyword arguments, as those are not valid initializers. + kwargs['choices'] = model_field.flatchoices + return ChoiceField(**kwargs) + # Ensure that max_length is passed explicitly as a keyword arg, # rather than as a validator. max_length = getattr(model_field, 'max_length', None) @@ -561,23 +582,6 @@ class ModelSerializer(Serializer): if validator_kwarg: kwargs['validators'] = validator_kwarg - # if issubclass(model_field.__class__, models.TextField): - # kwargs['widget'] = widgets.Textarea - - # if model_field.help_text is not None: - # kwargs['help_text'] = model_field.help_text - - # TODO: TypedChoiceField? - if model_field.flatchoices: # This ModelField contains choices - kwargs['choices'] = model_field.flatchoices - if model_field.null: - kwargs['empty'] = None - return ChoiceField(**kwargs) - - if model_field.null and \ - issubclass(model_field.__class__, (models.CharField, models.TextField)): - kwargs['allow_none'] = True - try: return self.field_mapping[model_field.__class__](**kwargs) except KeyError: @@ -597,33 +601,24 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializer(ModelSerializer): _options_class = HyperlinkedModelSerializerOptions - def get_default_fields(self): - fields = super(HyperlinkedModelSerializer, self).get_default_fields() - - if self.opts.view_name is None: - self.opts.view_name = self.get_default_view_name(self.opts.model) - - url_field_name = api_settings.URL_FIELD_NAME - if url_field_name not in fields: - ret = fields.__class__() - ret[url_field_name] = self.get_url_field() - ret.update(fields) - fields = ret - - return fields - - def get_pk_field(self, model_field): - if self.opts.fields and model_field.name in self.opts.fields: - return self.get_field(model_field) - def get_url_field(self): + if self.opts.view_name is not None: + view_name = self.opts.view_name + else: + view_name = self.get_default_view_name(self.opts.model) + kwargs = { - 'view_name': self.get_default_view_name(self.opts.model) + 'view_name': view_name } if self.opts.lookup_field: kwargs['lookup_field'] = self.opts.lookup_field + return HyperlinkedIdentityField(**kwargs) + def get_pk_field(self, model_field): + if self.opts.fields and model_field.name in self.opts.fields: + return self.get_field(model_field) + def get_related_field(self, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a flat relational field. @@ -643,13 +638,14 @@ class HyperlinkedModelSerializer(ModelSerializer): if model_field: if model_field.null or model_field.blank: kwargs['required'] = False - # if model_field.help_text is not None: - # kwargs['help_text'] = model_field.help_text - if model_field.verbose_name is not None: + if model_field.verbose_name: kwargs['label'] = model_field.verbose_name if not model_field.editable: kwargs['read_only'] = True kwargs.pop('queryset', None) + help_text = clean_manytomany_helptext(model_field.help_text) + if help_text: + kwargs['help_text'] = help_text return HyperlinkedRelatedField(**kwargs) -- cgit v1.2.3 From 80ba0473473501968154c5cc5dd5922e53d96a70 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 10 Sep 2014 16:57:22 +0100 Subject: Compat fixes --- rest_framework/compat.py | 12 +++++++++++- rest_framework/fields.py | 14 +++++++------- rest_framework/serializers.py | 25 +++++++++++++------------ rest_framework/utils/modelinfo.py | 9 +++++---- 4 files changed, 36 insertions(+), 24 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 70b38df9..7c05bed9 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -110,8 +110,18 @@ def get_concrete_model(model_cls): return model_cls +# View._allowed_methods only present from 1.5 onwards +if django.VERSION >= (1, 5): + from django.views.generic import View +else: + from django.views.generic import View as DjangoView + + class View(DjangoView): + def _allowed_methods(self): + return [m.upper() for m in self.http_method_names if hasattr(self, m)] + + # PATCH method is not implemented by Django -from django.views.generic import View if 'patch' not in View.http_method_names: View.http_method_names = View.http_method_names + ['patch'] diff --git a/rest_framework/fields.py b/rest_framework/fields.py index e71dce1d..3ec28908 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -266,8 +266,8 @@ class BooleanField(Field): default_error_messages = { 'invalid': _('`{input}` is not a valid boolean.') } - TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True} - FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False} + TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) + FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) def get_value(self, dictionary): if html.is_html_input(dictionary): @@ -678,16 +678,16 @@ class ChoiceField(Field): for item in choices ] if all(pairs): - self.choices = {key: display_value for key, display_value in choices} + self.choices = dict([(key, display_value) for key, display_value in choices]) else: - self.choices = {item: item for item in choices} + self.choices = dict([(item, item) for item in choices]) # Map the string representation of choices to the underlying value. # Allows us to deal with eg. integer choices while supporting either # integer or string input, but still get the correct datatype out. - self.choice_strings_to_values = { - str(key): key for key in self.choices.keys() - } + self.choice_strings_to_values = dict([ + (str(key), key) for key in self.choices.keys() + ]) super(ChoiceField, self).__init__(**kwargs) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 459f8a8c..13e57939 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -14,7 +14,8 @@ from django.core import validators from django.core.exceptions import ValidationError from django.db import models from django.utils import six -from collections import namedtuple, OrderedDict +from django.utils.datastructures import SortedDict +from collections import namedtuple from rest_framework.compat import clean_manytomany_helptext from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings @@ -91,10 +92,10 @@ class BaseSerializer(Field): if self.instance is not None: self._data = self.to_primative(self.instance) elif self._initial_data is not None: - self._data = { - field_name: field.get_value(self._initial_data) + self._data = dict([ + (field_name, field.get_value(self._initial_data)) for field_name, field in self.fields.items() - } + ]) else: self._data = self.get_initial() return self._data @@ -137,7 +138,7 @@ class SerializerMetaclass(type): if hasattr(base, 'base_fields'): fields = list(base.base_fields.items()) + fields - return OrderedDict(fields) + return SortedDict(fields) def __new__(cls, name, bases, attrs): attrs['base_fields'] = cls._get_fields(bases, attrs) @@ -180,10 +181,10 @@ class Serializer(BaseSerializer): field.bind(field_name, self, root) def get_initial(self): - return { - field.field_name: field.get_initial() + return dict([ + (field.field_name, field.get_initial()) for field in self.fields.values() - } + ]) def get_value(self, dictionary): # We override the default field access in order to support @@ -222,14 +223,14 @@ class Serializer(BaseSerializer): try: return self.validate(ret) - except ValidationError, exc: + except ValidationError as exc: raise ValidationError({'non_field_errors': exc.messages}) def to_primative(self, instance): """ Object instance -> Dict of primitive datatypes. """ - ret = OrderedDict() + ret = SortedDict() fields = [field for field in self.fields.values() if not field.write_only] for field in fields: @@ -368,7 +369,7 @@ class ModelSerializer(Serializer): # If `fields` is set on the `Meta` class, # then use only those fields, and in that order. if self.opts.fields: - fields = OrderedDict([ + fields = SortedDict([ (key, fields[key]) for key in self.opts.fields ]) @@ -379,7 +380,7 @@ class ModelSerializer(Serializer): Return all the fields that should be serialized for the model. """ info = modelinfo.get_field_info(self.opts.model) - ret = OrderedDict() + ret = SortedDict() serializer_url_field = self.get_url_field() if serializer_url_field: diff --git a/rest_framework/utils/modelinfo.py b/rest_framework/utils/modelinfo.py index c0513886..a7a0346c 100644 --- a/rest_framework/utils/modelinfo.py +++ b/rest_framework/utils/modelinfo.py @@ -2,9 +2,10 @@ Helper functions for returning the field information that is associated with a model class. """ -from collections import namedtuple, OrderedDict +from collections import namedtuple from django.db import models from django.utils import six +from django.utils.datastructures import SortedDict import inspect FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations']) @@ -45,12 +46,12 @@ def get_field_info(model): pk = pk.rel.to._meta.pk # Deal with regular fields. - fields = OrderedDict() + fields = SortedDict() for field in [field for field in opts.fields if field.serialize and not field.rel]: fields[field.name] = field # Deal with forward relationships. - forward_relations = OrderedDict() + forward_relations = SortedDict() for field in [field for field in opts.fields if field.serialize and field.rel]: forward_relations[field.name] = RelationInfo( field=field, @@ -71,7 +72,7 @@ def get_field_info(model): ) # Deal with reverse relationships. - reverse_relations = OrderedDict() + reverse_relations = SortedDict() for relation in opts.get_all_related_objects(): accessor_name = relation.get_accessor_name() reverse_relations[accessor_name] = RelationInfo( -- cgit v1.2.3 From 54ccf7230d0fcdabe8c2457539e314893915a34b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 11 Sep 2014 13:43:46 +0100 Subject: Improve memory address removal for serializer representations --- rest_framework/utils/representation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py index 1de21597..e2a37497 100644 --- a/rest_framework/utils/representation.py +++ b/rest_framework/utils/representation.py @@ -17,7 +17,7 @@ def smart_repr(value): # # Should be presented as # - value = re.sub(' at 0x[0-9a-f]{8,10}>', '>', value) + value = re.sub(' at 0x[0-9a-f]{8,32}>', '>', value) return value -- cgit v1.2.3 From 3318f75a7166cbac76a40d0461ca7b3e4640d3a2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 11 Sep 2014 13:50:53 +0100 Subject: Improve memory address removal for serializer representations --- rest_framework/utils/representation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py index e2a37497..1a4d1a62 100644 --- a/rest_framework/utils/representation.py +++ b/rest_framework/utils/representation.py @@ -17,7 +17,7 @@ def smart_repr(value): # # Should be presented as # - value = re.sub(' at 0x[0-9a-f]{8,32}>', '>', value) + value = re.sub(' at 0x[0-9a-f]{4,32}>', '>', value) return value -- cgit v1.2.3 From ab40780dc2f341a271c2f489659dcd48eb47c07d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 11 Sep 2014 20:22:32 +0100 Subject: Tidy up lookup_class --- rest_framework/serializers.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8fe999ae..4322f213 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -317,17 +317,17 @@ class ModelSerializerOptions(object): self.depth = getattr(meta, 'depth', 0) -def lookup_class(mapping, obj): +def lookup_class(mapping, instance): """ Takes a dictionary with classes as keys, and an object. Traverses the object's inheritance hierarchy in method resolution order, and returns the first matching value - from the dictionary or None. + from the dictionary or raises a KeyError if nothing matches. """ - return next( - (mapping[cls] for cls in inspect.getmro(obj.__class__) if cls in mapping), - None - ) + for cls in inspect.getmro(instance.__class__): + if cls in mapping: + return mapping[cls] + raise KeyError('Class %s not found in lookup.', cls.__name__) class ModelSerializer(Serializer): @@ -341,6 +341,7 @@ class ModelSerializer(Serializer): models.DateTimeField: DateTimeField, models.DecimalField: DecimalField, models.EmailField: EmailField, + models.Field: ModelField, models.FileField: FileField, models.FloatField: FloatField, models.ImageField: ImageField, @@ -484,6 +485,7 @@ class ModelSerializer(Serializer): """ Creates a default instance of a basic non-relational field. """ + serializer_cls = lookup_class(self.field_mapping, model_field) kwargs = {} validator_kwarg = model_field.validators @@ -602,11 +604,10 @@ class ModelSerializer(Serializer): if validator_kwarg: kwargs['validators'] = validator_kwarg - cls = lookup_class(self.field_mapping, model_field) - if cls is None: - cls = ModelField + if issubclass(serializer_cls, ModelField): kwargs['model_field'] = model_field - return cls(**kwargs) + + return serializer_cls(**kwargs) class HyperlinkedModelSerializerOptions(ModelSerializerOptions): -- cgit v1.2.3 From bf52d04f4c370d6917599d26c84b73124d5ef366 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 11 Sep 2014 20:37:27 +0100 Subject: Nice manager representations on serializer classes --- rest_framework/utils/representation.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py index 1a4d1a62..71db1886 100644 --- a/rest_framework/utils/representation.py +++ b/rest_framework/utils/representation.py @@ -2,10 +2,23 @@ Helper functions for creating user-friendly representations of serializer classes and serializer fields. """ +from django.db import models import re +def manager_repr(value): + model = value.model + opts = model._meta + for _, name, manager in opts.concrete_managers + opts.abstract_managers: + if manager == value: + return '%s.%s.all()' % (model._meta.object_name, name) + return repr(value) + + def smart_repr(value): + if isinstance(value, models.Manager): + return manager_repr(value) + value = repr(value) # Representations like u'help text' -- cgit v1.2.3 From 19b8f779de82fa4737b37fb4359145af0b07a56c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 11 Sep 2014 20:43:44 +0100 Subject: Throttles now use Retry-After header and no longer support the custom style --- rest_framework/views.py | 1 - 1 file changed, 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/views.py b/rest_framework/views.py index 3b7b1c16..cd394b2d 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -62,7 +62,6 @@ def exception_handler(exc): if getattr(exc, 'auth_header', None): headers['WWW-Authenticate'] = exc.auth_header if getattr(exc, 'wait', None): - headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait headers['Retry-After'] = '%d' % exc.wait return Response({'detail': exc.detail}, -- cgit v1.2.3 From 55650a743d579e0bc1643c8812428746b0271984 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 11 Sep 2014 20:49:10 +0100 Subject: no longer tightly coupled to private queryset API --- rest_framework/generics.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index c2c59154..408b1246 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -3,6 +3,7 @@ Generic views that provide commonly needed behaviour. """ from __future__ import unicode_literals +from django.db.models.query import QuerySet from django.core.exceptions import PermissionDenied from django.core.paginator import Paginator, InvalidPage from django.http import Http404 @@ -214,7 +215,9 @@ class GenericAPIView(views.APIView): % self.__class__.__name__ ) - return self.queryset._clone() + if isinstance(self.queryset, QuerySet): + return self.queryset.all() + return self.queryset def get_object(self): """ -- cgit v1.2.3 From a7518719917c7ad8e699119b442cfeb568ba1dde Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 11 Sep 2014 20:50:26 +0100 Subject: no longer tightly coupled to private queryset API --- rest_framework/generics.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 408b1246..338d56a6 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -215,9 +215,11 @@ class GenericAPIView(views.APIView): % self.__class__.__name__ ) + queryset = self.queryset if isinstance(self.queryset, QuerySet): - return self.queryset.all() - return self.queryset + # Ensure queryset is re-evaluated on each request. + queryset = queryset.all() + return queryset def get_object(self): """ -- cgit v1.2.3 From 040bfcc09c851bb3dadd60558c78a1f7937e9fbd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 11 Sep 2014 21:48:54 +0100 Subject: NotImplemented stubs for Field, and DecimalField improvements --- rest_framework/fields.py | 27 +++++++++++++++++++++------ rest_framework/pagination.py | 4 ++-- rest_framework/utils/encoders.py | 2 +- 3 files changed, 24 insertions(+), 9 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 7496a629..20b8ffbf 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -229,13 +229,13 @@ class Field(object): """ Transform the *incoming* primative data into a native value. """ - return data + raise NotImplementedError('to_native() must be implemented.') def to_primative(self, value): """ Transform the *outgoing* native value into primative data. """ - return value + raise NotImplementedError('to_primative() must be implemented.') def fail(self, key, **kwargs): """ @@ -429,9 +429,10 @@ class DecimalField(Field): 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.') } - def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs): - self.max_value, self.min_value = max_value, min_value - self.max_digits, self.max_decimal_places = max_digits, decimal_places + def __init__(self, max_digits, decimal_places, coerce_to_string=True, max_value=None, min_value=None, **kwargs): + self.max_digits = max_digits + self.decimal_places = decimal_places + self.coerce_to_string = coerce_to_string super(DecimalField, self).__init__(**kwargs) if max_value is not None: self.validators.append(validators.MaxValueValidator(max_value)) @@ -478,12 +479,26 @@ class DecimalField(Field): if self.max_digits is not None and digits > self.max_digits: self.fail('max_digits', max_digits=self.max_digits) if self.decimal_places is not None and decimals > self.decimal_places: - self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places) + self.fail('max_decimal_places', max_decimal_places=self.decimal_places) if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places) return value + def to_primative(self, value): + if not self.coerce_to_string: + return value + + if isinstance(value, decimal.Decimal): + context = decimal.getcontext().copy() + context.prec = self.max_digits + quantized = value.quantize( + decimal.Decimal('.1') ** self.decimal_places, + context=context + ) + return '{0:f}'.format(quantized) + return '%.*f' % (self.max_decimal_places, value) + # Date & time fields... diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 9cf31629..d82d2d3b 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -37,7 +37,7 @@ class PreviousPageField(serializers.Field): return replace_query_param(url, self.page_field, page) -class DefaultObjectSerializer(serializers.Field): +class DefaultObjectSerializer(serializers.ReadOnlyField): """ If no object serializer is specified, then this serializer will be applied as the default. @@ -79,6 +79,6 @@ class PaginationSerializer(BasePaginationSerializer): """ A default implementation of a pagination serializer. """ - count = serializers.Field(source='paginator.count') + count = serializers.ReadOnlyField(source='paginator.count') next = NextPageField(source='*') previous = PreviousPageField(source='*') diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 6a2f6126..7992b6b1 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -43,7 +43,7 @@ class JSONEncoder(json.JSONEncoder): elif isinstance(o, datetime.timedelta): return str(o.total_seconds()) elif isinstance(o, decimal.Decimal): - return str(o) + return float(o) elif isinstance(o, QuerySet): return list(o) elif hasattr(o, 'tolist'): -- cgit v1.2.3 From 1e53eb0aa2998385e26aa0a4d8542013bc7b575b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 11 Sep 2014 21:57:32 +0100 Subject: DecimalFields should still be quantized even without coerce_to_string --- rest_framework/fields.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 20b8ffbf..a56ea96b 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -486,9 +486,6 @@ class DecimalField(Field): return value def to_primative(self, value): - if not self.coerce_to_string: - return value - if isinstance(value, decimal.Decimal): context = decimal.getcontext().copy() context.prec = self.max_digits @@ -496,7 +493,12 @@ class DecimalField(Field): decimal.Decimal('.1') ** self.decimal_places, context=context ) + if not self.coerce_to_string: + return quantized return '{0:f}'.format(quantized) + + if not self.coerce_to_string: + return value return '%.*f' % (self.max_decimal_places, value) -- cgit v1.2.3 From adcb64ab4198f35c61d5be68956d201685ed3538 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 12 Sep 2014 09:12:56 +0100 Subject: MethodField -> SerializerMethodField --- rest_framework/fields.py | 21 +++++++++------------ rest_framework/serializers.py | 3 ++- rest_framework/utils/modelinfo.py | 3 ++- 3 files changed, 13 insertions(+), 14 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index a56ea96b..5f198767 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -768,16 +768,13 @@ class ReadOnlyField(Field): kwargs['read_only'] = True super(ReadOnlyField, self).__init__(**kwargs) - def to_native(self, data): - raise NotImplemented('.to_native() not supported.') - def to_primative(self, value): if is_simple_callable(value): return value() return value -class MethodField(Field): +class SerializerMethodField(Field): """ A read-only field that get its representation from calling a method on the parent serializer class. The method called will be of the form @@ -787,22 +784,22 @@ class MethodField(Field): For example: class ExampleSerializer(self): - extra_info = MethodField() + extra_info = SerializerMethodField() def get_extra_info(self, obj): return ... # Calculate some data to return. """ - def __init__(self, **kwargs): + def __init__(self, method_attr=None, **kwargs): + self.method_attr = method_attr kwargs['source'] = '*' kwargs['read_only'] = True - super(MethodField, self).__init__(**kwargs) - - def to_native(self, data): - raise NotImplemented('.to_native() not supported.') + super(SerializerMethodField, self).__init__(**kwargs) def to_primative(self, value): - attr = 'get_{field_name}'.format(field_name=self.field_name) - method = getattr(self.parent, attr) + method_attr = self.method_attr + if method_attr is None: + method_attr = 'get_{field_name}'.format(field_name=self.field_name) + method = getattr(self.parent, method_attr) return method(value) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4322f213..388fe29f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -598,7 +598,8 @@ class ModelSerializer(Serializer): if isinstance(model_field, models.BooleanField): # models.BooleanField has `blank=True`, but *is* actually # required *unless* a default is provided. - # Also note that <1.6 `default=False`, >=1.6 `default=None`. + # Also note that Django<1.6 uses `default=False` for + # models.BooleanField, but Django>=1.6 uses `default=None`. kwargs.pop('required', None) if validator_kwarg: diff --git a/rest_framework/utils/modelinfo.py b/rest_framework/utils/modelinfo.py index a7a0346c..960fa4d0 100644 --- a/rest_framework/utils/modelinfo.py +++ b/rest_framework/utils/modelinfo.py @@ -1,6 +1,7 @@ """ Helper functions for returning the field information that is associated -with a model class. +with a model class. This includes returning all the forward and reverse +relationships and their associated metadata. """ from collections import namedtuple from django.db import models -- cgit v1.2.3 From 0d354e8f92c7daaf8dac3b80f0fd64f983f21e0b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 12 Sep 2014 09:49:35 +0100 Subject: to_internal_value() and to_representation() --- rest_framework/fields.py | 86 ++++++++++++++++++++++--------------------- rest_framework/pagination.py | 4 +- rest_framework/relations.py | 2 +- rest_framework/serializers.py | 28 +++++++------- 4 files changed, 61 insertions(+), 59 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 5f198767..a96f9ba8 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -195,7 +195,7 @@ class Field(object): raise SkipField() return self.default - def validate_value(self, data=empty): + def run_validation(self, data=empty): """ Validate a simple representation and return the internal value. @@ -208,7 +208,7 @@ class Field(object): self.fail('required') return self.get_default() - value = self.to_native(data) + value = self.to_internal_value(data) self.run_validators(value) return value @@ -225,17 +225,17 @@ class Field(object): if errors: raise ValidationError(errors) - def to_native(self, data): + def to_internal_value(self, data): """ Transform the *incoming* primative data into a native value. """ - raise NotImplementedError('to_native() must be implemented.') + raise NotImplementedError('to_internal_value() must be implemented.') - def to_primative(self, value): + def to_representation(self, value): """ Transform the *outgoing* native value into primative data. """ - raise NotImplementedError('to_primative() must be implemented.') + raise NotImplementedError('to_representation() must be implemented.') def fail(self, key, **kwargs): """ @@ -279,14 +279,14 @@ class BooleanField(Field): return dictionary.get(self.field_name, False) return dictionary.get(self.field_name, empty) - def to_native(self, data): + def to_internal_value(self, data): if data in self.TRUE_VALUES: return True elif data in self.FALSE_VALUES: return False self.fail('invalid', input=data) - def to_primative(self, value): + def to_representation(self, value): if value is None: return None if value in self.TRUE_VALUES: @@ -309,12 +309,14 @@ class CharField(Field): self.min_length = kwargs.pop('min_length', None) super(CharField, self).__init__(**kwargs) - def to_native(self, data): + def to_internal_value(self, data): if data == '' and not self.allow_blank: self.fail('blank') + if data is None: + return None return str(data) - def to_primative(self, value): + def to_representation(self, value): if value is None: return None return str(value) @@ -326,17 +328,17 @@ class EmailField(CharField): } default_validators = [validators.validate_email] - def to_native(self, data): - ret = super(EmailField, self).to_native(data) - if ret is None: + def to_internal_value(self, data): + if data == '' and not self.allow_blank: + self.fail('blank') + if data is None: return None - return ret.strip() + return str(data).strip() - def to_primative(self, value): - ret = super(EmailField, self).to_primative(value) - if ret is None: + def to_representation(self, value): + if value is None: return None - return ret.strip() + return str(value).strip() class RegexField(CharField): @@ -378,14 +380,14 @@ class IntegerField(Field): if min_value is not None: self.validators.append(validators.MinValueValidator(min_value)) - def to_native(self, data): + def to_internal_value(self, data): try: data = int(str(data)) except (ValueError, TypeError): self.fail('invalid') return data - def to_primative(self, value): + def to_representation(self, value): if value is None: return None return int(value) @@ -405,7 +407,12 @@ class FloatField(Field): if min_value is not None: self.validators.append(validators.MinValueValidator(min_value)) - def to_primative(self, value): + def to_internal_value(self, value): + if value is None: + return None + return float(value) + + def to_representation(self, value): if value is None: return None try: @@ -413,11 +420,6 @@ class FloatField(Field): except (TypeError, ValueError): self.fail('invalid', value=value) - def to_native(self, value): - if value is None: - return None - return float(value) - class DecimalField(Field): default_error_messages = { @@ -439,7 +441,7 @@ class DecimalField(Field): if min_value is not None: self.validators.append(validators.MinValueValidator(min_value)) - def from_native(self, value): + def to_internal_value(self, value): """ Validates that the input is a decimal number. Returns a Decimal instance. Returns None for empty values. Ensures that there are no more @@ -485,7 +487,7 @@ class DecimalField(Field): return value - def to_primative(self, value): + def to_representation(self, value): if isinstance(value, decimal.Decimal): context = decimal.getcontext().copy() context.prec = self.max_digits @@ -516,7 +518,7 @@ class DateField(Field): self.format = format if format is not None else self.format super(DateField, self).__init__(*args, **kwargs) - def from_native(self, value): + def to_internal_value(self, value): if value in validators.EMPTY_VALUES: return None @@ -552,7 +554,7 @@ class DateField(Field): msg = self.error_messages['invalid'] % humanized_format raise ValidationError(msg) - def to_primative(self, value): + def to_representation(self, value): if value is None or self.format is None: return value @@ -576,7 +578,7 @@ class DateTimeField(Field): self.format = format if format is not None else self.format super(DateTimeField, self).__init__(*args, **kwargs) - def from_native(self, value): + def to_internal_value(self, value): if value in validators.EMPTY_VALUES: return None @@ -618,7 +620,7 @@ class DateTimeField(Field): msg = self.error_messages['invalid'] % humanized_format raise ValidationError(msg) - def to_primative(self, value): + def to_representation(self, value): if value is None or self.format is None: return value @@ -670,7 +672,7 @@ class TimeField(Field): msg = self.error_messages['invalid'] % humanized_format raise ValidationError(msg) - def to_primative(self, value): + def to_representation(self, value): if value is None or self.format is None: return value @@ -711,13 +713,13 @@ class ChoiceField(Field): super(ChoiceField, self).__init__(**kwargs) - def to_native(self, data): + def to_internal_value(self, data): try: return self.choice_strings_to_values[str(data)] except KeyError: self.fail('invalid_choice', input=data) - def to_primative(self, value): + def to_representation(self, value): return value @@ -727,15 +729,15 @@ class MultipleChoiceField(ChoiceField): 'not_a_list': _('Expected a list of items but got type `{input_type}`') } - def to_native(self, data): + def to_internal_value(self, data): if not hasattr(data, '__iter__'): self.fail('not_a_list', input_type=type(data).__name__) return set([ - super(MultipleChoiceField, self).to_native(item) + super(MultipleChoiceField, self).to_internal_value(item) for item in data ]) - def to_primative(self, value): + def to_representation(self, value): return value @@ -768,7 +770,7 @@ class ReadOnlyField(Field): kwargs['read_only'] = True super(ReadOnlyField, self).__init__(**kwargs) - def to_primative(self, value): + def to_representation(self, value): if is_simple_callable(value): return value() return value @@ -795,7 +797,7 @@ class SerializerMethodField(Field): kwargs['read_only'] = True super(SerializerMethodField, self).__init__(**kwargs) - def to_primative(self, value): + def to_representation(self, value): method_attr = self.method_attr if method_attr is None: method_attr = 'get_{field_name}'.format(field_name=self.field_name) @@ -815,13 +817,13 @@ class ModelField(Field): kwargs['source'] = '*' super(ModelField, self).__init__(**kwargs) - def to_native(self, data): + def to_internal_value(self, data): rel = getattr(self.model_field, 'rel', None) if rel is not None: return rel.to._meta.get_field(rel.field_name).to_python(data) return self.model_field.to_python(data) - def to_primative(self, obj): + def to_representation(self, obj): value = self.model_field._get_val_from_obj(obj) if is_protected_type(value): return value diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index d82d2d3b..c5a9270a 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -13,7 +13,7 @@ class NextPageField(serializers.Field): """ page_field = 'page' - def to_primative(self, value): + def to_representation(self, value): if not value.has_next(): return None page = value.next_page_number() @@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field): """ page_field = 'page' - def to_primative(self, value): + def to_representation(self, value): if not value.has_previous(): return None page = value.previous_page_number() diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 661a1249..30a252db 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -110,7 +110,7 @@ class HyperlinkedIdentityField(RelatedField): def get_attribute(self, instance): return instance - def to_primative(self, value): + def to_representation(self, value): request = self.context.get('request', None) format = self.context.get('format', None) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 388fe29f..502b1e19 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -47,11 +47,11 @@ class BaseSerializer(Field): self.instance = instance self._initial_data = data - def to_native(self, data): - raise NotImplementedError('`to_native()` must be implemented.') + def to_internal_value(self, data): + raise NotImplementedError('`to_internal_value()` must be implemented.') - def to_primative(self, instance): - raise NotImplementedError('`to_primative()` must be implemented.') + def to_representation(self, instance): + raise NotImplementedError('`to_representation()` must be implemented.') def update(self, instance, attrs): raise NotImplementedError('`update()` must be implemented.') @@ -74,7 +74,7 @@ class BaseSerializer(Field): def is_valid(self, raise_exception=False): if not hasattr(self, '_validated_data'): try: - self._validated_data = self.to_native(self._initial_data) + self._validated_data = self.to_internal_value(self._initial_data) except ValidationError as exc: self._validated_data = {} self._errors = exc.message_dict @@ -90,7 +90,7 @@ class BaseSerializer(Field): def data(self): if not hasattr(self, '_data'): if self.instance is not None: - self._data = self.to_primative(self.instance) + self._data = self.to_representation(self.instance) elif self._initial_data is not None: self._data = dict([ (field_name, field.get_value(self._initial_data)) @@ -193,7 +193,7 @@ class Serializer(BaseSerializer): return html.parse_html_dict(dictionary, prefix=self.field_name) return dictionary.get(self.field_name, empty) - def to_native(self, data): + def to_internal_value(self, data): """ Dict of native values <- Dict of primitive datatypes. """ @@ -208,7 +208,7 @@ class Serializer(BaseSerializer): validate_method = getattr(self, 'validate_' + field.field_name, None) primitive_value = field.get_value(data) try: - validated_value = field.validate_value(primitive_value) + validated_value = field.run_validation(primitive_value) if validate_method is not None: validated_value = validate_method(validated_value) except ValidationError as exc: @@ -226,7 +226,7 @@ class Serializer(BaseSerializer): except ValidationError as exc: raise ValidationError({'non_field_errors': exc.messages}) - def to_primative(self, instance): + def to_representation(self, instance): """ Object instance -> Dict of primitive datatypes. """ @@ -235,7 +235,7 @@ class Serializer(BaseSerializer): for field in fields: native_value = field.get_attribute(instance) - ret[field.field_name] = field.to_primative(native_value) + ret[field.field_name] = field.to_representation(native_value) return ret @@ -279,20 +279,20 @@ class ListSerializer(BaseSerializer): return html.parse_html_list(dictionary, prefix=self.field_name) return dictionary.get(self.field_name, empty) - def to_native(self, data): + def to_internal_value(self, data): """ List of dicts of native values <- List of dicts of primitive datatypes. """ if html.is_html_input(data): data = html.parse_html_list(data) - return [self.child.validate(item) for item in data] + return [self.child.run_validation(item) for item in data] - def to_primative(self, data): + def to_representation(self, data): """ List of object instances -> List of dicts of primitive datatypes. """ - return [self.child.to_primative(item) for item in data] + return [self.child.to_representation(item) for item in data] def create(self, attrs_list): return [self.child.create(attrs) for attrs in attrs_list] -- cgit v1.2.3 From 6db3356c4d1aa4f9a042b0ec67d47238abc16dd7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 12 Sep 2014 10:21:35 +0100 Subject: NON_FIELD_ERRORS_KEY setting --- rest_framework/serializers.py | 8 ++++++-- rest_framework/settings.py | 1 + rest_framework/views.py | 8 +++++++- 3 files changed, 14 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 502b1e19..0c2aedfa 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -198,7 +198,9 @@ class Serializer(BaseSerializer): Dict of native values <- Dict of primitive datatypes. """ if not isinstance(data, dict): - raise ValidationError({'non_field_errors': ['Invalid data']}) + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data'] + }) ret = {} errors = {} @@ -224,7 +226,9 @@ class Serializer(BaseSerializer): try: return self.validate(ret) except ValidationError as exc: - raise ValidationError({'non_field_errors': exc.messages}) + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: exc.messages + }) def to_representation(self, instance): """ diff --git a/rest_framework/settings.py b/rest_framework/settings.py index bbe7a56a..f48643b5 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -77,6 +77,7 @@ DEFAULTS = { # Exception handling 'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler', + 'NON_FIELD_ERRORS_KEY': 'non_field_errors', # Testing 'TEST_REQUEST_RENDERER_CLASSES': ( diff --git a/rest_framework/views.py b/rest_framework/views.py index cd394b2d..9f08a4ad 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework. """ from __future__ import unicode_literals -from django.core.exceptions import PermissionDenied, ValidationError +from django.core.exceptions import PermissionDenied, ValidationError, NON_FIELD_ERRORS from django.http import Http404 from django.utils.datastructures import SortedDict from django.views.decorators.csrf import csrf_exempt @@ -69,6 +69,12 @@ def exception_handler(exc): headers=headers) elif isinstance(exc, ValidationError): + # ValidationErrors may include the non-field key named '__all__'. + # When returning a response we map this to a key name that can be + # modified in settings. + if NON_FIELD_ERRORS in exc.message_dict: + errors = exc.message_dict.pop(NON_FIELD_ERRORS) + exc.message_dict[api_settings.NON_FIELD_ERRORS_KEY] = errors return Response(exc.message_dict, status=status.HTTP_400_BAD_REQUEST) -- cgit v1.2.3 From 250755def707e1397876614fa0c08130d9fcc449 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 12 Sep 2014 10:59:51 +0100 Subject: Clean up relational fields queryset usage --- rest_framework/fields.py | 15 ++++------ rest_framework/generics.py | 2 +- rest_framework/relations.py | 73 ++++++++++++++++++++++++--------------------- 3 files changed, 46 insertions(+), 44 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index a96f9ba8..4f06d186 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -508,7 +508,7 @@ class DecimalField(Field): class DateField(Field): default_error_messages = { - 'invalid': _("Date has wrong format. Use one of these formats instead: %s"), + 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'), } input_formats = api_settings.DATE_INPUT_FORMATS format = api_settings.DATE_FORMAT @@ -551,8 +551,7 @@ class DateField(Field): return parsed.date() humanized_format = humanize_datetime.date_formats(self.input_formats) - msg = self.error_messages['invalid'] % humanized_format - raise ValidationError(msg) + self.fail('invalid', format=humanized_format) def to_representation(self, value): if value is None or self.format is None: @@ -568,7 +567,7 @@ class DateField(Field): class DateTimeField(Field): default_error_messages = { - 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), + 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'), } input_formats = api_settings.DATETIME_INPUT_FORMATS format = api_settings.DATETIME_FORMAT @@ -617,8 +616,7 @@ class DateTimeField(Field): return parsed humanized_format = humanize_datetime.datetime_formats(self.input_formats) - msg = self.error_messages['invalid'] % humanized_format - raise ValidationError(msg) + self.fail('invalid', format=humanized_format) def to_representation(self, value): if value is None or self.format is None: @@ -634,7 +632,7 @@ class DateTimeField(Field): class TimeField(Field): default_error_messages = { - 'invalid': _("Time has wrong format. Use one of these formats instead: %s"), + 'invalid': _('Time has wrong format. Use one of these formats instead: {format}'), } input_formats = api_settings.TIME_INPUT_FORMATS format = api_settings.TIME_FORMAT @@ -669,8 +667,7 @@ class TimeField(Field): return parsed.time() humanized_format = humanize_datetime.time_formats(self.input_formats) - msg = self.error_messages['invalid'] % humanized_format - raise ValidationError(msg) + self.fail('invalid', format=humanized_format) def to_representation(self, value): if value is None or self.format is None: diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 338d56a6..eb6b64ef 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -216,7 +216,7 @@ class GenericAPIView(views.APIView): ) queryset = self.queryset - if isinstance(self.queryset, QuerySet): + if isinstance(queryset, QuerySet): # Ensure queryset is re-evaluated on each request. queryset = queryset.all() return queryset diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 30a252db..e23a4152 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -2,28 +2,35 @@ from rest_framework.fields import Field from rest_framework.reverse import reverse from django.core.exceptions import ObjectDoesNotExist from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch +from django.db.models.query import QuerySet from rest_framework.compat import urlparse -def get_default_queryset(serializer_class, field_name): - manager = getattr(serializer_class.opts.model, field_name) - if hasattr(manager, 'related'): - # Forward relationships - return manager.related.model._default_manager.all() - # Reverse relationships - return manager.field.rel.to._default_manager.all() - - class RelatedField(Field): def __init__(self, **kwargs): self.queryset = kwargs.pop('queryset', None) self.many = kwargs.pop('many', False) + assert self.queryset is not None or kwargs.get('read_only', False), ( + 'Relational field must provide a `queryset` argument, ' + 'or set read_only=`True`.' + ) super(RelatedField, self).__init__(**kwargs) - def bind(self, field_name, parent, root): - super(RelatedField, self).bind(field_name, parent, root) - if self.queryset is None and not self.read_only: - self.queryset = get_default_queryset(parent, self.source) + def get_queryset(self): + queryset = self.queryset + if isinstance(queryset, QuerySet): + # Ensure queryset is re-evaluated whenever used. + queryset = queryset.all() + return queryset + + +class StringRelatedField(Field): + def __init__(self, **kwargs): + kwargs['read_only'] = True + super(StringRelatedField, self).__init__(**kwargs) + + def to_representation(self, value): + return str(value) class PrimaryKeyRelatedField(RelatedField): @@ -33,9 +40,9 @@ class PrimaryKeyRelatedField(RelatedField): 'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.', } - def from_native(self, data): + def to_internal_value(self, data): try: - return self.queryset.get(pk=data) + return self.get_queryset().get(pk=data) except ObjectDoesNotExist: self.fail('does_not_exist', pk_value=data) except (TypeError, ValueError): @@ -68,9 +75,9 @@ class HyperlinkedRelatedField(RelatedField): """ lookup_value = view_kwargs[self.lookup_url_kwarg] lookup_kwargs = {self.lookup_field: lookup_value} - return self.queryset.get(**lookup_kwargs) + return self.get_queryset().get(**lookup_kwargs) - def from_native(self, value): + def to_internal_value(self, value): try: http_prefix = value.startswith(('http:', 'https:')) except AttributeError: @@ -102,13 +109,26 @@ class HyperlinkedIdentityField(RelatedField): def __init__(self, **kwargs): kwargs['read_only'] = True + kwargs['source'] = '*' self.view_name = kwargs.pop('view_name') self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) super(HyperlinkedIdentityField, self).__init__(**kwargs) - def get_attribute(self, instance): - return instance + def get_url(self, obj, view_name, request, format): + """ + Given an object, return the URL that hyperlinks to the object. + + May raise a `NoReverseMatch` if the `view_name` and `lookup_field` + attributes are not configured to correctly match the URL conf. + """ + # Unsaved objects will not yet have a valid URL. + if obj.pk is None: + return None + + lookup_value = getattr(obj, self.lookup_field) + kwargs = {self.lookup_url_kwarg: lookup_value} + return reverse(view_name, kwargs=kwargs, request=request, format=format) def to_representation(self, value): request = self.context.get('request', None) @@ -144,21 +164,6 @@ class HyperlinkedIdentityField(RelatedField): ) raise Exception(msg % self.view_name) - def get_url(self, obj, view_name, request, format): - """ - Given an object, return the URL that hyperlinks to the object. - - May raise a `NoReverseMatch` if the `view_name` and `lookup_field` - attributes are not configured to correctly match the URL conf. - """ - # Unsaved objects will not yet have a valid URL. - if obj.pk is None: - return None - - lookup_value = getattr(obj, self.lookup_field) - kwargs = {self.lookup_url_kwarg: lookup_value} - return reverse(view_name, kwargs=kwargs, request=request, format=format) - class SlugRelatedField(RelatedField): def __init__(self, **kwargs): -- cgit v1.2.3 From 5e39e159ee6aee90755709cfecec7d22c7ea3049 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 12 Sep 2014 11:38:22 +0100 Subject: UNICODE_JSON and COMPACT_JSON settings --- rest_framework/fields.py | 20 ++++++++++---------- rest_framework/parsers.py | 2 +- rest_framework/renderers.py | 33 ++++++++++++--------------------- rest_framework/settings.py | 3 +++ 4 files changed, 26 insertions(+), 32 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 4f06d186..9d96cf5c 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -137,6 +137,16 @@ class Field(object): messages.update(error_messages or {}) self.error_messages = messages + def __new__(cls, *args, **kwargs): + """ + When a field is instantiated, we store the arguments that were used, + so that we can present a helpful representation of the object. + """ + instance = super(Field, cls).__new__(cls) + instance._args = args + instance._kwargs = kwargs + return instance + def bind(self, field_name, parent, root): """ Setup the context for the field instance. @@ -249,16 +259,6 @@ class Field(object): raise AssertionError(msg) raise ValidationError(msg.format(**kwargs)) - def __new__(cls, *args, **kwargs): - """ - When a field is instantiated, we store the arguments that were used, - so that we can present a helpful representation of the object. - """ - instance = super(Field, cls).__new__(cls) - instance._args = args - instance._kwargs = kwargs - return instance - def __repr__(self): return representation.field_repr(self) diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index c287908d..fa02ecf1 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -48,7 +48,7 @@ class JSONParser(BaseParser): """ media_type = 'application/json' - renderer_class = renderers.UnicodeJSONRenderer + renderer_class = renderers.JSONRenderer def parse(self, stream, media_type=None, parser_context=None): """ diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index dfc5a39f..3bf03e62 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -26,6 +26,10 @@ from rest_framework.utils.breadcrumbs import get_breadcrumbs from rest_framework import exceptions, status, VERSION +def zero_as_none(value): + return None if value == 0 else value + + class BaseRenderer(object): """ All renderers should extend this class, setting the `media_type` @@ -44,13 +48,13 @@ class BaseRenderer(object): class JSONRenderer(BaseRenderer): """ Renderer which serializes to JSON. - Applies JSON's backslash-u character escaping for non-ascii characters. """ media_type = 'application/json' format = 'json' encoder_class = encoders.JSONEncoder - ensure_ascii = True + ensure_ascii = not api_settings.UNICODE_JSON + compact = api_settings.COMPACT_JSON # We don't set a charset because JSON is a binary encoding, # that can be encoded as utf-8, utf-16 or utf-32. @@ -62,9 +66,10 @@ class JSONRenderer(BaseRenderer): if accepted_media_type: # If the media type looks like 'application/json; indent=4', # then pretty print the result. + # Note that we coerce `indent=0` into `indent=None`. base_media_type, params = parse_header(accepted_media_type.encode('ascii')) try: - return max(min(int(params['indent']), 8), 0) + return zero_as_none(max(min(int(params['indent']), 8), 0)) except (KeyError, ValueError, TypeError): pass @@ -81,10 +86,12 @@ class JSONRenderer(BaseRenderer): renderer_context = renderer_context or {} indent = self.get_indent(accepted_media_type, renderer_context) + separators = (',', ':') if (indent is None and self.compact) else (', ', ': ') ret = json.dumps( data, cls=self.encoder_class, - indent=indent, ensure_ascii=self.ensure_ascii + indent=indent, ensure_ascii=self.ensure_ascii, + separators=separators ) # On python 2.x json.dumps() returns bytestrings if ensure_ascii=True, @@ -96,14 +103,6 @@ class JSONRenderer(BaseRenderer): return ret -class UnicodeJSONRenderer(JSONRenderer): - ensure_ascii = False - """ - Renderer which serializes to JSON. - Does *not* apply JSON's character escaping for non-ascii characters. - """ - - class JSONPRenderer(JSONRenderer): """ Renderer which serializes to json, @@ -196,7 +195,7 @@ class YAMLRenderer(BaseRenderer): format = 'yaml' encoder = encoders.SafeDumper charset = 'utf-8' - ensure_ascii = True + ensure_ascii = False def render(self, data, accepted_media_type=None, renderer_context=None): """ @@ -210,14 +209,6 @@ class YAMLRenderer(BaseRenderer): return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii) -class UnicodeYAMLRenderer(YAMLRenderer): - """ - Renderer which serializes to YAML. - Does *not* apply character escaping for non-ascii characters. - """ - ensure_ascii = False - - class TemplateHTMLRenderer(BaseRenderer): """ An HTML renderer for use with templates. diff --git a/rest_framework/settings.py b/rest_framework/settings.py index f48643b5..e55610bb 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -112,6 +112,9 @@ DEFAULTS = { ), 'TIME_FORMAT': None, + # Encoding + 'UNICODE_JSON': True, + 'COMPACT_JSON': True } -- cgit v1.2.3 From 22af49bf8ffc73afc9b638f1b9cd2e909c6c89a8 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 12 Sep 2014 11:50:20 +0100 Subject: Tidy up JSONEncoder --- rest_framework/utils/encoders.py | 67 ++++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 33 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 7992b6b1..174b08b8 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -7,7 +7,6 @@ from django.db.models.query import QuerySet from django.utils.datastructures import SortedDict from django.utils.functional import Promise from rest_framework.compat import force_text -# from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata import datetime import decimal import types @@ -17,45 +16,47 @@ import json class JSONEncoder(json.JSONEncoder): """ JSONEncoder subclass that knows how to encode date/time/timedelta, - decimal types, and generators. + decimal types, generators and other basic python objects. """ - def default(self, o): + def default(self, obj): # For Date Time string spec, see ECMA 262 # http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15 - if isinstance(o, Promise): - return force_text(o) - elif isinstance(o, datetime.datetime): - r = o.isoformat() - if o.microsecond: - r = r[:23] + r[26:] - if r.endswith('+00:00'): - r = r[:-6] + 'Z' - return r - elif isinstance(o, datetime.date): - return o.isoformat() - elif isinstance(o, datetime.time): - if timezone and timezone.is_aware(o): + if isinstance(obj, Promise): + return force_text(obj) + elif isinstance(obj, datetime.datetime): + representation = obj.isoformat() + if obj.microsecond: + representation = representation[:23] + representation[26:] + if representation.endswith('+00:00'): + representation = representation[:-6] + 'Z' + return representation + elif isinstance(obj, datetime.date): + return obj.isoformat() + elif isinstance(obj, datetime.time): + if timezone and timezone.is_aware(obj): raise ValueError("JSON can't represent timezone-aware times.") - r = o.isoformat() - if o.microsecond: - r = r[:12] - return r - elif isinstance(o, datetime.timedelta): - return str(o.total_seconds()) - elif isinstance(o, decimal.Decimal): - return float(o) - elif isinstance(o, QuerySet): - return list(o) - elif hasattr(o, 'tolist'): - return o.tolist() - elif hasattr(o, '__getitem__'): + representation = obj.isoformat() + if obj.microsecond: + representation = representation[:12] + return representation + elif isinstance(obj, datetime.timedelta): + return str(obj.total_seconds()) + elif isinstance(obj, decimal.Decimal): + # Serializers will coerce decimals to strings by default. + return float(obj) + elif isinstance(obj, QuerySet): + return list(obj) + elif hasattr(obj, 'tolist'): + # Numpy arrays and array scalars. + return obj.tolist() + elif hasattr(obj, '__getitem__'): try: - return dict(o) + return dict(obj) except: pass - elif hasattr(o, '__iter__'): - return [i for i in o] - return super(JSONEncoder, self).default(o) + elif hasattr(obj, '__iter__'): + return [item for item in obj] + return super(JSONEncoder, self).default(obj) try: -- cgit v1.2.3 From 79715f01f8c34fdd55c2291b6b21d09fa3a8153e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 12 Sep 2014 12:10:22 +0100 Subject: Coerce dates etc to ISO_8601 in seralizer, by default. --- rest_framework/fields.py | 24 +++++++++++++----------- rest_framework/settings.py | 21 ++++++++------------- 2 files changed, 21 insertions(+), 24 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 9d96cf5c..e1855ff7 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -431,10 +431,12 @@ class DecimalField(Field): 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.') } - def __init__(self, max_digits, decimal_places, coerce_to_string=True, max_value=None, min_value=None, **kwargs): + coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING + + def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, **kwargs): self.max_digits = max_digits self.decimal_places = decimal_places - self.coerce_to_string = coerce_to_string + self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string super(DecimalField, self).__init__(**kwargs) if max_value is not None: self.validators.append(validators.MaxValueValidator(max_value)) @@ -510,12 +512,12 @@ class DateField(Field): default_error_messages = { 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'), } - input_formats = api_settings.DATE_INPUT_FORMATS format = api_settings.DATE_FORMAT + input_formats = api_settings.DATE_INPUT_FORMATS - def __init__(self, input_formats=None, format=None, *args, **kwargs): - self.input_formats = input_formats if input_formats is not None else self.input_formats + def __init__(self, format=None, input_formats=None, *args, **kwargs): self.format = format if format is not None else self.format + self.input_formats = input_formats if input_formats is not None else self.input_formats super(DateField, self).__init__(*args, **kwargs) def to_internal_value(self, value): @@ -569,12 +571,12 @@ class DateTimeField(Field): default_error_messages = { 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'), } - input_formats = api_settings.DATETIME_INPUT_FORMATS format = api_settings.DATETIME_FORMAT + input_formats = api_settings.DATETIME_INPUT_FORMATS - def __init__(self, input_formats=None, format=None, *args, **kwargs): - self.input_formats = input_formats if input_formats is not None else self.input_formats + def __init__(self, format=None, input_formats=None, *args, **kwargs): self.format = format if format is not None else self.format + self.input_formats = input_formats if input_formats is not None else self.input_formats super(DateTimeField, self).__init__(*args, **kwargs) def to_internal_value(self, value): @@ -634,12 +636,12 @@ class TimeField(Field): default_error_messages = { 'invalid': _('Time has wrong format. Use one of these formats instead: {format}'), } - input_formats = api_settings.TIME_INPUT_FORMATS format = api_settings.TIME_FORMAT + input_formats = api_settings.TIME_INPUT_FORMATS - def __init__(self, input_formats=None, format=None, *args, **kwargs): - self.input_formats = input_formats if input_formats is not None else self.input_formats + def __init__(self, format=None, input_formats=None, *args, **kwargs): self.format = format if format is not None else self.format + self.input_formats = input_formats if input_formats is not None else self.input_formats super(TimeField, self).__init__(*args, **kwargs) def from_native(self, value): diff --git a/rest_framework/settings.py b/rest_framework/settings.py index e55610bb..421e146c 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -97,24 +97,19 @@ DEFAULTS = { 'URL_FIELD_NAME': 'url', # Input and output formats - 'DATE_INPUT_FORMATS': ( - ISO_8601, - ), - 'DATE_FORMAT': None, + 'DATE_FORMAT': ISO_8601, + 'DATE_INPUT_FORMATS': (ISO_8601,), - 'DATETIME_INPUT_FORMATS': ( - ISO_8601, - ), - 'DATETIME_FORMAT': None, + 'DATETIME_FORMAT': ISO_8601, + 'DATETIME_INPUT_FORMATS': (ISO_8601,), - 'TIME_INPUT_FORMATS': ( - ISO_8601, - ), - 'TIME_FORMAT': None, + 'TIME_FORMAT': ISO_8601, + 'TIME_INPUT_FORMATS': (ISO_8601,), # Encoding 'UNICODE_JSON': True, - 'COMPACT_JSON': True + 'COMPACT_JSON': True, + 'COERCE_DECIMAL_TO_STRING': True } -- cgit v1.2.3 From b73a205cc021983d9a508b447f30e144a1ce4129 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 12 Sep 2014 17:03:42 +0100 Subject: Tests for relational fields (not including many=True) --- rest_framework/relations.py | 143 +++++++++++++++++++++++++++++--------------- 1 file changed, 94 insertions(+), 49 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index e23a4152..75ec89a8 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,19 +1,24 @@ +from rest_framework.compat import smart_text, urlparse from rest_framework.fields import Field from rest_framework.reverse import reverse -from django.core.exceptions import ObjectDoesNotExist +from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch from django.db.models.query import QuerySet -from rest_framework.compat import urlparse +from django.utils.translation import ugettext_lazy as _ class RelatedField(Field): def __init__(self, **kwargs): self.queryset = kwargs.pop('queryset', None) self.many = kwargs.pop('many', False) - assert self.queryset is not None or kwargs.get('read_only', False), ( + assert self.queryset is not None or kwargs.get('read_only', None), ( 'Relational field must provide a `queryset` argument, ' 'or set read_only=`True`.' ) + assert not (self.queryset is not None and kwargs.get('read_only', None)), ( + 'Relational fields should not provide a `queryset` argument, ' + 'when setting read_only=`True`.' + ) super(RelatedField, self).__init__(**kwargs) def get_queryset(self): @@ -25,6 +30,11 @@ class RelatedField(Field): class StringRelatedField(Field): + """ + A read only field that represents its targets using their + plain string representation. + """ + def __init__(self, **kwargs): kwargs['read_only'] = True super(StringRelatedField, self).__init__(**kwargs) @@ -34,10 +44,10 @@ class StringRelatedField(Field): class PrimaryKeyRelatedField(RelatedField): - MESSAGES = { + default_error_messages = { 'required': 'This field is required.', 'does_not_exist': "Invalid pk '{pk_value}' - object does not exist.", - 'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.', + 'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.', } def to_internal_value(self, data): @@ -48,22 +58,33 @@ class PrimaryKeyRelatedField(RelatedField): except (TypeError, ValueError): self.fail('incorrect_type', data_type=type(data).__name__) + def to_representation(self, value): + return value.pk + class HyperlinkedRelatedField(RelatedField): lookup_field = 'pk' - MESSAGES = { + default_error_messages = { 'required': 'This field is required.', 'no_match': 'Invalid hyperlink - No URL match', 'incorrect_match': 'Invalid hyperlink - Incorrect URL match.', - 'does_not_exist': "Invalid hyperlink - Object does not exist.", - 'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.', + 'does_not_exist': 'Invalid hyperlink - Object does not exist.', + 'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.', } - def __init__(self, **kwargs): - self.view_name = kwargs.pop('view_name') + def __init__(self, view_name, **kwargs): + self.view_name = view_name self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) + self.format = kwargs.pop('format', None) + + # We include these simply for dependancy injection in tests. + # We can't add them as class attributes or they would expect an + # implict `self` argument to be passed. + self.reverse = reverse + self.resolve = resolve + super(HyperlinkedRelatedField, self).__init__(**kwargs) def get_object(self, view_name, view_args, view_kwargs): @@ -77,21 +98,36 @@ class HyperlinkedRelatedField(RelatedField): lookup_kwargs = {self.lookup_field: lookup_value} return self.get_queryset().get(**lookup_kwargs) - def to_internal_value(self, value): + def get_url(self, obj, view_name, request, format): + """ + Given an object, return the URL that hyperlinks to the object. + + May raise a `NoReverseMatch` if the `view_name` and `lookup_field` + attributes are not configured to correctly match the URL conf. + """ + # Unsaved objects will not yet have a valid URL. + if obj.pk is None: + return None + + lookup_value = getattr(obj, self.lookup_field) + kwargs = {self.lookup_url_kwarg: lookup_value} + return self.reverse(view_name, kwargs=kwargs, request=request, format=format) + + def to_internal_value(self, data): try: - http_prefix = value.startswith(('http:', 'https:')) + http_prefix = data.startswith(('http:', 'https:')) except AttributeError: - self.fail('incorrect_type', data_type=type(value).__name__) + self.fail('incorrect_type', data_type=type(data).__name__) if http_prefix: # If needed convert absolute URLs to relative path - value = urlparse.urlparse(value).path + data = urlparse.urlparse(data).path prefix = get_script_prefix() - if value.startswith(prefix): - value = '/' + value[len(prefix):] + if data.startswith(prefix): + data = '/' + data[len(prefix):] try: - match = resolve(value) + match = self.resolve(data) except Exception: self.fail('no_match') @@ -103,41 +139,14 @@ class HyperlinkedRelatedField(RelatedField): except (ObjectDoesNotExist, TypeError, ValueError): self.fail('does_not_exist') - -class HyperlinkedIdentityField(RelatedField): - lookup_field = 'pk' - - def __init__(self, **kwargs): - kwargs['read_only'] = True - kwargs['source'] = '*' - self.view_name = kwargs.pop('view_name') - self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) - self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) - super(HyperlinkedIdentityField, self).__init__(**kwargs) - - def get_url(self, obj, view_name, request, format): - """ - Given an object, return the URL that hyperlinks to the object. - - May raise a `NoReverseMatch` if the `view_name` and `lookup_field` - attributes are not configured to correctly match the URL conf. - """ - # Unsaved objects will not yet have a valid URL. - if obj.pk is None: - return None - - lookup_value = getattr(obj, self.lookup_field) - kwargs = {self.lookup_url_kwarg: lookup_value} - return reverse(view_name, kwargs=kwargs, request=request, format=format) - def to_representation(self, value): request = self.context.get('request', None) format = self.context.get('format', None) assert request is not None, ( - "`HyperlinkedIdentityField` requires the request in the serializer" + "`%s` requires the request in the serializer" " context. Add `context={'request': request}` when instantiating " - "the serializer." + "the serializer." % self.__class__.__name__ ) # By default use whatever format is given for the current context @@ -162,9 +171,45 @@ class HyperlinkedIdentityField(RelatedField): 'model in your API, or incorrectly configured the ' '`lookup_field` attribute on this field.' ) - raise Exception(msg % self.view_name) + raise ImproperlyConfigured(msg % self.view_name) + + +class HyperlinkedIdentityField(HyperlinkedRelatedField): + """ + A read-only field that represents the identity URL for an object, itself. + + This is in contrast to `HyperlinkedRelatedField` which represents the + URL of relationships to other objects. + """ + + def __init__(self, view_name, **kwargs): + kwargs['read_only'] = True + kwargs['source'] = '*' + super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs) class SlugRelatedField(RelatedField): - def __init__(self, **kwargs): - self.slug_field = kwargs.pop('slug_field', None) + """ + A read-write field the represents the target of the relationship + by a unique 'slug' attribute. + """ + + default_error_messages = { + 'does_not_exist': _("Object with {slug_name}={value} does not exist."), + 'invalid': _('Invalid value.'), + } + + def __init__(self, slug_field, **kwargs): + self.slug_field = slug_field + super(SlugRelatedField, self).__init__(**kwargs) + + def to_internal_value(self, data): + try: + return self.get_queryset().get(**{self.slug_field: data}) + except ObjectDoesNotExist: + self.fail('does_not_exist', slug_name=self.slug_field, value=smart_text(data)) + except (TypeError, ValueError): + self.fail('invalid') + + def to_representation(self, obj): + return getattr(obj, self.slug_field) -- cgit v1.2.3 From 0ac52e0808288892717c017e57c57aa8ad81e6d3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 12 Sep 2014 17:06:37 +0100 Subject: Use Resolver404 instead of base Exception --- rest_framework/relations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 75ec89a8..46fe55ef 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -2,7 +2,7 @@ from rest_framework.compat import smart_text, urlparse from rest_framework.fields import Field from rest_framework.reverse import reverse from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured -from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch +from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch, Resolver404 from django.db.models.query import QuerySet from django.utils.translation import ugettext_lazy as _ @@ -128,7 +128,7 @@ class HyperlinkedRelatedField(RelatedField): try: match = self.resolve(data) - except Exception: + except Resolver404: self.fail('no_match') if match.view_name != self.view_name: -- cgit v1.2.3 From e6c88a423361b084ba171af7a74a183bd557e73e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 12 Sep 2014 19:54:27 +0100 Subject: Drop usage of validatiors.EMPTY_VALUES --- rest_framework/fields.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index e1855ff7..33ab0682 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -223,7 +223,7 @@ class Field(object): return value def run_validators(self, value): - if value in validators.EMPTY_VALUES: + if value in (None, '', [], (), {}): return errors = [] @@ -450,7 +450,7 @@ class DecimalField(Field): than max_digits in the number, and no more than decimal_places digits after the decimal point. """ - if value in validators.EMPTY_VALUES: + if value in (None, ''): return None value = smart_text(value).strip() @@ -521,7 +521,7 @@ class DateField(Field): super(DateField, self).__init__(*args, **kwargs) def to_internal_value(self, value): - if value in validators.EMPTY_VALUES: + if value in (None, ''): return None if isinstance(value, datetime.datetime): @@ -580,7 +580,7 @@ class DateTimeField(Field): super(DateTimeField, self).__init__(*args, **kwargs) def to_internal_value(self, value): - if value in validators.EMPTY_VALUES: + if value in (None, ''): return None if isinstance(value, datetime.datetime): @@ -645,7 +645,7 @@ class TimeField(Field): super(TimeField, self).__init__(*args, **kwargs) def from_native(self, value): - if value in validators.EMPTY_VALUES: + if value in (None, ''): return None if isinstance(value, datetime.time): -- cgit v1.2.3 From afb28a44ad1737cd6fcd6da50ba9552f38293368 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 12 Sep 2014 21:32:20 +0100 Subject: Dealing with reverse relationships --- rest_framework/serializers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 0c2aedfa..ecb2829b 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -157,7 +157,7 @@ class Serializer(BaseSerializer): def __init__(self, *args, **kwargs): self.context = kwargs.pop('context', {}) kwargs.pop('partial', None) - kwargs.pop('many', False) + kwargs.pop('many', None) super(Serializer, self).__init__(*args, **kwargs) @@ -423,9 +423,9 @@ class ModelSerializer(Serializer): for accessor_name, relation_info in info.reverse_relations.items(): if accessor_name in self.opts.fields: if self.opts.depth: - ret[field_name] = self.get_nested_field(*relation_info) + ret[accessor_name] = self.get_nested_field(*relation_info) else: - ret[field_name] = self.get_related_field(*relation_info) + ret[accessor_name] = self.get_related_field(*relation_info) return ret @@ -444,7 +444,7 @@ class ModelSerializer(Serializer): Note that model_field will be `None` for reverse relationships. """ - class NestedModelSerializer(ModelSerializer): + class NestedModelSerializer(ModelSerializer): # Not right! class Meta: model = related_model depth = self.opts.depth - 1 -- cgit v1.2.3 From 40dc588a372375608701f7e521dea6d860a49eb2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 15 Sep 2014 09:50:51 +0100 Subject: Drop label from serializer fields when not needed --- rest_framework/fields.py | 6 ++- rest_framework/serializers.py | 56 ++++++++++++--------- rest_framework/utils/model_meta.py | 99 ++++++++++++++++++++++++++++++++++++++ rest_framework/utils/modelinfo.py | 99 -------------------------------------- 4 files changed, 137 insertions(+), 123 deletions(-) create mode 100644 rest_framework/utils/model_meta.py delete mode 100644 rest_framework/utils/modelinfo.py (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 33ab0682..1818e705 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -80,6 +80,10 @@ def set_value(dictionary, keys, value): dictionary[keys[-1]] = value +def field_name_to_label(field_name): + return field_name.replace('_', ' ').capitalize() + + class SkipField(Exception): pass @@ -158,7 +162,7 @@ class Field(object): # `self.label` should deafult to being based on the field name. if self.label is None: - self.label = self.field_name.replace('_', ' ').capitalize() + self.label = field_name_to_label(self.field_name) # self.source should default to being the same as the field name. if self.source is None: diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index ecb2829b..ba8d475f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -15,11 +15,12 @@ from django.core.exceptions import ValidationError from django.db import models from django.utils import six from django.utils.datastructures import SortedDict +from django.utils.text import capfirst from collections import namedtuple from rest_framework.compat import clean_manytomany_helptext from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings -from rest_framework.utils import html, modelinfo, representation +from rest_framework.utils import html, model_meta, representation import copy # Note: We do the following so that users of the framework can use this style: @@ -334,6 +335,14 @@ def lookup_class(mapping, instance): raise KeyError('Class %s not found in lookup.', cls.__name__) +def needs_label(model_field, field_name): + """ + Returns `True` if the label based on the model's verbose name + is not equal to the default label it would have based on it's field name. + """ + return capfirst(model_field.verbose_name) != field_name_to_label(field_name) + + class ModelSerializer(Serializer): field_mapping = { models.AutoField: IntegerField, @@ -397,54 +406,55 @@ class ModelSerializer(Serializer): """ Return all the fields that should be serialized for the model. """ - info = modelinfo.get_field_info(self.opts.model) + info = model_meta.get_field_info(self.opts.model) ret = SortedDict() serializer_url_field = self.get_url_field() if serializer_url_field: ret[api_settings.URL_FIELD_NAME] = serializer_url_field - serializer_pk_field = self.get_pk_field(info.pk) + field_name = info.pk.name + serializer_pk_field = self.get_pk_field(field_name, info.pk) if serializer_pk_field: - ret[info.pk.name] = serializer_pk_field + ret[field_name] = serializer_pk_field # Regular fields for field_name, field in info.fields.items(): - ret[field_name] = self.get_field(field) + ret[field_name] = self.get_field(field_name, field) # Forward relations for field_name, relation_info in info.forward_relations.items(): if self.opts.depth: - ret[field_name] = self.get_nested_field(*relation_info) + ret[field_name] = self.get_nested_field(field_name, *relation_info) else: - ret[field_name] = self.get_related_field(*relation_info) + ret[field_name] = self.get_related_field(field_name, *relation_info) # Reverse relations for accessor_name, relation_info in info.reverse_relations.items(): if accessor_name in self.opts.fields: if self.opts.depth: - ret[accessor_name] = self.get_nested_field(*relation_info) + ret[accessor_name] = self.get_nested_field(accessor_name, *relation_info) else: - ret[accessor_name] = self.get_related_field(*relation_info) + ret[accessor_name] = self.get_related_field(accessor_name, *relation_info) return ret def get_url_field(self): return None - def get_pk_field(self, model_field): + def get_pk_field(self, field_name, model_field): """ Returns a default instance of the pk field. """ - return self.get_field(model_field) + return self.get_field(field_name, model_field) - def get_nested_field(self, model_field, related_model, to_many, has_through_model): + def get_nested_field(self, field_name, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a nested relational field. Note that model_field will be `None` for reverse relationships. """ - class NestedModelSerializer(ModelSerializer): # Not right! + class NestedModelSerializer(ModelSerializer): class Meta: model = related_model depth = self.opts.depth - 1 @@ -454,7 +464,7 @@ class ModelSerializer(Serializer): kwargs['many'] = True return NestedModelSerializer(**kwargs) - def get_related_field(self, model_field, related_model, to_many, has_through_model): + def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a flat relational field. @@ -474,8 +484,8 @@ class ModelSerializer(Serializer): if model_field: if model_field.null or model_field.blank: kwargs['required'] = False - if model_field.verbose_name: - kwargs['label'] = model_field.verbose_name + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) if not model_field.editable: kwargs['read_only'] = True kwargs.pop('queryset', None) @@ -485,7 +495,7 @@ class ModelSerializer(Serializer): return PrimaryKeyRelatedField(**kwargs) - def get_field(self, model_field): + def get_field(self, field_name, model_field): """ Creates a default instance of a basic non-relational field. """ @@ -496,8 +506,8 @@ class ModelSerializer(Serializer): if model_field.null or model_field.blank: kwargs['required'] = False - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) if model_field.help_text: kwargs['help_text'] = model_field.help_text @@ -642,11 +652,11 @@ class HyperlinkedModelSerializer(ModelSerializer): return HyperlinkedIdentityField(**kwargs) - def get_pk_field(self, model_field): + def get_pk_field(self, field_name, model_field): if self.opts.fields and model_field.name in self.opts.fields: return self.get_field(model_field) - def get_related_field(self, model_field, related_model, to_many, has_through_model): + def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a flat relational field. """ @@ -665,8 +675,8 @@ class HyperlinkedModelSerializer(ModelSerializer): if model_field: if model_field.null or model_field.blank: kwargs['required'] = False - if model_field.verbose_name: - kwargs['label'] = model_field.verbose_name + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) if not model_field.editable: kwargs['read_only'] = True kwargs.pop('queryset', None) diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py new file mode 100644 index 00000000..960fa4d0 --- /dev/null +++ b/rest_framework/utils/model_meta.py @@ -0,0 +1,99 @@ +""" +Helper functions for returning the field information that is associated +with a model class. This includes returning all the forward and reverse +relationships and their associated metadata. +""" +from collections import namedtuple +from django.db import models +from django.utils import six +from django.utils.datastructures import SortedDict +import inspect + +FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations']) +RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model']) + + +def _resolve_model(obj): + """ + Resolve supplied `obj` to a Django model class. + + `obj` must be a Django model class itself, or a string + representation of one. Useful in situtations like GH #1225 where + Django may not have resolved a string-based reference to a model in + another model's foreign key definition. + + String representations should have the format: + 'appname.ModelName' + """ + if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: + app_name, model_name = obj.split('.') + return models.get_model(app_name, model_name) + elif inspect.isclass(obj) and issubclass(obj, models.Model): + return obj + raise ValueError("{0} is not a Django model".format(obj)) + + +def get_field_info(model): + """ + Given a model class, returns a `FieldInfo` instance containing metadata + about the various field types on the model. + """ + opts = model._meta.concrete_model._meta + + # Deal with the primary key. + pk = opts.pk + while pk.rel and pk.rel.parent_link: + # If model is a child via multitable inheritance, use parent's pk. + pk = pk.rel.to._meta.pk + + # Deal with regular fields. + fields = SortedDict() + for field in [field for field in opts.fields if field.serialize and not field.rel]: + fields[field.name] = field + + # Deal with forward relationships. + forward_relations = SortedDict() + for field in [field for field in opts.fields if field.serialize and field.rel]: + forward_relations[field.name] = RelationInfo( + field=field, + related=_resolve_model(field.rel.to), + to_many=False, + has_through_model=False + ) + + # Deal with forward many-to-many relationships. + for field in [field for field in opts.many_to_many if field.serialize]: + forward_relations[field.name] = RelationInfo( + field=field, + related=_resolve_model(field.rel.to), + to_many=True, + has_through_model=( + not field.rel.through._meta.auto_created + ) + ) + + # Deal with reverse relationships. + reverse_relations = SortedDict() + for relation in opts.get_all_related_objects(): + accessor_name = relation.get_accessor_name() + reverse_relations[accessor_name] = RelationInfo( + field=None, + related=relation.model, + to_many=relation.field.rel.multiple, + has_through_model=False + ) + + # Deal with reverse many-to-many relationships. + for relation in opts.get_all_related_many_to_many_objects(): + accessor_name = relation.get_accessor_name() + reverse_relations[accessor_name] = RelationInfo( + field=None, + related=relation.model, + to_many=True, + has_through_model=( + hasattr(relation.field.rel, 'through') and + not relation.field.rel.through._meta.auto_created + ) + ) + + return FieldInfo(pk, fields, forward_relations, reverse_relations) diff --git a/rest_framework/utils/modelinfo.py b/rest_framework/utils/modelinfo.py deleted file mode 100644 index 960fa4d0..00000000 --- a/rest_framework/utils/modelinfo.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -Helper functions for returning the field information that is associated -with a model class. This includes returning all the forward and reverse -relationships and their associated metadata. -""" -from collections import namedtuple -from django.db import models -from django.utils import six -from django.utils.datastructures import SortedDict -import inspect - -FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations']) -RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model']) - - -def _resolve_model(obj): - """ - Resolve supplied `obj` to a Django model class. - - `obj` must be a Django model class itself, or a string - representation of one. Useful in situtations like GH #1225 where - Django may not have resolved a string-based reference to a model in - another model's foreign key definition. - - String representations should have the format: - 'appname.ModelName' - """ - if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: - app_name, model_name = obj.split('.') - return models.get_model(app_name, model_name) - elif inspect.isclass(obj) and issubclass(obj, models.Model): - return obj - raise ValueError("{0} is not a Django model".format(obj)) - - -def get_field_info(model): - """ - Given a model class, returns a `FieldInfo` instance containing metadata - about the various field types on the model. - """ - opts = model._meta.concrete_model._meta - - # Deal with the primary key. - pk = opts.pk - while pk.rel and pk.rel.parent_link: - # If model is a child via multitable inheritance, use parent's pk. - pk = pk.rel.to._meta.pk - - # Deal with regular fields. - fields = SortedDict() - for field in [field for field in opts.fields if field.serialize and not field.rel]: - fields[field.name] = field - - # Deal with forward relationships. - forward_relations = SortedDict() - for field in [field for field in opts.fields if field.serialize and field.rel]: - forward_relations[field.name] = RelationInfo( - field=field, - related=_resolve_model(field.rel.to), - to_many=False, - has_through_model=False - ) - - # Deal with forward many-to-many relationships. - for field in [field for field in opts.many_to_many if field.serialize]: - forward_relations[field.name] = RelationInfo( - field=field, - related=_resolve_model(field.rel.to), - to_many=True, - has_through_model=( - not field.rel.through._meta.auto_created - ) - ) - - # Deal with reverse relationships. - reverse_relations = SortedDict() - for relation in opts.get_all_related_objects(): - accessor_name = relation.get_accessor_name() - reverse_relations[accessor_name] = RelationInfo( - field=None, - related=relation.model, - to_many=relation.field.rel.multiple, - has_through_model=False - ) - - # Deal with reverse many-to-many relationships. - for relation in opts.get_all_related_many_to_many_objects(): - accessor_name = relation.get_accessor_name() - reverse_relations[accessor_name] = RelationInfo( - field=None, - related=relation.model, - to_many=True, - has_through_model=( - hasattr(relation.field.rel, 'through') and - not relation.field.rel.through._meta.auto_created - ) - ) - - return FieldInfo(pk, fields, forward_relations, reverse_relations) -- cgit v1.2.3 From d196608d5af912057baba79ab13d05d876368ad2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 15 Sep 2014 13:55:09 +0100 Subject: Fix nested model serializer base class --- rest_framework/serializers.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index ba8d475f..40d76897 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -368,6 +368,7 @@ class ModelSerializer(Serializer): models.TimeField: TimeField, models.URLField: URLField, } + nested_class = None # We fill this in at the end of this module. _options_class = ModelSerializerOptions @@ -454,7 +455,7 @@ class ModelSerializer(Serializer): Note that model_field will be `None` for reverse relationships. """ - class NestedModelSerializer(ModelSerializer): + class NestedModelSerializer(self.nested_class): class Meta: model = related_model depth = self.opts.depth - 1 @@ -694,3 +695,7 @@ class HyperlinkedModelSerializer(ModelSerializer): 'app_label': model._meta.app_label, 'model_name': model._meta.object_name.lower() } + + +ModelSerializer.nested_class = ModelSerializer +HyperlinkedModelSerializer.nested_class = HyperlinkedModelSerializer -- cgit v1.2.3 From c0155fd9dc654dc5932effd46a00f66495ce700b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Sep 2014 14:11:53 +0100 Subject: Update comments --- rest_framework/serializers.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 40d76897..1fea1380 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -410,10 +410,12 @@ class ModelSerializer(Serializer): info = model_meta.get_field_info(self.opts.model) ret = SortedDict() + # URL field serializer_url_field = self.get_url_field() if serializer_url_field: ret[api_settings.URL_FIELD_NAME] = serializer_url_field + # Primary key field field_name = info.pk.name serializer_pk_field = self.get_pk_field(field_name, info.pk) if serializer_pk_field: -- cgit v1.2.3 From 5b7e4af0d657a575cb15eea85a63a7100c636085 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 18 Sep 2014 11:20:56 +0100 Subject: get_base_field() refactor --- rest_framework/fields.py | 6 +- rest_framework/relations.py | 9 +- rest_framework/serializers.py | 464 ++++++++-------------------------- rest_framework/utils/field_mapping.py | 215 ++++++++++++++++ rest_framework/utils/model_meta.py | 46 +++- 5 files changed, 364 insertions(+), 376 deletions(-) create mode 100644 rest_framework/utils/field_mapping.py (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1818e705..0c78b3fb 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -80,10 +80,6 @@ def set_value(dictionary, keys, value): dictionary[keys[-1]] = value -def field_name_to_label(field_name): - return field_name.replace('_', ' ').capitalize() - - class SkipField(Exception): pass @@ -162,7 +158,7 @@ class Field(object): # `self.label` should deafult to being based on the field name. if self.label is None: - self.label = field_name_to_label(self.field_name) + self.label = field_name.replace('_', ' ').capitalize() # self.source should default to being the same as the field name. if self.source is None: diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 46fe55ef..9f44ab63 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -73,7 +73,8 @@ class HyperlinkedRelatedField(RelatedField): 'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.', } - def __init__(self, view_name, **kwargs): + def __init__(self, view_name=None, **kwargs): + assert view_name is not None, 'The `view_name` argument is required.' self.view_name = view_name self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) @@ -182,7 +183,8 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField): URL of relationships to other objects. """ - def __init__(self, view_name, **kwargs): + def __init__(self, view_name=None, **kwargs): + assert view_name is not None, 'The `view_name` argument is required.' kwargs['read_only'] = True kwargs['source'] = '*' super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs) @@ -199,7 +201,8 @@ class SlugRelatedField(RelatedField): 'invalid': _('Invalid value.'), } - def __init__(self, slug_field, **kwargs): + def __init__(self, slug_field=None, **kwargs): + assert slug_field is not None, 'The `slug_field` argument is required.' self.slug_field = slug_field super(SlugRelatedField, self).__init__(**kwargs) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 1fea1380..99dcc349 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,17 +10,19 @@ python primitives. 2. The process of marshalling between python primitives and request and response content is handled by parsers and renderers. """ -from django.core import validators from django.core.exceptions import ValidationError from django.db import models from django.utils import six from django.utils.datastructures import SortedDict -from django.utils.text import capfirst from collections import namedtuple -from rest_framework.compat import clean_manytomany_helptext from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings from rest_framework.utils import html, model_meta, representation +from rest_framework.utils.field_mapping import ( + get_url_kwargs, get_field_kwargs, + get_relation_kwargs, get_nested_relation_kwargs, + lookup_class +) import copy # Note: We do the following so that users of the framework can use this style: @@ -126,7 +128,7 @@ class SerializerMetaclass(type): """ @classmethod - def _get_fields(cls, bases, attrs): + def _get_declared_fields(cls, bases, attrs): fields = [(field_name, attrs.pop(field_name)) for field_name, obj in list(attrs.items()) if isinstance(obj, Field)] @@ -136,25 +138,18 @@ class SerializerMetaclass(type): # fields. Note that we loop over the bases in *reverse*. This is necessary # in order to maintain the correct order of fields. for base in bases[::-1]: - if hasattr(base, 'base_fields'): - fields = list(base.base_fields.items()) + fields + if hasattr(base, '_declared_fields'): + fields = list(base._declared_fields.items()) + fields return SortedDict(fields) def __new__(cls, name, bases, attrs): - attrs['base_fields'] = cls._get_fields(bases, attrs) + attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) @six.add_metaclass(SerializerMetaclass) class Serializer(BaseSerializer): - - def __new__(cls, *args, **kwargs): - if kwargs.pop('many', False): - kwargs['child'] = cls() - return ListSerializer(*args, **kwargs) - return super(Serializer, cls).__new__(cls, *args, **kwargs) - def __init__(self, *args, **kwargs): self.context = kwargs.pop('context', {}) kwargs.pop('partial', None) @@ -165,14 +160,22 @@ class Serializer(BaseSerializer): # Every new serializer is created with a clone of the field instances. # This allows users to dynamically modify the fields on a serializer # instance without affecting every other serializer class. - self.fields = self.get_fields() + self.fields = self._get_base_fields() # Setup all the child fields, to provide them with the current context. for field_name, field in self.fields.items(): field.bind(field_name, self, self) - def get_fields(self): - return copy.deepcopy(self.base_fields) + def __new__(cls, *args, **kwargs): + # We override this method in order to automagically create + # `ListSerializer` classes instead when `many=True` is set. + if kwargs.pop('many', False): + kwargs['child'] = cls() + return ListSerializer(*args, **kwargs) + return super(Serializer, cls).__new__(cls, *args, **kwargs) + + def _get_base_fields(self): + return copy.deepcopy(self._declared_fields) def bind(self, field_name, parent, root): # If the serializer is used as a field then when it becomes bound @@ -312,39 +315,8 @@ class ListSerializer(BaseSerializer): return representation.list_repr(self, indent=1) -class ModelSerializerOptions(object): - """ - Meta class options for ModelSerializer - """ - def __init__(self, meta): - self.model = getattr(meta, 'model') - self.fields = getattr(meta, 'fields', ()) - self.depth = getattr(meta, 'depth', 0) - - -def lookup_class(mapping, instance): - """ - Takes a dictionary with classes as keys, and an object. - Traverses the object's inheritance hierarchy in method - resolution order, and returns the first matching value - from the dictionary or raises a KeyError if nothing matches. - """ - for cls in inspect.getmro(instance.__class__): - if cls in mapping: - return mapping[cls] - raise KeyError('Class %s not found in lookup.', cls.__name__) - - -def needs_label(model_field, field_name): - """ - Returns `True` if the label based on the model's verbose name - is not equal to the default label it would have based on it's field name. - """ - return capfirst(model_field.verbose_name) != field_name_to_label(field_name) - - class ModelSerializer(Serializer): - field_mapping = { + _field_mapping = { models.AutoField: IntegerField, models.BigIntegerField: IntegerField, models.BooleanField: BooleanField, @@ -368,16 +340,10 @@ class ModelSerializer(Serializer): models.TimeField: TimeField, models.URLField: URLField, } - nested_class = None # We fill this in at the end of this module. - - _options_class = ModelSerializerOptions - - def __init__(self, *args, **kwargs): - self.opts = self._options_class(self.Meta) - super(ModelSerializer, self).__init__(*args, **kwargs) + _related_class = PrimaryKeyRelatedField def create(self, attrs): - ModelClass = self.opts.model + ModelClass = self.Meta.model return ModelClass.objects.create(**attrs) def update(self, obj, attrs): @@ -385,319 +351,97 @@ class ModelSerializer(Serializer): setattr(obj, attr, value) obj.save() - def get_fields(self): - # Get the explicitly declared fields. - fields = copy.deepcopy(self.base_fields) + def _get_base_fields(self): + declared_fields = copy.deepcopy(self._declared_fields) - # Add in the default fields. - for key, val in self.get_default_fields().items(): - if key not in fields: - fields[key] = val - - # If `fields` is set on the `Meta` class, - # then use only those fields, and in that order. - if self.opts.fields: - fields = SortedDict([ - (key, fields[key]) for key in self.opts.fields - ]) - - return fields - - def get_default_fields(self): - """ - Return all the fields that should be serialized for the model. - """ - info = model_meta.get_field_info(self.opts.model) ret = SortedDict() + model = getattr(self.Meta, 'model') + fields = getattr(self.Meta, 'fields', None) + depth = getattr(self.Meta, 'depth', 0) + + # Retrieve metadata about fields & relationships on the model class. + info = model_meta.get_field_info(model) + + # Use the default set of fields if none is supplied explicitly. + if fields is None: + fields = self._get_default_field_names(declared_fields, info) + + for field_name in fields: + if field_name in declared_fields: + # Field is explicitly declared on the class, use that. + ret[field_name] = declared_fields[field_name] + continue + + elif field_name == api_settings.URL_FIELD_NAME: + # Create the URL field. + field_cls = HyperlinkedIdentityField + kwargs = get_url_kwargs(model) + + elif field_name in info.fields_and_pk: + # Create regular model fields. + model_field = info.fields_and_pk[field_name] + field_cls = lookup_class(self._field_mapping, model_field) + kwargs = get_field_kwargs(field_name, model_field) + if 'choices' in kwargs: + # Fields with choices get coerced into `ChoiceField` + # instead of using their regular typed field. + field_cls = ChoiceField + if not issubclass(field_cls, ModelField): + # `model_field` is only valid for the fallback case of + # `ModelField`, which is used when no other typed field + # matched to the model field. + kwargs.pop('model_field', None) + + elif field_name in info.relations: + # Create forward and reverse relationships. + relation_info = info.relations[field_name] + if depth: + field_cls = self._get_nested_class(depth, relation_info) + kwargs = get_nested_relation_kwargs(relation_info) + else: + field_cls = self._related_class + kwargs = get_relation_kwargs(field_name, relation_info) + # `view_name` is only valid for hyperlinked relationships. + if not issubclass(field_cls, HyperlinkedRelatedField): + kwargs.pop('view_name', None) - # URL field - serializer_url_field = self.get_url_field() - if serializer_url_field: - ret[api_settings.URL_FIELD_NAME] = serializer_url_field - - # Primary key field - field_name = info.pk.name - serializer_pk_field = self.get_pk_field(field_name, info.pk) - if serializer_pk_field: - ret[field_name] = serializer_pk_field - - # Regular fields - for field_name, field in info.fields.items(): - ret[field_name] = self.get_field(field_name, field) - - # Forward relations - for field_name, relation_info in info.forward_relations.items(): - if self.opts.depth: - ret[field_name] = self.get_nested_field(field_name, *relation_info) else: - ret[field_name] = self.get_related_field(field_name, *relation_info) + assert False, 'Field name `%s` is not valid.' % field_name - # Reverse relations - for accessor_name, relation_info in info.reverse_relations.items(): - if accessor_name in self.opts.fields: - if self.opts.depth: - ret[accessor_name] = self.get_nested_field(accessor_name, *relation_info) - else: - ret[accessor_name] = self.get_related_field(accessor_name, *relation_info) + ret[field_name] = field_cls(**kwargs) return ret - def get_url_field(self): - return None - - def get_pk_field(self, field_name, model_field): - """ - Returns a default instance of the pk field. - """ - return self.get_field(field_name, model_field) - - def get_nested_field(self, field_name, model_field, related_model, to_many, has_through_model): - """ - Creates a default instance of a nested relational field. + def _get_default_field_names(self, declared_fields, model_info): + return ( + [model_info.pk.name] + + list(declared_fields.keys()) + + list(model_info.fields.keys()) + + list(model_info.forward_relations.keys()) + ) - Note that model_field will be `None` for reverse relationships. - """ - class NestedModelSerializer(self.nested_class): + def _get_nested_class(self, nested_depth, relation_info): + class NestedSerializer(ModelSerializer): class Meta: - model = related_model - depth = self.opts.depth - 1 - - kwargs = {'read_only': True} - if to_many: - kwargs['many'] = True - return NestedModelSerializer(**kwargs) - - def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model): - """ - Creates a default instance of a flat relational field. - - Note that model_field will be `None` for reverse relationships. - """ - kwargs = { - 'queryset': related_model._default_manager, - } - - if to_many: - kwargs['many'] = True - - if has_through_model: - kwargs['read_only'] = True - kwargs.pop('queryset', None) - - if model_field: - if model_field.null or model_field.blank: - kwargs['required'] = False - if model_field.verbose_name and needs_label(model_field, field_name): - kwargs['label'] = capfirst(model_field.verbose_name) - if not model_field.editable: - kwargs['read_only'] = True - kwargs.pop('queryset', None) - help_text = clean_manytomany_helptext(model_field.help_text) - if help_text: - kwargs['help_text'] = help_text - - return PrimaryKeyRelatedField(**kwargs) - - def get_field(self, field_name, model_field): - """ - Creates a default instance of a basic non-relational field. - """ - serializer_cls = lookup_class(self.field_mapping, model_field) - kwargs = {} - validator_kwarg = model_field.validators - - if model_field.null or model_field.blank: - kwargs['required'] = False - - if model_field.verbose_name and needs_label(model_field, field_name): - kwargs['label'] = capfirst(model_field.verbose_name) - - if model_field.help_text: - kwargs['help_text'] = model_field.help_text - - if isinstance(model_field, models.AutoField) or not model_field.editable: - kwargs['read_only'] = True - # Read only implies that the field is not required. - # We have a cleaner repr on the instance if we don't set it. - kwargs.pop('required', None) - - if model_field.has_default(): - kwargs['default'] = model_field.get_default() - # Having a default implies that the field is not required. - # We have a cleaner repr on the instance if we don't set it. - kwargs.pop('required', None) - - if model_field.flatchoices: - # If this model field contains choices, then use a ChoiceField, - # rather than the standard serializer field for this type. - # Note that we return this prior to setting any validation type - # keyword arguments, as those are not valid initializers. - kwargs['choices'] = model_field.flatchoices - return ChoiceField(**kwargs) - - # Ensure that max_length is passed explicitly as a keyword arg, - # rather than as a validator. - max_length = getattr(model_field, 'max_length', None) - if max_length is not None: - kwargs['max_length'] = max_length - validator_kwarg = [ - validator for validator in validator_kwarg - if not isinstance(validator, validators.MaxLengthValidator) - ] - - # Ensure that min_length is passed explicitly as a keyword arg, - # rather than as a validator. - min_length = getattr(model_field, 'min_length', None) - if min_length is not None: - kwargs['min_length'] = min_length - validator_kwarg = [ - validator for validator in validator_kwarg - if not isinstance(validator, validators.MinLengthValidator) - ] - - # Ensure that max_value is passed explicitly as a keyword arg, - # rather than as a validator. - max_value = next(( - validator.limit_value for validator in validator_kwarg - if isinstance(validator, validators.MaxValueValidator) - ), None) - if max_value is not None: - kwargs['max_value'] = max_value - validator_kwarg = [ - validator for validator in validator_kwarg - if not isinstance(validator, validators.MaxValueValidator) - ] - - # Ensure that max_value is passed explicitly as a keyword arg, - # rather than as a validator. - min_value = next(( - validator.limit_value for validator in validator_kwarg - if isinstance(validator, validators.MinValueValidator) - ), None) - if min_value is not None: - kwargs['min_value'] = min_value - validator_kwarg = [ - validator for validator in validator_kwarg - if not isinstance(validator, validators.MinValueValidator) - ] - - # URLField does not need to include the URLValidator argument, - # as it is explicitly added in. - if isinstance(model_field, models.URLField): - validator_kwarg = [ - validator for validator in validator_kwarg - if not isinstance(validator, validators.URLValidator) - ] - - # EmailField does not need to include the validate_email argument, - # as it is explicitly added in. - if isinstance(model_field, models.EmailField): - validator_kwarg = [ - validator for validator in validator_kwarg - if validator is not validators.validate_email - ] - - # SlugField do not need to include the 'validate_slug' argument, - if isinstance(model_field, models.SlugField): - validator_kwarg = [ - validator for validator in validator_kwarg - if validator is not validators.validate_slug - ] - - max_digits = getattr(model_field, 'max_digits', None) - if max_digits is not None: - kwargs['max_digits'] = max_digits - - decimal_places = getattr(model_field, 'decimal_places', None) - if decimal_places is not None: - kwargs['decimal_places'] = decimal_places - - if isinstance(model_field, models.BooleanField): - # models.BooleanField has `blank=True`, but *is* actually - # required *unless* a default is provided. - # Also note that Django<1.6 uses `default=False` for - # models.BooleanField, but Django>=1.6 uses `default=None`. - kwargs.pop('required', None) - - if validator_kwarg: - kwargs['validators'] = validator_kwarg - - if issubclass(serializer_cls, ModelField): - kwargs['model_field'] = model_field - - return serializer_cls(**kwargs) - - -class HyperlinkedModelSerializerOptions(ModelSerializerOptions): - """ - Options for HyperlinkedModelSerializer - """ - def __init__(self, meta): - super(HyperlinkedModelSerializerOptions, self).__init__(meta) - self.view_name = getattr(meta, 'view_name', None) - self.lookup_field = getattr(meta, 'lookup_field', None) + model = relation_info.related + depth = nested_depth + return NestedSerializer class HyperlinkedModelSerializer(ModelSerializer): - _options_class = HyperlinkedModelSerializerOptions - - def get_url_field(self): - if self.opts.view_name is not None: - view_name = self.opts.view_name - else: - view_name = self.get_default_view_name(self.opts.model) - - kwargs = { - 'view_name': view_name - } - if self.opts.lookup_field: - kwargs['lookup_field'] = self.opts.lookup_field - - return HyperlinkedIdentityField(**kwargs) - - def get_pk_field(self, field_name, model_field): - if self.opts.fields and model_field.name in self.opts.fields: - return self.get_field(model_field) - - def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model): - """ - Creates a default instance of a flat relational field. - """ - kwargs = { - 'queryset': related_model._default_manager, - 'view_name': self.get_default_view_name(related_model), - } - - if to_many: - kwargs['many'] = True - - if has_through_model: - kwargs['read_only'] = True - kwargs.pop('queryset', None) - - if model_field: - if model_field.null or model_field.blank: - kwargs['required'] = False - if model_field.verbose_name and needs_label(model_field, field_name): - kwargs['label'] = capfirst(model_field.verbose_name) - if not model_field.editable: - kwargs['read_only'] = True - kwargs.pop('queryset', None) - help_text = clean_manytomany_helptext(model_field.help_text) - if help_text: - kwargs['help_text'] = help_text - - return HyperlinkedRelatedField(**kwargs) - - def get_default_view_name(self, model): - """ - Return the view name to use for related models. - """ - return '%(model_name)s-detail' % { - 'app_label': model._meta.app_label, - 'model_name': model._meta.object_name.lower() - } - - -ModelSerializer.nested_class = ModelSerializer -HyperlinkedModelSerializer.nested_class = HyperlinkedModelSerializer + _related_class = HyperlinkedRelatedField + + def _get_default_field_names(self, declared_fields, model_info): + return ( + [api_settings.URL_FIELD_NAME] + + list(declared_fields.keys()) + + list(model_info.fields.keys()) + + list(model_info.forward_relations.keys()) + ) + + def _get_nested_class(self, nested_depth, relation_info): + class NestedSerializer(HyperlinkedModelSerializer): + class Meta: + model = relation_info.related + depth = nested_depth + return NestedSerializer diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py new file mode 100644 index 00000000..be72e444 --- /dev/null +++ b/rest_framework/utils/field_mapping.py @@ -0,0 +1,215 @@ +""" +Helper functions for mapping model fields to a dictionary of default +keyword arguments that should be used for their equivelent serializer fields. +""" +from django.core import validators +from django.db import models +from django.utils.text import capfirst +from rest_framework.compat import clean_manytomany_helptext +import inspect + + +def lookup_class(mapping, instance): + """ + Takes a dictionary with classes as keys, and an object. + Traverses the object's inheritance hierarchy in method + resolution order, and returns the first matching value + from the dictionary or raises a KeyError if nothing matches. + """ + for cls in inspect.getmro(instance.__class__): + if cls in mapping: + return mapping[cls] + raise KeyError('Class %s not found in lookup.', cls.__name__) + + +def needs_label(model_field, field_name): + """ + Returns `True` if the label based on the model's verbose name + is not equal to the default label it would have based on it's field name. + """ + default_label = field_name.replace('_', ' ').capitalize() + return capfirst(model_field.verbose_name) != default_label + + +def get_detail_view_name(model): + """ + Given a model class, return the view name to use for URL relationships + that refer to instances of the model. + """ + return '%(model_name)s-detail' % { + 'app_label': model._meta.app_label, + 'model_name': model._meta.object_name.lower() + } + + +def get_field_kwargs(field_name, model_field): + """ + Creates a default instance of a basic non-relational field. + """ + kwargs = {} + validator_kwarg = model_field.validators + + if model_field.null or model_field.blank: + kwargs['required'] = False + + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) + + if model_field.help_text: + kwargs['help_text'] = model_field.help_text + + if isinstance(model_field, models.AutoField) or not model_field.editable: + kwargs['read_only'] = True + # Read only implies that the field is not required. + # We have a cleaner repr on the instance if we don't set it. + kwargs.pop('required', None) + + if model_field.has_default(): + kwargs['default'] = model_field.get_default() + # Having a default implies that the field is not required. + # We have a cleaner repr on the instance if we don't set it. + kwargs.pop('required', None) + + if model_field.flatchoices: + # If this model field contains choices, then return now, + # any further keyword arguments are not valid. + kwargs['choices'] = model_field.flatchoices + return kwargs + + # Ensure that max_length is passed explicitly as a keyword arg, + # rather than as a validator. + max_length = getattr(model_field, 'max_length', None) + if max_length is not None: + kwargs['max_length'] = max_length + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MaxLengthValidator) + ] + + # Ensure that min_length is passed explicitly as a keyword arg, + # rather than as a validator. + min_length = getattr(model_field, 'min_length', None) + if min_length is not None: + kwargs['min_length'] = min_length + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MinLengthValidator) + ] + + # Ensure that max_value is passed explicitly as a keyword arg, + # rather than as a validator. + max_value = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MaxValueValidator) + ), None) + if max_value is not None: + kwargs['max_value'] = max_value + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MaxValueValidator) + ] + + # Ensure that max_value is passed explicitly as a keyword arg, + # rather than as a validator. + min_value = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MinValueValidator) + ), None) + if min_value is not None: + kwargs['min_value'] = min_value + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MinValueValidator) + ] + + # URLField does not need to include the URLValidator argument, + # as it is explicitly added in. + if isinstance(model_field, models.URLField): + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.URLValidator) + ] + + # EmailField does not need to include the validate_email argument, + # as it is explicitly added in. + if isinstance(model_field, models.EmailField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_email + ] + + # SlugField do not need to include the 'validate_slug' argument, + if isinstance(model_field, models.SlugField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_slug + ] + + max_digits = getattr(model_field, 'max_digits', None) + if max_digits is not None: + kwargs['max_digits'] = max_digits + + decimal_places = getattr(model_field, 'decimal_places', None) + if decimal_places is not None: + kwargs['decimal_places'] = decimal_places + + if isinstance(model_field, models.BooleanField): + # models.BooleanField has `blank=True`, but *is* actually + # required *unless* a default is provided. + # Also note that Django<1.6 uses `default=False` for + # models.BooleanField, but Django>=1.6 uses `default=None`. + kwargs.pop('required', None) + + if validator_kwarg: + kwargs['validators'] = validator_kwarg + + # The following will only be used by ModelField classes. + # Gets removed for everything else. + kwargs['model_field'] = model_field + + return kwargs + + +def get_relation_kwargs(field_name, relation_info): + """ + Creates a default instance of a flat relational field. + """ + model_field, related_model, to_many, has_through_model = relation_info + kwargs = { + 'queryset': related_model._default_manager, + 'view_name': get_detail_view_name(related_model) + } + + if to_many: + kwargs['many'] = True + + if has_through_model: + kwargs['read_only'] = True + kwargs.pop('queryset', None) + + if model_field: + if model_field.null or model_field.blank: + kwargs['required'] = False + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) + if not model_field.editable: + kwargs['read_only'] = True + kwargs.pop('queryset', None) + help_text = clean_manytomany_helptext(model_field.help_text) + if help_text: + kwargs['help_text'] = help_text + + return kwargs + + +def get_nested_relation_kwargs(relation_info): + kwargs = {'read_only': True} + if relation_info.to_many: + kwargs['many'] = True + return kwargs + + +def get_url_kwargs(model_field): + return { + 'view_name': get_detail_view_name(model_field) + } diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index 960fa4d0..b6c41174 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -1,7 +1,9 @@ """ -Helper functions for returning the field information that is associated +Helper function for returning the field information that is associated with a model class. This includes returning all the forward and reverse relationships and their associated metadata. + +Usage: `get_field_info(model)` returns a `FieldInfo` instance. """ from collections import namedtuple from django.db import models @@ -9,8 +11,22 @@ from django.utils import six from django.utils.datastructures import SortedDict import inspect -FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations']) -RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model']) + +FieldInfo = namedtuple('FieldResult', [ + 'pk', # Model field instance + 'fields', # Dict of field name -> model field instance + 'forward_relations', # Dict of field name -> RelationInfo + 'reverse_relations', # Dict of field name -> RelationInfo + 'fields_and_pk', # Shortcut for 'pk' + 'fields' + 'relations' # Shortcut for 'forward_relations' + 'reverse_relations' +]) + +RelationInfo = namedtuple('RelationInfo', [ + 'model_field', + 'related', + 'to_many', + 'has_through_model' +]) def _resolve_model(obj): @@ -55,7 +71,7 @@ def get_field_info(model): forward_relations = SortedDict() for field in [field for field in opts.fields if field.serialize and field.rel]: forward_relations[field.name] = RelationInfo( - field=field, + model_field=field, related=_resolve_model(field.rel.to), to_many=False, has_through_model=False @@ -64,7 +80,7 @@ def get_field_info(model): # Deal with forward many-to-many relationships. for field in [field for field in opts.many_to_many if field.serialize]: forward_relations[field.name] = RelationInfo( - field=field, + model_field=field, related=_resolve_model(field.rel.to), to_many=True, has_through_model=( @@ -77,7 +93,7 @@ def get_field_info(model): for relation in opts.get_all_related_objects(): accessor_name = relation.get_accessor_name() reverse_relations[accessor_name] = RelationInfo( - field=None, + model_field=None, related=relation.model, to_many=relation.field.rel.multiple, has_through_model=False @@ -87,7 +103,7 @@ def get_field_info(model): for relation in opts.get_all_related_many_to_many_objects(): accessor_name = relation.get_accessor_name() reverse_relations[accessor_name] = RelationInfo( - field=None, + model_field=None, related=relation.model, to_many=True, has_through_model=( @@ -96,4 +112,18 @@ def get_field_info(model): ) ) - return FieldInfo(pk, fields, forward_relations, reverse_relations) + # Shortcut that merges both regular fields and the pk, + # for simplifying regular field lookup. + fields_and_pk = SortedDict() + fields_and_pk['pk'] = pk + fields_and_pk[pk.name] = pk + fields_and_pk.update(fields) + + # Shortcut that merges both forward and reverse relationships + + relations = SortedDict( + list(forward_relations.items()) + + list(reverse_relations.items()) + ) + + return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations) -- cgit v1.2.3 From 87734be5f41de921ac32ad1f6664db243aab6d07 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 18 Sep 2014 12:17:21 +0100 Subject: Configuration correctness tests on ModelSerializer --- rest_framework/serializers.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 99dcc349..9f3e53fd 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,7 +10,7 @@ python primitives. 2. The process of marshalling between python primitives and request and response content is handled by parsers and renderers. """ -from django.core.exceptions import ValidationError +from django.core.exceptions import ImproperlyConfigured, ValidationError from django.db import models from django.utils import six from django.utils.datastructures import SortedDict @@ -358,6 +358,7 @@ class ModelSerializer(Serializer): model = getattr(self.Meta, 'model') fields = getattr(self.Meta, 'fields', None) depth = getattr(self.Meta, 'depth', 0) + extra_kwargs = getattr(self.Meta, 'extra_kwargs', {}) # Retrieve metadata about fields & relationships on the model class. info = model_meta.get_field_info(model) @@ -405,9 +406,32 @@ class ModelSerializer(Serializer): if not issubclass(field_cls, HyperlinkedRelatedField): kwargs.pop('view_name', None) - else: - assert False, 'Field name `%s` is not valid.' % field_name + elif hasattr(model, field_name): + # Create a read only field for model methods and properties. + field_cls = ReadOnlyField + kwargs = {} + else: + raise ImproperlyConfigured( + 'Field name `%s` is not valid for model `%s`.' % + (field_name, model.__class__.__name__) + ) + + # Check that any fields declared on the class are + # also explicity included in `Meta.fields`. + missing_fields = set(declared_fields.keys()) - set(fields) + if missing_fields: + missing_field = list(missing_fields)[0] + raise ImproperlyConfigured( + 'Field `%s` has been declared on serializer `%s`, but ' + 'is missing from `Meta.fields`.' % + (missing_field, self.__class__.__name__) + ) + + # Populate any kwargs defined in `Meta.extra_kwargs` + kwargs.update(extra_kwargs.get(field_name, {})) + + # Create the serializer field. ret[field_name] = field_cls(**kwargs) return ret -- cgit v1.2.3 From 9fdb2280d11db126771686d626aa8a0247b8a46c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 18 Sep 2014 14:23:00 +0100 Subject: First pass on ManyRelation --- rest_framework/relations.py | 42 +++++++++++++++++++++++++++++++++- rest_framework/utils/representation.py | 2 ++ 2 files changed, 43 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 9f44ab63..474d3e75 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -10,7 +10,6 @@ from django.utils.translation import ugettext_lazy as _ class RelatedField(Field): def __init__(self, **kwargs): self.queryset = kwargs.pop('queryset', None) - self.many = kwargs.pop('many', False) assert self.queryset is not None or kwargs.get('read_only', None), ( 'Relational field must provide a `queryset` argument, ' 'or set read_only=`True`.' @@ -21,6 +20,13 @@ class RelatedField(Field): ) super(RelatedField, self).__init__(**kwargs) + def __new__(cls, *args, **kwargs): + # We override this method in order to automagically create + # `ManyRelation` classes instead when `many=True` is set. + if kwargs.pop('many', False): + return ManyRelation(child_relation=cls(*args, **kwargs)) + return super(RelatedField, cls).__new__(cls, *args, **kwargs) + def get_queryset(self): queryset = self.queryset if isinstance(queryset, QuerySet): @@ -216,3 +222,37 @@ class SlugRelatedField(RelatedField): def to_representation(self, obj): return getattr(obj, self.slug_field) + + +class ManyRelation(Field): + """ + Relationships with `many=True` transparently get coerced into instead being + a ManyRelation with a child relationship. + + The `ManyRelation` class is responsible for handling iterating through + the values and passing each one to the child relationship. + + You shouldn't need to be using this class directly yourself. + """ + + def __init__(self, child_relation=None, *args, **kwargs): + self.child_relation = child_relation + assert child_relation is not None, '`child_relation` is a required argument.' + super(ManyRelation, self).__init__(*args, **kwargs) + + def bind(self, field_name, parent, root): + # ManyRelation needs to provide the current context to the child relation. + super(ManyRelation, self).bind(field_name, parent, root) + self.child_relation.bind(field_name, parent, root) + + def to_internal_value(self, data): + return [ + self.child_relation.to_internal_value(item) + for item in data + ] + + def to_representation(self, obj): + return [ + self.child_relation.to_representation(value) + for value in obj.all() + ] diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py index 71db1886..e64fdd22 100644 --- a/rest_framework/utils/representation.py +++ b/rest_framework/utils/representation.py @@ -73,6 +73,8 @@ def serializer_repr(serializer, indent, force_many=None): ret += serializer_repr(field, indent + 1) elif hasattr(field, 'child'): ret += list_repr(field, indent + 1) + elif hasattr(field, 'child_relation'): + ret += field_repr(field.child_relation, force_many=field.child_relation) else: ret += field_repr(field) return ret -- cgit v1.2.3 From 106362b437f45e04faaea759df57a66a8a2d7cfd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 18 Sep 2014 14:58:08 +0100 Subject: ModelSerializer.create() to handle many to many by default --- rest_framework/relations.py | 5 ++++- rest_framework/serializers.py | 20 +++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 474d3e75..5aa1f8bd 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -24,7 +24,10 @@ class RelatedField(Field): # We override this method in order to automagically create # `ManyRelation` classes instead when `many=True` is set. if kwargs.pop('many', False): - return ManyRelation(child_relation=cls(*args, **kwargs)) + return ManyRelation( + child_relation=cls(*args, **kwargs), + read_only=kwargs.get('read_only', False) + ) return super(RelatedField, cls).__new__(cls, *args, **kwargs) def get_queryset(self): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 9f3e53fd..03e20df8 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -344,7 +344,25 @@ class ModelSerializer(Serializer): def create(self, attrs): ModelClass = self.Meta.model - return ModelClass.objects.create(**attrs) + + # Remove many-to-many relationships from attrs. + # They are not valid arguments to the default `.create()` method, + # as they require that the instance has already been saved. + info = model_meta.get_field_info(ModelClass) + many_to_many = {} + for key, relation_info in info.relations.items(): + if relation_info.to_many and (key in attrs): + many_to_many[key] = attrs.pop(key) + + instance = ModelClass.objects.create(**attrs) + + # Save many to many relationships after the instance is created. + if many_to_many: + for key, value in many_to_many.items(): + setattr(instance, key, value) + instance.save() + + return instance def update(self, obj, attrs): for attr, value in attrs.items(): -- cgit v1.2.3 From f90049316a3ecca6c92e10b57bfa5becbceff386 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 18 Sep 2014 15:47:27 +0100 Subject: Added a model update integration test --- rest_framework/serializers.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 03e20df8..d2740fc2 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -350,17 +350,16 @@ class ModelSerializer(Serializer): # as they require that the instance has already been saved. info = model_meta.get_field_info(ModelClass) many_to_many = {} - for key, relation_info in info.relations.items(): - if relation_info.to_many and (key in attrs): - many_to_many[key] = attrs.pop(key) + for field_name, relation_info in info.relations.items(): + if relation_info.to_many and (field_name in attrs): + many_to_many[field_name] = attrs.pop(field_name) instance = ModelClass.objects.create(**attrs) - # Save many to many relationships after the instance is created. + # Save many-to-many relationships after the instance is created. if many_to_many: - for key, value in many_to_many.items(): - setattr(instance, key, value) - instance.save() + for field_name, value in many_to_many.items(): + setattr(instance, field_name, value) return instance -- cgit v1.2.3 From cf72b9a8b755652cec4ad19a27488e3a79c2e401 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 19 Sep 2014 16:43:13 +0100 Subject: Moar tests --- rest_framework/serializers.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d2740fc2..d9f9c8cb 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -24,6 +24,7 @@ from rest_framework.utils.field_mapping import ( lookup_class ) import copy +import inspect # Note: We do the following so that users of the framework can use this style: # @@ -268,6 +269,7 @@ class ListSerializer(BaseSerializer): def __init__(self, *args, **kwargs): self.child = kwargs.pop('child', copy.deepcopy(self.child)) assert self.child is not None, '`child` is a required argument.' + assert not inspect.isclass(self.child), '`child` has not been instantiated.' self.context = kwargs.pop('context', {}) kwargs.pop('partial', None) -- cgit v1.2.3 From af46fd6b00f1d7f018049c19094af58acb1415fb Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 22 Sep 2014 12:25:57 +0100 Subject: Field tests and associated cleanup --- rest_framework/fields.py | 64 +++++++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 33 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 0c78b3fb..db75ddf9 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -12,7 +12,6 @@ from rest_framework.utils import html, representation, humanize_datetime import datetime import decimal import inspect -import warnings class empty: @@ -395,7 +394,7 @@ class IntegerField(Field): class FloatField(Field): default_error_messages = { - 'invalid': _("'%s' value must be a float."), + 'invalid': _("A valid number is required."), } def __init__(self, **kwargs): @@ -410,20 +409,20 @@ class FloatField(Field): def to_internal_value(self, value): if value is None: return None - return float(value) + try: + return float(value) + except (TypeError, ValueError): + self.fail('invalid') def to_representation(self, value): if value is None: return None - try: - return float(value) - except (TypeError, ValueError): - self.fail('invalid', value=value) + return float(value) class DecimalField(Field): default_error_messages = { - 'invalid': _('Enter a number.'), + 'invalid': _('A valid number is required.'), 'max_value': _('Ensure this value is less than or equal to {max_value}.'), 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'), @@ -485,7 +484,7 @@ class DecimalField(Field): if self.decimal_places is not None and decimals > self.decimal_places: self.fail('max_decimal_places', max_decimal_places=self.decimal_places) if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): - self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places) + self.fail('max_whole_digits', max_whole_digits=self.max_digits - self.decimal_places) return value @@ -511,6 +510,7 @@ class DecimalField(Field): class DateField(Field): default_error_messages = { 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'), + 'datetime': _('Expected a date but got a datetime.'), } format = api_settings.DATE_FORMAT input_formats = api_settings.DATE_INPUT_FORMATS @@ -525,12 +525,7 @@ class DateField(Field): return None if isinstance(value, datetime.datetime): - if timezone and settings.USE_TZ and timezone.is_aware(value): - # Convert aware datetimes to the default time zone - # before casting them to dates (#17742). - default_timezone = timezone.get_default_timezone() - value = timezone.make_naive(value, default_timezone) - return value.date() + self.fail('datetime') if isinstance(value, datetime.date): return value @@ -570,35 +565,38 @@ class DateField(Field): class DateTimeField(Field): default_error_messages = { 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'), + 'date': _('Expected a datetime but got a date.'), } format = api_settings.DATETIME_FORMAT input_formats = api_settings.DATETIME_INPUT_FORMATS + default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None - def __init__(self, format=None, input_formats=None, *args, **kwargs): + def __init__(self, format=None, input_formats=None, default_timezone=None, *args, **kwargs): self.format = format if format is not None else self.format self.input_formats = input_formats if input_formats is not None else self.input_formats + self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone super(DateTimeField, self).__init__(*args, **kwargs) + def enforce_timezone(self, value): + """ + When `self.default_timezone` is `None`, always return naive datetimes. + When `self.default_timezone` is not `None`, always return aware datetimes. + """ + if (self.default_timezone is not None) and not timezone.is_aware(value): + return timezone.make_aware(value, self.default_timezone) + elif (self.default_timezone is None) and timezone.is_aware(value): + return timezone.make_naive(value, timezone.UTC()) + return value + def to_internal_value(self, value): if value in (None, ''): return None - if isinstance(value, datetime.datetime): - return value + if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): + self.fail('date') - if isinstance(value, datetime.date): - value = datetime.datetime(value.year, value.month, value.day) - if settings.USE_TZ: - # For backwards compatibility, interpret naive datetimes in - # local time. This won't work during DST change, but we can't - # do much about it, so we let the exceptions percolate up the - # call stack. - warnings.warn("DateTimeField received a naive datetime (%s)" - " while time zone support is active." % value, - RuntimeWarning) - default_timezone = timezone.get_default_timezone() - value = timezone.make_aware(value, default_timezone) - return value + if isinstance(value, datetime.datetime): + return self.enforce_timezone(value) for format in self.input_formats: if format.lower() == ISO_8601: @@ -608,14 +606,14 @@ class DateTimeField(Field): pass else: if parsed is not None: - return parsed + return self.enforce_timezone(parsed) else: try: parsed = datetime.datetime.strptime(value, format) except (ValueError, TypeError): pass else: - return parsed + return self.enforce_timezone(parsed) humanized_format = humanize_datetime.datetime_formats(self.input_formats) self.fail('invalid', format=humanized_format) -- cgit v1.2.3 From afb3f8ab0ad6c33b147292e9777ba0ddf3871d14 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 22 Sep 2014 13:26:47 +0100 Subject: Tests and tweaks for text fields --- rest_framework/fields.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index db75ddf9..35bd5c4b 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -12,6 +12,7 @@ from rest_framework.utils import html, representation, humanize_datetime import datetime import decimal import inspect +import re class empty: @@ -325,7 +326,11 @@ class EmailField(CharField): default_error_messages = { 'invalid': _('Enter a valid email address.') } - default_validators = [validators.validate_email] + + def __init__(self, **kwargs): + super(EmailField, self).__init__(**kwargs) + validator = validators.EmailValidator(message=self.error_messages['invalid']) + self.validators = [validator] + self.validators def to_internal_value(self, data): if data == '' and not self.allow_blank: @@ -341,26 +346,37 @@ class EmailField(CharField): class RegexField(CharField): + default_error_messages = { + 'invalid': _('This value does not match the required pattern.') + } + def __init__(self, regex, **kwargs): - kwargs['validators'] = ( - [validators.RegexValidator(regex)] + - kwargs.get('validators', []) - ) super(RegexField, self).__init__(**kwargs) + validator = validators.RegexValidator(regex, message=self.error_messages['invalid']) + self.validators = [validator] + self.validators class SlugField(CharField): default_error_messages = { 'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.") } - default_validators = [validators.validate_slug] + + def __init__(self, **kwargs): + super(SlugField, self).__init__(**kwargs) + slug_regex = re.compile(r'^[-a-zA-Z0-9_]+$') + validator = validators.RegexValidator(slug_regex, message=self.error_messages['invalid']) + self.validators = [validator] + self.validators class URLField(CharField): default_error_messages = { 'invalid': _("Enter a valid URL.") } - default_validators = [validators.URLValidator()] + + def __init__(self, **kwargs): + super(URLField, self).__init__(**kwargs) + validator = validators.URLValidator(message=self.error_messages['invalid']) + self.validators = [validator] + self.validators # Number types... @@ -642,7 +658,7 @@ class TimeField(Field): self.input_formats = input_formats if input_formats is not None else self.input_formats super(TimeField, self).__init__(*args, **kwargs) - def from_native(self, value): + def to_internal_value(self, value): if value in (None, ''): return None -- cgit v1.2.3 From c54f394904c3f93211b8aa073de4e9e50110f831 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 22 Sep 2014 13:57:45 +0100 Subject: Ensure 'messages' in fields are respected in preference to default validator messages --- rest_framework/compat.py | 19 +++++++++++++++++++ rest_framework/fields.py | 34 ++++++++++++++++++++++------------ 2 files changed, 41 insertions(+), 12 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 7c05bed9..2b4ddb02 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -121,6 +121,25 @@ else: return [m.upper() for m in self.http_method_names if hasattr(self, m)] + +# MinValueValidator and MaxValueValidator only accept `message` in 1.8+ +if django.VERSION >= (1, 8): + from django.core.validators import MinValueValidator, MaxValueValidator +else: + from django.core.validators import MinValueValidator as DjangoMinValueValidator + from django.core.validators import MaxValueValidator as DjangoMaxValueValidator + + class MinValueValidator(DjangoMinValueValidator): + def __init__(self, *args, **kwargs): + self.message = kwargs.pop('message', self.message) + super(MinValueValidator, self).__init__(*args, **kwargs) + + class MaxValueValidator(DjangoMaxValueValidator): + def __init__(self, *args, **kwargs): + self.message = kwargs.pop('message', self.message) + super(MaxValueValidator, self).__init__(*args, **kwargs) + + # PATCH method is not implemented by Django if 'patch' not in View.http_method_names: View.http_method_names = View.http_method_names + ['patch'] diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 35bd5c4b..5105dfcb 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -6,7 +6,7 @@ from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ from rest_framework import ISO_8601 -from rest_framework.compat import smart_text +from rest_framework.compat import smart_text, MinValueValidator, MaxValueValidator from rest_framework.settings import api_settings from rest_framework.utils import html, representation, humanize_datetime import datetime @@ -330,7 +330,7 @@ class EmailField(CharField): def __init__(self, **kwargs): super(EmailField, self).__init__(**kwargs) validator = validators.EmailValidator(message=self.error_messages['invalid']) - self.validators = [validator] + self.validators + self.validators.append(validator) def to_internal_value(self, data): if data == '' and not self.allow_blank: @@ -353,7 +353,7 @@ class RegexField(CharField): def __init__(self, regex, **kwargs): super(RegexField, self).__init__(**kwargs) validator = validators.RegexValidator(regex, message=self.error_messages['invalid']) - self.validators = [validator] + self.validators + self.validators.append(validator) class SlugField(CharField): @@ -365,7 +365,7 @@ class SlugField(CharField): super(SlugField, self).__init__(**kwargs) slug_regex = re.compile(r'^[-a-zA-Z0-9_]+$') validator = validators.RegexValidator(slug_regex, message=self.error_messages['invalid']) - self.validators = [validator] + self.validators + self.validators.append(validator) class URLField(CharField): @@ -376,14 +376,16 @@ class URLField(CharField): def __init__(self, **kwargs): super(URLField, self).__init__(**kwargs) validator = validators.URLValidator(message=self.error_messages['invalid']) - self.validators = [validator] + self.validators + self.validators.append(validator) # Number types... class IntegerField(Field): default_error_messages = { - 'invalid': _('A valid integer is required.') + 'invalid': _('A valid integer is required.'), + 'max_value': _('Ensure this value is less than or equal to {max_value}.'), + 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), } def __init__(self, **kwargs): @@ -391,9 +393,11 @@ class IntegerField(Field): min_value = kwargs.pop('min_value', None) super(IntegerField, self).__init__(**kwargs) if max_value is not None: - self.validators.append(validators.MaxValueValidator(max_value)) + message = self.error_messages['max_value'].format(max_value=max_value) + self.validators.append(MaxValueValidator(max_value, message=message)) if min_value is not None: - self.validators.append(validators.MinValueValidator(min_value)) + message = self.error_messages['min_value'].format(min_value=min_value) + self.validators.append(MinValueValidator(min_value, message=message)) def to_internal_value(self, data): try: @@ -411,6 +415,8 @@ class IntegerField(Field): class FloatField(Field): default_error_messages = { 'invalid': _("A valid number is required."), + 'max_value': _('Ensure this value is less than or equal to {max_value}.'), + 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), } def __init__(self, **kwargs): @@ -418,9 +424,11 @@ class FloatField(Field): min_value = kwargs.pop('min_value', None) super(FloatField, self).__init__(**kwargs) if max_value is not None: - self.validators.append(validators.MaxValueValidator(max_value)) + message = self.error_messages['max_value'].format(max_value=max_value) + self.validators.append(MaxValueValidator(max_value, message=message)) if min_value is not None: - self.validators.append(validators.MinValueValidator(min_value)) + message = self.error_messages['min_value'].format(min_value=min_value) + self.validators.append(MinValueValidator(min_value, message=message)) def to_internal_value(self, value): if value is None: @@ -454,9 +462,11 @@ class DecimalField(Field): self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string super(DecimalField, self).__init__(**kwargs) if max_value is not None: - self.validators.append(validators.MaxValueValidator(max_value)) + message = self.error_messages['max_value'].format(max_value=max_value) + self.validators.append(MaxValueValidator(max_value, message=message)) if min_value is not None: - self.validators.append(validators.MinValueValidator(min_value)) + message = self.error_messages['min_value'].format(min_value=min_value) + self.validators.append(MinValueValidator(min_value, message=message)) def to_internal_value(self, value): """ -- cgit v1.2.3 From 249253a144ba4381581809fb3f27959c7bd6e577 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 22 Sep 2014 14:54:33 +0100 Subject: Fix compat issues --- rest_framework/fields.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 5105dfcb..5fb99a42 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -209,8 +209,10 @@ class Field(object): """ Validate a simple representation and return the internal value. - The provided data may be `empty` if no representation was included. - May return `empty` if the field should not be included in the + The provided data may be `empty` if no representation was included + in the input. + + May raise `SkipField` if the field should not be included in the validated data. """ if data is empty: @@ -223,6 +225,10 @@ class Field(object): return value def run_validators(self, value): + """ + Test the given value against all the validators on the field, + and either raise a `ValidationError` or simply return. + """ if value in (None, '', [], (), {}): return @@ -753,8 +759,9 @@ class MultipleChoiceField(ChoiceField): } def to_internal_value(self, data): - if not hasattr(data, '__iter__'): + if isinstance(data, type('')) or not hasattr(data, '__iter__'): self.fail('not_a_list', input_type=type(data).__name__) + return set([ super(MultipleChoiceField, self).to_internal_value(item) for item in data -- cgit v1.2.3 From 4db23cae213decc3e8a8613ad5c76a545f8cfb1a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 22 Sep 2014 15:34:06 +0100 Subject: Tweaks to DecimalField --- rest_framework/fields.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 5fb99a42..db7ceabb 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -521,20 +521,21 @@ class DecimalField(Field): return value def to_representation(self, value): - if isinstance(value, decimal.Decimal): - context = decimal.getcontext().copy() - context.prec = self.max_digits - quantized = value.quantize( - decimal.Decimal('.1') ** self.decimal_places, - context=context - ) - if not self.coerce_to_string: - return quantized - return '{0:f}'.format(quantized) + if value in (None, ''): + return None + if not isinstance(value, decimal.Decimal): + value = decimal.Decimal(value) + + context = decimal.getcontext().copy() + context.prec = self.max_digits + quantized = value.quantize( + decimal.Decimal('.1') ** self.decimal_places, + context=context + ) if not self.coerce_to_string: - return value - return '%.*f' % (self.max_decimal_places, value) + return quantized + return '{0:f}'.format(quantized) # Date & time fields... -- cgit v1.2.3 From 5586b6581d9d8db05276c08f2c6deffec04ade4f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 22 Sep 2014 16:02:59 +0100 Subject: Support format=None for date/time fields --- rest_framework/fields.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index db7ceabb..cbd3334a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -548,8 +548,8 @@ class DateField(Field): format = api_settings.DATE_FORMAT input_formats = api_settings.DATE_INPUT_FORMATS - def __init__(self, format=None, input_formats=None, *args, **kwargs): - self.format = format if format is not None else self.format + def __init__(self, format=empty, input_formats=None, *args, **kwargs): + self.format = format if format is not empty else self.format self.input_formats = input_formats if input_formats is not None else self.input_formats super(DateField, self).__init__(*args, **kwargs) @@ -604,8 +604,8 @@ class DateTimeField(Field): input_formats = api_settings.DATETIME_INPUT_FORMATS default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None - def __init__(self, format=None, input_formats=None, default_timezone=None, *args, **kwargs): - self.format = format if format is not None else self.format + def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs): + self.format = format if format is not empty else self.format self.input_formats = input_formats if input_formats is not None else self.input_formats self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone super(DateTimeField, self).__init__(*args, **kwargs) @@ -670,8 +670,8 @@ class TimeField(Field): format = api_settings.TIME_FORMAT input_formats = api_settings.TIME_INPUT_FORMATS - def __init__(self, format=None, input_formats=None, *args, **kwargs): - self.format = format if format is not None else self.format + def __init__(self, format=empty, input_formats=None, *args, **kwargs): + self.format = format if format is not empty else self.format self.input_formats = input_formats if input_formats is not None else self.input_formats super(TimeField, self).__init__(*args, **kwargs) -- cgit v1.2.3 From e5f0a97595ff9280c7876fc917f6feb27b5ea95d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 22 Sep 2014 16:45:06 +0100 Subject: More compat fixes --- rest_framework/compat.py | 23 +++++++++++++++++++++++ rest_framework/fields.py | 8 ++++---- 2 files changed, 27 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 2b4ddb02..7303c32a 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -139,6 +139,29 @@ else: self.message = kwargs.pop('message', self.message) super(MaxValueValidator, self).__init__(*args, **kwargs) +# URLValidator only accept `message` in 1.6+ +if django.VERSION >= (1, 6): + from django.core.validators import URLValidator +else: + from django.core.validators import URLValidator as DjangoURLValidator + + class URLValidator(DjangoURLValidator): + def __init__(self, *args, **kwargs): + self.message = kwargs.pop('message', self.message) + super(URLValidator, self).__init__(*args, **kwargs) + + +# EmailValidator requires explicit regex prior to 1.6+ +if django.VERSION >= (1, 6): + from django.core.validators import EmailValidator +else: + from django.core.validators import EmailValidator as DjangoEmailValidator + from django.core.validators import email_re + + class EmailValidator(DjangoEmailValidator): + def __init__(self, *args, **kwargs): + super(EmailValidator, self).__init__(email_re, *args, **kwargs) + # PATCH method is not implemented by Django if 'patch' not in View.http_method_names: diff --git a/rest_framework/fields.py b/rest_framework/fields.py index cbd3334a..12975ae4 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -6,7 +6,7 @@ from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ from rest_framework import ISO_8601 -from rest_framework.compat import smart_text, MinValueValidator, MaxValueValidator +from rest_framework.compat import smart_text, EmailValidator, MinValueValidator, MaxValueValidator, URLValidator from rest_framework.settings import api_settings from rest_framework.utils import html, representation, humanize_datetime import datetime @@ -335,7 +335,7 @@ class EmailField(CharField): def __init__(self, **kwargs): super(EmailField, self).__init__(**kwargs) - validator = validators.EmailValidator(message=self.error_messages['invalid']) + validator = EmailValidator(message=self.error_messages['invalid']) self.validators.append(validator) def to_internal_value(self, data): @@ -381,7 +381,7 @@ class URLField(CharField): def __init__(self, **kwargs): super(URLField, self).__init__(**kwargs) - validator = validators.URLValidator(message=self.error_messages['invalid']) + validator = URLValidator(message=self.error_messages['invalid']) self.validators.append(validator) @@ -525,7 +525,7 @@ class DecimalField(Field): return None if not isinstance(value, decimal.Decimal): - value = decimal.Decimal(value) + value = decimal.Decimal(str(value).strip()) context = decimal.getcontext().copy() context.prec = self.max_digits -- cgit v1.2.3 From b5454dd02290130a7fb0a0e375f3efecc58edc6d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 22 Sep 2014 16:50:04 +0100 Subject: Tests and tweaks for choice fields --- rest_framework/fields.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 12975ae4..500018f3 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -750,7 +750,7 @@ class ChoiceField(Field): self.fail('invalid_choice', input=data) def to_representation(self, value): - return value + return self.choice_strings_to_values[str(value)] class MultipleChoiceField(ChoiceField): @@ -769,7 +769,7 @@ class MultipleChoiceField(ChoiceField): ]) def to_representation(self, value): - return value + return [self.choice_strings_to_values[str(item)] for item in value] # File types... -- cgit v1.2.3 From 5a95baf2a2258fb5297062ac18582129c05fb320 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 22 Sep 2014 16:52:57 +0100 Subject: Tests & tweaks for ChoiceField --- rest_framework/fields.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 500018f3..80eadf1e 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -769,7 +769,9 @@ class MultipleChoiceField(ChoiceField): ]) def to_representation(self, value): - return [self.choice_strings_to_values[str(item)] for item in value] + return set([ + self.choice_strings_to_values[str(item)] for item in value + ]) # File types... -- cgit v1.2.3 From 5d80f7f932bfcc0630ac0fdbf07072a53197b98f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 22 Sep 2014 17:46:02 +0100 Subject: allow_blank, allow_null --- rest_framework/fields.py | 40 ++++++++++++++++------------------------ 1 file changed, 16 insertions(+), 24 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 80eadf1e..48a3e1ab 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -98,14 +98,15 @@ class Field(object): _creation_counter = 0 default_error_messages = { - 'required': _('This field is required.') + 'required': _('This field is required.'), + 'null': _('This field may not be null.') } default_validators = [] def __init__(self, read_only=False, write_only=False, required=None, default=empty, initial=None, source=None, label=None, help_text=None, style=None, - error_messages=None, validators=[]): + error_messages=None, validators=[], allow_null=False): self._creation_counter = Field._creation_counter Field._creation_counter += 1 @@ -129,6 +130,7 @@ class Field(object): self.help_text = help_text self.style = {} if style is None else style self.validators = validators or self.default_validators[:] + self.allow_null = allow_null # Collect default error message from self and parent classes messages = {} @@ -220,6 +222,11 @@ class Field(object): self.fail('required') return self.get_default() + if data is None: + if not self.allow_null: + self.fail('null') + return None + value = self.to_internal_value(data) self.run_validators(value) return value @@ -315,11 +322,14 @@ class CharField(Field): self.min_length = kwargs.pop('min_length', None) super(CharField, self).__init__(**kwargs) + def run_validation(self, data=empty): + if data == '': + if not self.allow_blank: + self.fail('blank') + return '' + return super(CharField, self).run_validation(data) + def to_internal_value(self, data): - if data == '' and not self.allow_blank: - self.fail('blank') - if data is None: - return None return str(data) def to_representation(self, value): @@ -339,10 +349,6 @@ class EmailField(CharField): self.validators.append(validator) def to_internal_value(self, data): - if data == '' and not self.allow_blank: - self.fail('blank') - if data is None: - return None return str(data).strip() def to_representation(self, value): @@ -437,8 +443,6 @@ class FloatField(Field): self.validators.append(MinValueValidator(min_value, message=message)) def to_internal_value(self, value): - if value is None: - return None try: return float(value) except (TypeError, ValueError): @@ -481,9 +485,6 @@ class DecimalField(Field): than max_digits in the number, and no more than decimal_places digits after the decimal point. """ - if value in (None, ''): - return None - value = smart_text(value).strip() try: value = decimal.Decimal(value) @@ -554,9 +555,6 @@ class DateField(Field): super(DateField, self).__init__(*args, **kwargs) def to_internal_value(self, value): - if value in (None, ''): - return None - if isinstance(value, datetime.datetime): self.fail('datetime') @@ -622,9 +620,6 @@ class DateTimeField(Field): return value def to_internal_value(self, value): - if value in (None, ''): - return None - if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): self.fail('date') @@ -676,9 +671,6 @@ class TimeField(Field): super(TimeField, self).__init__(*args, **kwargs) def to_internal_value(self, value): - if value in (None, ''): - return None - if isinstance(value, datetime.time): return value -- cgit v1.2.3 From b187f53453d3885cd918f5f9f4490bcc8e3e2410 Mon Sep 17 00:00:00 2001 From: Danilo Bargen Date: Mon, 2 Jun 2014 00:41:58 +0200 Subject: Changed return status for CSRF failures to HTTP 403 By default, Django returns "HTTP 403 Forbidden" responses when CSRF validation failed[1]. CSRF is a case of authorization, not of authentication. Therefore `PermissionDenied` should be raised instead of `AuthenticationFailed`. [1] https://docs.djangoproject.com/en/dev/ref/contrib/csrf/#rejected-requests --- rest_framework/authentication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index f3fec05e..36d74dd9 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -129,7 +129,7 @@ class SessionAuthentication(BaseAuthentication): reason = CSRFCheck().process_view(request, None, (), {}) if reason: # CSRF failed, bail with explicit error message - raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason) + raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) class TokenAuthentication(BaseAuthentication): -- cgit v1.2.3 From f22d0afc3dfc7478e084d1d6ed6b53f71641dec6 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 23 Sep 2014 14:15:00 +0100 Subject: Tests for field choices --- rest_framework/fields.py | 21 +++++++------ rest_framework/serializers.py | 3 ++ rest_framework/utils/field_mapping.py | 58 ++++++++++++++++++----------------- 3 files changed, 44 insertions(+), 38 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 48a3e1ab..f5bae734 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -102,6 +102,7 @@ class Field(object): 'null': _('This field may not be null.') } default_validators = [] + default_empty_html = None def __init__(self, read_only=False, write_only=False, required=None, default=empty, initial=None, source=None, @@ -185,6 +186,11 @@ class Field(object): Given the *incoming* primative data, return the value for this field that should be validated and transformed to a native value. """ + if html.is_html_input(dictionary): + # HTML forms will represent empty fields as '', and cannot + # represent None or False values directly. + ret = dictionary.get(self.field_name, '') + return self.default_empty_html if (ret == '') else ret return dictionary.get(self.field_name, empty) def get_attribute(self, instance): @@ -236,9 +242,6 @@ class Field(object): Test the given value against all the validators on the field, and either raise a `ValidationError` or simply return. """ - if value in (None, '', [], (), {}): - return - errors = [] for validator in self.validators: try: @@ -282,16 +285,10 @@ class BooleanField(Field): default_error_messages = { 'invalid': _('`{input}` is not a valid boolean.') } + default_empty_html = False TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) - def get_value(self, dictionary): - if html.is_html_input(dictionary): - # HTML forms do not send a `False` value on an empty checkbox, - # so we override the default empty value to be False. - return dictionary.get(self.field_name, False) - return dictionary.get(self.field_name, empty) - def to_internal_value(self, data): if data in self.TRUE_VALUES: return True @@ -315,6 +312,7 @@ class CharField(Field): default_error_messages = { 'blank': _('This field may not be blank.') } + default_empty_html = '' def __init__(self, **kwargs): self.allow_blank = kwargs.pop('allow_blank', False) @@ -323,6 +321,9 @@ class CharField(Field): super(CharField, self).__init__(**kwargs) def run_validation(self, data=empty): + # Test for the empty string here so that it does not get validated, + # and so that subclasses do not need to handle it explicitly + # inside the `to_internal_value()` method. if data == '': if not self.allow_blank: self.fail('blank') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d9f9c8cb..949f5915 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -411,6 +411,9 @@ class ModelSerializer(Serializer): # `ModelField`, which is used when no other typed field # matched to the model field. kwargs.pop('model_field', None) + if not issubclass(field_cls, CharField): + # `allow_blank` is only valid for textual fields. + kwargs.pop('allow_blank', None) elif field_name in info.relations: # Create forward and reverse relationships. diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index be72e444..1c718ccb 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -49,8 +49,9 @@ def get_field_kwargs(field_name, model_field): kwargs = {} validator_kwarg = model_field.validators - if model_field.null or model_field.blank: - kwargs['required'] = False + # The following will only be used by ModelField classes. + # Gets removed for everything else. + kwargs['model_field'] = model_field if model_field.verbose_name and needs_label(model_field, field_name): kwargs['label'] = capfirst(model_field.verbose_name) @@ -59,23 +60,26 @@ def get_field_kwargs(field_name, model_field): kwargs['help_text'] = model_field.help_text if isinstance(model_field, models.AutoField) or not model_field.editable: + # If this field is read-only, then return early. + # Further keyword arguments are not valid. kwargs['read_only'] = True - # Read only implies that the field is not required. - # We have a cleaner repr on the instance if we don't set it. - kwargs.pop('required', None) + return kwargs if model_field.has_default(): - kwargs['default'] = model_field.get_default() - # Having a default implies that the field is not required. - # We have a cleaner repr on the instance if we don't set it. - kwargs.pop('required', None) + kwargs['required'] = False if model_field.flatchoices: - # If this model field contains choices, then return now, - # any further keyword arguments are not valid. + # If this model field contains choices, then return early. + # Further keyword arguments are not valid. kwargs['choices'] = model_field.flatchoices return kwargs + if model_field.null: + kwargs['allow_null'] = True + + if model_field.blank: + kwargs['allow_blank'] = True + # Ensure that max_length is passed explicitly as a keyword arg, # rather than as a validator. max_length = getattr(model_field, 'max_length', None) @@ -88,7 +92,10 @@ def get_field_kwargs(field_name, model_field): # Ensure that min_length is passed explicitly as a keyword arg, # rather than as a validator. - min_length = getattr(model_field, 'min_length', None) + min_length = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MinLengthValidator) + ), None) if min_length is not None: kwargs['min_length'] = min_length validator_kwarg = [ @@ -153,20 +160,9 @@ def get_field_kwargs(field_name, model_field): if decimal_places is not None: kwargs['decimal_places'] = decimal_places - if isinstance(model_field, models.BooleanField): - # models.BooleanField has `blank=True`, but *is* actually - # required *unless* a default is provided. - # Also note that Django<1.6 uses `default=False` for - # models.BooleanField, but Django>=1.6 uses `default=None`. - kwargs.pop('required', None) - if validator_kwarg: kwargs['validators'] = validator_kwarg - # The following will only be used by ModelField classes. - # Gets removed for everything else. - kwargs['model_field'] = model_field - return kwargs @@ -188,16 +184,22 @@ def get_relation_kwargs(field_name, relation_info): kwargs.pop('queryset', None) if model_field: - if model_field.null or model_field.blank: - kwargs['required'] = False if model_field.verbose_name and needs_label(model_field, field_name): kwargs['label'] = capfirst(model_field.verbose_name) - if not model_field.editable: - kwargs['read_only'] = True - kwargs.pop('queryset', None) help_text = clean_manytomany_helptext(model_field.help_text) if help_text: kwargs['help_text'] = help_text + if not model_field.editable: + kwargs['read_only'] = True + kwargs.pop('queryset', None) + if kwargs.get('read_only', False): + # If this field is read-only, then return early. + # No further keyword arguments are valid. + return kwargs + if model_field.has_default(): + kwargs['required'] = False + if model_field.null: + kwargs['allow_null'] = True return kwargs -- cgit v1.2.3 From 0404f09a7e69f533038d47ca25caad90c0c2659f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 23 Sep 2014 14:30:17 +0100 Subject: NullBooleanField --- rest_framework/fields.py | 37 ++++++++++++++++++++++++++++++++++- rest_framework/serializers.py | 2 +- rest_framework/utils/field_mapping.py | 2 +- 3 files changed, 38 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f5bae734..f859658a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -289,6 +289,10 @@ class BooleanField(Field): TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) + def __init__(self, **kwargs): + assert 'allow_null' not in kwargs, '`allow_null` is not a valid option. Use `NullBooleanField` instead.' + super(BooleanField, self).__init__(**kwargs) + def to_internal_value(self, data): if data in self.TRUE_VALUES: return True @@ -297,7 +301,38 @@ class BooleanField(Field): self.fail('invalid', input=data) def to_representation(self, value): - if value is None: + if value in self.TRUE_VALUES: + return True + elif value in self.FALSE_VALUES: + return False + return bool(value) + + +class NullBooleanField(Field): + default_error_messages = { + 'invalid': _('`{input}` is not a valid boolean.') + } + default_empty_html = None + TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) + FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) + NULL_VALUES = set(('n', 'N', 'null', 'Null', 'NULL', '', None)) + + def __init__(self, **kwargs): + assert 'allow_null' not in kwargs, '`allow_null` is not a valid option.' + kwargs['allow_null'] = True + super(NullBooleanField, self).__init__(**kwargs) + + def to_internal_value(self, data): + if data in self.TRUE_VALUES: + return True + elif data in self.FALSE_VALUES: + return False + elif data in self.NULL_VALUES: + return None + self.fail('invalid', input=data) + + def to_representation(self, value): + if value in self.NULL_VALUES: return None if value in self.TRUE_VALUES: return True diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 949f5915..d8d72a4c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -333,7 +333,7 @@ class ModelSerializer(Serializer): models.FloatField: FloatField, models.ImageField: ImageField, models.IntegerField: IntegerField, - models.NullBooleanField: BooleanField, + models.NullBooleanField: NullBooleanField, models.PositiveIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField, models.SlugField: SlugField, diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index 1c718ccb..c208afdc 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -74,7 +74,7 @@ def get_field_kwargs(field_name, model_field): kwargs['choices'] = model_field.flatchoices return kwargs - if model_field.null: + if model_field.null and not isinstance(model_field, models.NullBooleanField): kwargs['allow_null'] = True if model_field.blank: -- cgit v1.2.3 From f4b1dcb167be0bbdaae2cc2a92f651536896dc16 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 24 Sep 2014 14:09:49 +0100 Subject: OPTIONS support --- rest_framework/generics.py | 51 +------------- rest_framework/metadata.py | 126 ++++++++++++++++++++++++++++++++++ rest_framework/serializers.py | 8 +-- rest_framework/settings.py | 2 + rest_framework/utils/field_mapping.py | 20 +++--- rest_framework/views.py | 24 ++----- 6 files changed, 150 insertions(+), 81 deletions(-) create mode 100644 rest_framework/metadata.py (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index eb6b64ef..f49b0a43 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -4,13 +4,11 @@ Generic views that provide commonly needed behaviour. from __future__ import unicode_literals from django.db.models.query import QuerySet -from django.core.exceptions import PermissionDenied from django.core.paginator import Paginator, InvalidPage from django.http import Http404 from django.shortcuts import get_object_or_404 as _get_object_or_404 from django.utils.translation import ugettext as _ -from rest_framework import views, mixins, exceptions -from rest_framework.request import clone_request +from rest_framework import views, mixins from rest_framework.settings import api_settings @@ -249,53 +247,6 @@ class GenericAPIView(views.APIView): return obj - # The following are placeholder methods, - # and are intended to be overridden. - # - # The are not called by GenericAPIView directly, - # but are used by the mixin methods. - def metadata(self, request): - """ - Return a dictionary of metadata about the view. - Used to return responses for OPTIONS requests. - - We override the default behavior, and add some extra information - about the required request body for POST and PUT operations. - """ - ret = super(GenericAPIView, self).metadata(request) - - actions = {} - for method in ('PUT', 'POST'): - if method not in self.allowed_methods: - continue - - cloned_request = clone_request(request, method) - try: - # Test global permissions - self.check_permissions(cloned_request) - # Test object permissions - if method == 'PUT': - try: - self.get_object() - except Http404: - # Http404 should be acceptable and the serializer - # metadata should be populated. Except this so the - # outer "else" clause of the try-except-else block - # will be executed. - pass - except (exceptions.APIException, PermissionDenied): - pass - else: - # If user has appropriate permissions for the view, include - # appropriate metadata about the fields that should be supplied. - serializer = self.get_serializer() - actions[method] = serializer.metadata() - - if actions: - ret['actions'] = actions - - return ret - # Concrete view classes that provide method handlers # by composing the mixin classes with the base view. diff --git a/rest_framework/metadata.py b/rest_framework/metadata.py new file mode 100644 index 00000000..580259de --- /dev/null +++ b/rest_framework/metadata.py @@ -0,0 +1,126 @@ +""" +The metadata API is used to allow cusomization of how `OPTIONS` requests +are handled. We currently provide a single default implementation that returns +some fairly ad-hoc information about the view. + +Future implementations might use JSON schema or other definations in order +to return this information in a more standardized way. +""" +from __future__ import unicode_literals + +from django.core.exceptions import PermissionDenied +from django.http import Http404 +from django.utils import six +from django.utils.datastructures import SortedDict +from rest_framework import exceptions, serializers +from rest_framework.compat import force_text +from rest_framework.request import clone_request +from rest_framework.utils.field_mapping import ClassLookupDict + + +class BaseMetadata(object): + def determine_metadata(self, request, view): + """ + Return a dictionary of metadata about the view. + Used to return responses for OPTIONS requests. + """ + raise NotImplementedError(".determine_metadata() must be overridden.") + + +class SimpleMetadata(BaseMetadata): + """ + This is the default metadata implementation. + It returns an ad-hoc set of information about the view. + There are not any formalized standards for `OPTIONS` responses + for us to base this on. + """ + label_lookup = ClassLookupDict({ + serializers.Field: 'field', + serializers.BooleanField: 'boolean', + serializers.CharField: 'string', + serializers.URLField: 'url', + serializers.EmailField: 'email', + serializers.RegexField: 'regex', + serializers.SlugField: 'slug', + serializers.IntegerField: 'integer', + serializers.FloatField: 'float', + serializers.DecimalField: 'decimal', + serializers.DateField: 'date', + serializers.DateTimeField: 'datetime', + serializers.TimeField: 'time', + serializers.ChoiceField: 'choice', + serializers.MultipleChoiceField: 'multiple choice', + serializers.FileField: 'file upload', + serializers.ImageField: 'image upload', + }) + + def determine_metadata(self, request, view): + metadata = SortedDict() + metadata['name'] = view.get_view_name() + metadata['description'] = view.get_view_description() + metadata['renders'] = [renderer.media_type for renderer in view.renderer_classes] + metadata['parses'] = [parser.media_type for parser in view.parser_classes] + if hasattr(view, 'get_serializer'): + actions = self.determine_actions(request, view) + if actions: + metadata['actions'] = actions + return metadata + + def determine_actions(self, request, view): + """ + For generic class based views we return information about + the fields that are accepted for 'PUT' and 'POST' methods. + """ + actions = {} + for method in set(['PUT', 'POST']) & set(view.allowed_methods): + view.request = clone_request(request, method) + try: + # Test global permissions + if hasattr(view, 'check_permissions'): + view.check_permissions(view.request) + # Test object permissions + if method == 'PUT' and hasattr(view, 'get_object'): + view.get_object() + except (exceptions.APIException, PermissionDenied, Http404): + pass + else: + # If user has appropriate permissions for the view, include + # appropriate metadata about the fields that should be supplied. + serializer = view.get_serializer() + actions[method] = self.get_serializer_info(serializer) + finally: + view.request = request + + return actions + + def get_serializer_info(self, serializer): + """ + Given an instance of a serializer, return a dictionary of metadata + about its fields. + """ + return SortedDict([ + (field_name, self.get_field_info(field)) + for field_name, field in six.iteritems(serializer.fields) + ]) + + def get_field_info(self, field): + """ + Given an instance of a serializer field, return a dictionary + of metadata about it. + """ + field_info = SortedDict() + field_info['type'] = self.label_lookup[field] + field_info['required'] = getattr(field, 'required', False) + + for attr in ['read_only', 'label', 'help_text', 'min_length', 'max_length']: + value = getattr(field, attr, None) + if value is not None and value != '': + field_info[attr] = force_text(value, strings_only=True) + + if hasattr(field, 'choices'): + field_info['choices'] = [ + {'value': choice_value, 'display_name': choice_name} + for choice_value, choice_name in field.choices.items() + ] + + return field_info diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d8d72a4c..8902294b 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -21,7 +21,7 @@ from rest_framework.utils import html, model_meta, representation from rest_framework.utils.field_mapping import ( get_url_kwargs, get_field_kwargs, get_relation_kwargs, get_nested_relation_kwargs, - lookup_class + ClassLookupDict ) import copy import inspect @@ -318,7 +318,7 @@ class ListSerializer(BaseSerializer): class ModelSerializer(Serializer): - _field_mapping = { + _field_mapping = ClassLookupDict({ models.AutoField: IntegerField, models.BigIntegerField: IntegerField, models.BooleanField: BooleanField, @@ -341,7 +341,7 @@ class ModelSerializer(Serializer): models.TextField: CharField, models.TimeField: TimeField, models.URLField: URLField, - } + }) _related_class = PrimaryKeyRelatedField def create(self, attrs): @@ -400,7 +400,7 @@ class ModelSerializer(Serializer): elif field_name in info.fields_and_pk: # Create regular model fields. model_field = info.fields_and_pk[field_name] - field_cls = lookup_class(self._field_mapping, model_field) + field_cls = self._field_mapping[model_field] kwargs = get_field_kwargs(field_name, model_field) if 'choices' in kwargs: # Fields with choices get coerced into `ChoiceField` diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 421e146c..d7fb0a43 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -45,6 +45,7 @@ DEFAULTS = { ), 'DEFAULT_THROTTLE_CLASSES': (), 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', + 'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata', # Genric view behavior 'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.serializers.ModelSerializer', @@ -121,6 +122,7 @@ IMPORT_STRINGS = ( 'DEFAULT_PERMISSION_CLASSES', 'DEFAULT_THROTTLE_CLASSES', 'DEFAULT_CONTENT_NEGOTIATION_CLASS', + 'DEFAULT_METADATA_CLASS', 'DEFAULT_MODEL_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS', 'DEFAULT_FILTER_BACKENDS', diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index c208afdc..c3794083 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -9,17 +9,21 @@ from rest_framework.compat import clean_manytomany_helptext import inspect -def lookup_class(mapping, instance): +class ClassLookupDict(object): """ - Takes a dictionary with classes as keys, and an object. - Traverses the object's inheritance hierarchy in method - resolution order, and returns the first matching value + Takes a dictionary with classes as keys. + Lookups against this object will traverses the object's inheritance + hierarchy in method resolution order, and returns the first matching value from the dictionary or raises a KeyError if nothing matches. """ - for cls in inspect.getmro(instance.__class__): - if cls in mapping: - return mapping[cls] - raise KeyError('Class %s not found in lookup.', cls.__name__) + def __init__(self, mapping): + self.mapping = mapping + + def __getitem__(self, key): + for cls in inspect.getmro(key.__class__): + if cls in self.mapping: + return self.mapping[cls] + raise KeyError('Class %s not found in lookup.', cls.__name__) def needs_label(model_field, field_name): diff --git a/rest_framework/views.py b/rest_framework/views.py index 9f08a4ad..835e223a 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -5,7 +5,6 @@ from __future__ import unicode_literals from django.core.exceptions import PermissionDenied, ValidationError, NON_FIELD_ERRORS from django.http import Http404 -from django.utils.datastructures import SortedDict from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions from rest_framework.compat import smart_text, HttpResponseBase, View @@ -99,6 +98,7 @@ class APIView(View): throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS + metadata_class = api_settings.DEFAULT_METADATA_CLASS # Allow dependancy injection of other settings to make testing easier. settings = api_settings @@ -418,22 +418,8 @@ class APIView(View): def options(self, request, *args, **kwargs): """ Handler method for HTTP 'OPTIONS' request. - We may as well implement this as Django will otherwise provide - a less useful default implementation. """ - return Response(self.metadata(request), status=status.HTTP_200_OK) - - def metadata(self, request): - """ - Return a dictionary of metadata about the view. - Used to return responses for OPTIONS requests. - """ - # By default we can't provide any form-like information, however the - # generic views override this implementation and add additional - # information for POST and PUT methods, based on the serializer. - ret = SortedDict() - ret['name'] = self.get_view_name() - ret['description'] = self.get_view_description() - ret['renders'] = [renderer.media_type for renderer in self.renderer_classes] - ret['parses'] = [parser.media_type for parser in self.parser_classes] - return ret + if self.metadata_class is None: + return self.http_method_not_allowed(request, *args, **kwargs) + data = self.metadata_class().determine_metadata(request, self) + return Response(data, status=status.HTTP_200_OK) -- cgit v1.2.3 From 127c0bd3d68860dd6567d81047257fbc3e70b4b9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 24 Sep 2014 20:25:59 +0100 Subject: Custom deepcopy on Field classes --- rest_framework/fields.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f859658a..1f7d964a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -9,6 +9,7 @@ from rest_framework import ISO_8601 from rest_framework.compat import smart_text, EmailValidator, MinValueValidator, MaxValueValidator, URLValidator from rest_framework.settings import api_settings from rest_framework.utils import html, representation, humanize_datetime +import copy import datetime import decimal import inspect @@ -150,6 +151,11 @@ class Field(object): instance._kwargs = kwargs return instance + def __deepcopy__(self, memo): + args = copy.deepcopy(self._args) + kwargs = copy.deepcopy(self._kwargs) + return self.__class__(*args, **kwargs) + def bind(self, field_name, parent, root): """ Setup the context for the field instance. -- cgit v1.2.3 From fb1546ee50953faae8af505a0c90da00ac08ad92 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 24 Sep 2014 20:53:37 +0100 Subject: Enforce field_name != source --- rest_framework/fields.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1f7d964a..9280ea3a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -160,6 +160,17 @@ class Field(object): """ Setup the context for the field instance. """ + + # In order to enforce a consistent style, we error if a redundant + # 'source' argument has been used. For example: + # my_field = serializer.CharField(source='my_field') + assert self._kwargs.get('source') != field_name, ( + "It is redundant to specify `source='%s'` on field '%s' in " + "serializer '%s', as it is the same the field name. " + "Remove the `source` keyword argument." % + (field_name, self.__class__.__name__, parent.__class__.__name__) + ) + self.field_name = field_name self.parent = parent self.root = root -- cgit v1.2.3 From 1420c76453c37c023a901dd0938d717b7b5e52ca Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Sep 2014 10:49:25 +0100 Subject: Ensure proper sorting of 'choices' attribute on ChoiceField --- rest_framework/fields.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 9280ea3a..d1aebbaf 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -2,6 +2,7 @@ from django.conf import settings from django.core import validators from django.core.exceptions import ValidationError from django.utils import timezone +from django.utils.datastructures import SortedDict from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ @@ -166,7 +167,7 @@ class Field(object): # my_field = serializer.CharField(source='my_field') assert self._kwargs.get('source') != field_name, ( "It is redundant to specify `source='%s'` on field '%s' in " - "serializer '%s', as it is the same the field name. " + "serializer '%s', because it is the same as the field name. " "Remove the `source` keyword argument." % (field_name, self.__class__.__name__, parent.__class__.__name__) ) @@ -303,6 +304,7 @@ class BooleanField(Field): 'invalid': _('`{input}` is not a valid boolean.') } default_empty_html = False + initial = False TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) @@ -365,6 +367,7 @@ class CharField(Field): 'blank': _('This field may not be blank.') } default_empty_html = '' + initial = '' def __init__(self, **kwargs): self.allow_blank = kwargs.pop('allow_blank', False) @@ -775,9 +778,9 @@ class ChoiceField(Field): for item in choices ] if all(pairs): - self.choices = dict([(key, display_value) for key, display_value in choices]) + self.choices = SortedDict([(key, display_value) for key, display_value in choices]) else: - self.choices = dict([(item, item) for item in choices]) + self.choices = SortedDict([(item, item) for item in choices]) # Map the string representation of choices to the underlying value. # Allows us to deal with eg. integer choices while supporting either -- cgit v1.2.3 From b22c9602fa0f717b688fdb35e4f6f42c189af3f3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Sep 2014 11:04:18 +0100 Subject: Automatic field binding --- rest_framework/metadata.py | 3 +-- rest_framework/pagination.py | 1 - rest_framework/serializers.py | 30 +++++++++++++++++++++++++----- 3 files changed, 26 insertions(+), 8 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/metadata.py b/rest_framework/metadata.py index 580259de..af4bc396 100644 --- a/rest_framework/metadata.py +++ b/rest_framework/metadata.py @@ -10,7 +10,6 @@ from __future__ import unicode_literals from django.core.exceptions import PermissionDenied from django.http import Http404 -from django.utils import six from django.utils.datastructures import SortedDict from rest_framework import exceptions, serializers from rest_framework.compat import force_text @@ -100,7 +99,7 @@ class SimpleMetadata(BaseMetadata): """ return SortedDict([ (field_name, self.get_field_info(field)) - for field_name, field in six.iteritems(serializer.fields) + for field_name, field in serializer.fields.items() ]) def get_field_info(self, field): diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index c5a9270a..fb451285 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -72,7 +72,6 @@ class BasePaginationSerializer(serializers.Serializer): child=object_serializer(), source='object_list' ) - self.fields[results_field].bind(results_field, self, self) class PaginationSerializer(BasePaginationSerializer): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8902294b..12e38090 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -149,6 +149,28 @@ class SerializerMetaclass(type): return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) +class BindingDict(object): + def __init__(self, serializer): + self.serializer = serializer + self.fields = SortedDict() + + def __setitem__(self, key, field): + self.fields[key] = field + field.bind(field_name=key, parent=self.serializer, root=self.serializer) + + def __getitem__(self, key): + return self.fields[key] + + def __delitem__(self, key): + del self.fields[key] + + def items(self): + return self.fields.items() + + def values(self): + return self.fields.values() + + @six.add_metaclass(SerializerMetaclass) class Serializer(BaseSerializer): def __init__(self, *args, **kwargs): @@ -161,11 +183,9 @@ class Serializer(BaseSerializer): # Every new serializer is created with a clone of the field instances. # This allows users to dynamically modify the fields on a serializer # instance without affecting every other serializer class. - self.fields = self._get_base_fields() - - # Setup all the child fields, to provide them with the current context. - for field_name, field in self.fields.items(): - field.bind(field_name, self, self) + self.fields = BindingDict(self) + for key, value in self._get_base_fields().items(): + self.fields[key] = value def __new__(cls, *args, **kwargs): # We override this method in order to automagically create -- cgit v1.2.3 From 64632da3718f501cb8174243385d38b547c2fefd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Sep 2014 11:40:32 +0100 Subject: Clean up bind - no longer needs to be called multiple times in nested fields --- rest_framework/fields.py | 21 ++++++++++++++++----- rest_framework/relations.py | 5 ----- rest_framework/serializers.py | 26 +++++++++----------------- 3 files changed, 25 insertions(+), 27 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index d1aebbaf..446732c3 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -109,7 +109,8 @@ class Field(object): def __init__(self, read_only=False, write_only=False, required=None, default=empty, initial=None, source=None, label=None, help_text=None, style=None, - error_messages=None, validators=[], allow_null=False): + error_messages=None, validators=[], allow_null=False, + context=None): self._creation_counter = Field._creation_counter Field._creation_counter += 1 @@ -135,6 +136,11 @@ class Field(object): self.validators = validators or self.default_validators[:] self.allow_null = allow_null + # These are set up by `.bind()` when the field is added to a serializer. + self.field_name = None + self.parent = None + self._context = {} if (context is None) else context + # Collect default error message from self and parent classes messages = {} for cls in reversed(self.__class__.__mro__): @@ -157,7 +163,14 @@ class Field(object): kwargs = copy.deepcopy(self._kwargs) return self.__class__(*args, **kwargs) - def bind(self, field_name, parent, root): + @property + def context(self): + root = self + while root.parent is not None: + root = root.parent + return root._context + + def bind(self, field_name, parent): """ Setup the context for the field instance. """ @@ -174,10 +187,8 @@ class Field(object): self.field_name = field_name self.parent = parent - self.root = root - self.context = parent.context - # `self.label` should deafult to being based on the field name. + # `self.label` should default to being based on the field name. if self.label is None: self.label = field_name.replace('_', ' ').capitalize() diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 5aa1f8bd..b37a6fed 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -243,11 +243,6 @@ class ManyRelation(Field): assert child_relation is not None, '`child_relation` is a required argument.' super(ManyRelation, self).__init__(*args, **kwargs) - def bind(self, field_name, parent, root): - # ManyRelation needs to provide the current context to the child relation. - super(ManyRelation, self).bind(field_name, parent, root) - self.child_relation.bind(field_name, parent, root) - def to_internal_value(self, data): return [ self.child_relation.to_internal_value(item) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 12e38090..04721c7a 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -150,13 +150,20 @@ class SerializerMetaclass(type): class BindingDict(object): + """ + This dict-like object is used to store fields on a serializer. + + This ensures that whenever fields are added to the serializer we call + `field.bind()` so that the `field_name` and `parent` attributes + can be set correctly. + """ def __init__(self, serializer): self.serializer = serializer self.fields = SortedDict() def __setitem__(self, key, field): self.fields[key] = field - field.bind(field_name=key, parent=self.serializer, root=self.serializer) + field.bind(field_name=key, parent=self.serializer) def __getitem__(self, key): return self.fields[key] @@ -174,7 +181,6 @@ class BindingDict(object): @six.add_metaclass(SerializerMetaclass) class Serializer(BaseSerializer): def __init__(self, *args, **kwargs): - self.context = kwargs.pop('context', {}) kwargs.pop('partial', None) kwargs.pop('many', None) @@ -198,13 +204,6 @@ class Serializer(BaseSerializer): def _get_base_fields(self): return copy.deepcopy(self._declared_fields) - def bind(self, field_name, parent, root): - # If the serializer is used as a field then when it becomes bound - # it also needs to bind all its child fields. - super(Serializer, self).bind(field_name, parent, root) - for field_name, field in self.fields.items(): - field.bind(field_name, self, root) - def get_initial(self): return dict([ (field.field_name, field.get_initial()) @@ -290,17 +289,10 @@ class ListSerializer(BaseSerializer): self.child = kwargs.pop('child', copy.deepcopy(self.child)) assert self.child is not None, '`child` is a required argument.' assert not inspect.isclass(self.child), '`child` has not been instantiated.' - self.context = kwargs.pop('context', {}) kwargs.pop('partial', None) super(ListSerializer, self).__init__(*args, **kwargs) - self.child.bind('', self, self) - - def bind(self, field_name, parent, root): - # If the list is used as a field then it needs to provide - # the current context to the child serializer. - super(ListSerializer, self).bind(field_name, parent, root) - self.child.bind(field_name, self, root) + self.child.bind(field_name='', parent=self) def get_value(self, dictionary): # We override the default field access in order to support -- cgit v1.2.3 From b47ca158b9ba9733baad080e648d24b0465ec697 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Sep 2014 12:09:12 +0100 Subject: Check for redundant on SerializerMethodField --- rest_framework/fields.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 446732c3..328e93ef 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -178,7 +178,7 @@ class Field(object): # In order to enforce a consistent style, we error if a redundant # 'source' argument has been used. For example: # my_field = serializer.CharField(source='my_field') - assert self._kwargs.get('source') != field_name, ( + assert self.source != field_name, ( "It is redundant to specify `source='%s'` on field '%s' in " "serializer '%s', because it is the same as the field name. " "Remove the `source` keyword argument." % @@ -883,17 +883,32 @@ class SerializerMethodField(Field): def get_extra_info(self, obj): return ... # Calculate some data to return. """ - def __init__(self, method_attr=None, **kwargs): - self.method_attr = method_attr + def __init__(self, method_name=None, **kwargs): + self.method_name = method_name kwargs['source'] = '*' kwargs['read_only'] = True super(SerializerMethodField, self).__init__(**kwargs) + def bind(self, field_name, parent): + # In order to enforce a consistent style, we error if a redundant + # 'method_name' argument has been used. For example: + # my_field = serializer.CharField(source='my_field') + default_method_name = 'get_{field_name}'.format(field_name=field_name) + assert self.method_name != default_method_name, ( + "It is redundant to specify `%s` on SerializerMethodField '%s' in " + "serializer '%s', because it is the same as the default method name. " + "Remove the `method_name` argument." % + (self.method_name, field_name, parent.__class__.__name__) + ) + + # The method name should default to `get_{field_name}`. + if self.method_name is None: + self.method_name = default_method_name + + super(SerializerMethodField, self).bind(field_name, parent) + def to_representation(self, value): - method_attr = self.method_attr - if method_attr is None: - method_attr = 'get_{field_name}'.format(field_name=self.field_name) - method = getattr(self.parent, method_attr) + method = getattr(self.parent, self.method_name) return method(value) -- cgit v1.2.3 From 8ee92f8a18c3a31a2a95233f36754203dc60bb18 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Sep 2014 13:10:33 +0100 Subject: Refuse to downcast from datetime to date or time --- rest_framework/fields.py | 118 ++++++++++++++++++++++++++--------------------- 1 file changed, 65 insertions(+), 53 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 328e93ef..d855e6fd 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -608,120 +608,126 @@ class DecimalField(Field): # Date & time fields... -class DateField(Field): +class DateTimeField(Field): default_error_messages = { - 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'), - 'datetime': _('Expected a date but got a datetime.'), + 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'), + 'date': _('Expected a datetime but got a date.'), } - format = api_settings.DATE_FORMAT - input_formats = api_settings.DATE_INPUT_FORMATS + format = api_settings.DATETIME_FORMAT + input_formats = api_settings.DATETIME_INPUT_FORMATS + default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None - def __init__(self, format=empty, input_formats=None, *args, **kwargs): + def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs): self.format = format if format is not empty else self.format self.input_formats = input_formats if input_formats is not None else self.input_formats - super(DateField, self).__init__(*args, **kwargs) + self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone + super(DateTimeField, self).__init__(*args, **kwargs) + + def enforce_timezone(self, value): + """ + When `self.default_timezone` is `None`, always return naive datetimes. + When `self.default_timezone` is not `None`, always return aware datetimes. + """ + if (self.default_timezone is not None) and not timezone.is_aware(value): + return timezone.make_aware(value, self.default_timezone) + elif (self.default_timezone is None) and timezone.is_aware(value): + return timezone.make_naive(value, timezone.UTC()) + return value def to_internal_value(self, value): - if isinstance(value, datetime.datetime): - self.fail('datetime') + if (isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): + self.fail('date') - if isinstance(value, datetime.date): - return value + if isinstance(value, datetime.datetime): + return self.enforce_timezone(value) for format in self.input_formats: if format.lower() == ISO_8601: try: - parsed = parse_date(value) + parsed = parse_datetime(value) except (ValueError, TypeError): pass else: if parsed is not None: - return parsed + return self.enforce_timezone(parsed) else: try: parsed = datetime.datetime.strptime(value, format) except (ValueError, TypeError): pass else: - return parsed.date() + return self.enforce_timezone(parsed) - humanized_format = humanize_datetime.date_formats(self.input_formats) + humanized_format = humanize_datetime.datetime_formats(self.input_formats) self.fail('invalid', format=humanized_format) def to_representation(self, value): if value is None or self.format is None: return value - if isinstance(value, datetime.datetime): - value = value.date() - if self.format.lower() == ISO_8601: - return value.isoformat() + ret = value.isoformat() + if ret.endswith('+00:00'): + ret = ret[:-6] + 'Z' + return ret return value.strftime(self.format) -class DateTimeField(Field): +class DateField(Field): default_error_messages = { - 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'), - 'date': _('Expected a datetime but got a date.'), + 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'), + 'datetime': _('Expected a date but got a datetime.'), } - format = api_settings.DATETIME_FORMAT - input_formats = api_settings.DATETIME_INPUT_FORMATS - default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None + format = api_settings.DATE_FORMAT + input_formats = api_settings.DATE_INPUT_FORMATS - def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs): + def __init__(self, format=empty, input_formats=None, *args, **kwargs): self.format = format if format is not empty else self.format self.input_formats = input_formats if input_formats is not None else self.input_formats - self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone - super(DateTimeField, self).__init__(*args, **kwargs) - - def enforce_timezone(self, value): - """ - When `self.default_timezone` is `None`, always return naive datetimes. - When `self.default_timezone` is not `None`, always return aware datetimes. - """ - if (self.default_timezone is not None) and not timezone.is_aware(value): - return timezone.make_aware(value, self.default_timezone) - elif (self.default_timezone is None) and timezone.is_aware(value): - return timezone.make_naive(value, timezone.UTC()) - return value + super(DateField, self).__init__(*args, **kwargs) def to_internal_value(self, value): - if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): - self.fail('date') - if isinstance(value, datetime.datetime): - return self.enforce_timezone(value) + self.fail('datetime') + + if isinstance(value, datetime.date): + return value for format in self.input_formats: if format.lower() == ISO_8601: try: - parsed = parse_datetime(value) + parsed = parse_date(value) except (ValueError, TypeError): pass else: if parsed is not None: - return self.enforce_timezone(parsed) + return parsed else: try: parsed = datetime.datetime.strptime(value, format) except (ValueError, TypeError): pass else: - return self.enforce_timezone(parsed) + return parsed.date() - humanized_format = humanize_datetime.datetime_formats(self.input_formats) + humanized_format = humanize_datetime.date_formats(self.input_formats) self.fail('invalid', format=humanized_format) def to_representation(self, value): if value is None or self.format is None: return value + # Applying a `DateField` to a datetime value is almost always + # not a sensible thing to do, as it means naively dropping + # any explicit or implicit timezone info. + assert not isinstance(value, datetime.datetime), ( + 'Expected a `date`, but got a `datetime`. Refusing to coerce, ' + 'as this may mean losing timezone information. Use a custom ' + 'read-only field and deal with timezone issues explicitly.' + ) + if self.format.lower() == ISO_8601: - ret = value.isoformat() - if ret.endswith('+00:00'): - ret = ret[:-6] + 'Z' - return ret + return value.isoformat() return value.strftime(self.format) @@ -765,8 +771,14 @@ class TimeField(Field): if value is None or self.format is None: return value - if isinstance(value, datetime.datetime): - value = value.time() + # Applying a `TimeField` to a datetime value is almost always + # not a sensible thing to do, as it means naively dropping + # any explicit or implicit timezone info. + assert not isinstance(value, datetime.datetime), ( + 'Expected a `time`, but got a `datetime`. Refusing to coerce, ' + 'as this may mean losing timezone information. Use a custom ' + 'read-only field and deal with timezone issues explicitly.' + ) if self.format.lower() == ISO_8601: return value.isoformat() -- cgit v1.2.3 From 3a5335f09f58439f8e3c0bddbed8e4c7eeb32482 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Sep 2014 13:12:02 +0100 Subject: Fix syntax error --- rest_framework/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index d855e6fd..7beccbb7 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -635,7 +635,7 @@ class DateTimeField(Field): return value def to_internal_value(self, value): - if (isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): + if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): self.fail('date') if isinstance(value, datetime.datetime): -- cgit v1.2.3 From 417fe1b675bd1d42518fb89a6f81547caef5b735 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 25 Sep 2014 13:37:26 +0100 Subject: Partial support --- rest_framework/fields.py | 35 +++++++++++++++++++++++++---------- rest_framework/serializers.py | 6 ++++-- 2 files changed, 29 insertions(+), 12 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 7beccbb7..032bfd04 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -109,8 +109,7 @@ class Field(object): def __init__(self, read_only=False, write_only=False, required=None, default=empty, initial=None, source=None, label=None, help_text=None, style=None, - error_messages=None, validators=[], allow_null=False, - context=None): + error_messages=None, validators=[], allow_null=False): self._creation_counter = Field._creation_counter Field._creation_counter += 1 @@ -139,7 +138,6 @@ class Field(object): # These are set up by `.bind()` when the field is added to a serializer. self.field_name = None self.parent = None - self._context = {} if (context is None) else context # Collect default error message from self and parent classes messages = {} @@ -163,13 +161,6 @@ class Field(object): kwargs = copy.deepcopy(self._kwargs) return self.__class__(*args, **kwargs) - @property - def context(self): - root = self - while root.parent is not None: - root = root.parent - return root._context - def bind(self, field_name, parent): """ Setup the context for the field instance. @@ -254,6 +245,8 @@ class Field(object): """ if data is empty: if self.required: + if getattr(self.root, 'partial', False): + raise SkipField() self.fail('required') return self.get_default() @@ -304,7 +297,29 @@ class Field(object): raise AssertionError(msg) raise ValidationError(msg.format(**kwargs)) + @property + def root(self): + """ + Returns the top-level serializer for this field. + """ + root = self + while root.parent is not None: + root = root.parent + return root + + @property + def context(self): + """ + Returns the context as passed to the root serializer on initialization. + """ + return getattr(self.root, '_context', {}) + def __repr__(self): + """ + Fields are represented using their initial calling arguments. + This allows us to create descriptive representations for serializer + instances that show all the declared fields on the serializer. + """ return representation.field_repr(self) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 04721c7a..b6a1898c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -181,8 +181,9 @@ class BindingDict(object): @six.add_metaclass(SerializerMetaclass) class Serializer(BaseSerializer): def __init__(self, *args, **kwargs): - kwargs.pop('partial', None) kwargs.pop('many', None) + self.partial = kwargs.pop('partial', False) + self._context = kwargs.pop('context', {}) super(Serializer, self).__init__(*args, **kwargs) @@ -289,7 +290,8 @@ class ListSerializer(BaseSerializer): self.child = kwargs.pop('child', copy.deepcopy(self.child)) assert self.child is not None, '`child` is a required argument.' assert not inspect.isclass(self.child), '`child` has not been instantiated.' - kwargs.pop('partial', None) + self.partial = kwargs.pop('partial', False) + self._context = kwargs.pop('context', {}) super(ListSerializer, self).__init__(*args, **kwargs) self.child.bind(field_name='', parent=self) -- cgit v1.2.3 From 2859eaf524bca23f27e666d24a0b63ba61698a76 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 26 Sep 2014 10:46:52 +0100 Subject: request.data attribute --- rest_framework/authtoken/views.py | 2 +- rest_framework/fields.py | 51 +++++++++++++++++++++++---------------- rest_framework/filters.py | 6 ++--- rest_framework/generics.py | 4 +-- rest_framework/mixins.py | 6 ++--- rest_framework/negotiation.py | 4 +-- rest_framework/renderers.py | 4 +-- rest_framework/request.py | 37 +++++++++++++++++++++++++--- rest_framework/serializers.py | 21 +++++++--------- 9 files changed, 86 insertions(+), 49 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 94e6f061..103abb27 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -16,7 +16,7 @@ class ObtainAuthToken(APIView): model = Token def post(self, request): - serializer = self.serializer_class(data=request.DATA) + serializer = self.serializer_class(data=request.data) if serializer.is_valid(): user = serializer.validated_data['user'] token, created = Token.objects.get_or_create(user=user) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 032bfd04..ec07a413 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -56,7 +56,7 @@ def get_attribute(instance, attrs): except AttributeError as exc: try: return instance[attr] - except (KeyError, TypeError): + except (KeyError, TypeError, AttributeError): raise exc return instance @@ -90,6 +90,7 @@ NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' +USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField' MISSING_ERROR_MESSAGE = ( 'ValidationError raised by `{class_name}`, but error key `{key}` does ' 'not exist in the `error_messages` dictionary.' @@ -105,9 +106,10 @@ class Field(object): } default_validators = [] default_empty_html = None + initial = None def __init__(self, read_only=False, write_only=False, - required=None, default=empty, initial=None, source=None, + required=None, default=empty, initial=empty, source=None, label=None, help_text=None, style=None, error_messages=None, validators=[], allow_null=False): self._creation_counter = Field._creation_counter @@ -122,13 +124,14 @@ class Field(object): assert not (read_only and required), NOT_READ_ONLY_REQUIRED assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT assert not (required and default is not empty), NOT_REQUIRED_DEFAULT + assert not (read_only and self.__class__ == Field), USE_READONLYFIELD self.read_only = read_only self.write_only = write_only self.required = required self.default = default self.source = source - self.initial = initial + self.initial = self.initial if (initial is empty) else initial self.label = label self.help_text = help_text self.style = {} if style is None else style @@ -146,24 +149,10 @@ class Field(object): messages.update(error_messages or {}) self.error_messages = messages - def __new__(cls, *args, **kwargs): - """ - When a field is instantiated, we store the arguments that were used, - so that we can present a helpful representation of the object. - """ - instance = super(Field, cls).__new__(cls) - instance._args = args - instance._kwargs = kwargs - return instance - - def __deepcopy__(self, memo): - args = copy.deepcopy(self._args) - kwargs = copy.deepcopy(self._kwargs) - return self.__class__(*args, **kwargs) - def bind(self, field_name, parent): """ - Setup the context for the field instance. + Initializes the field name and parent for the field instance. + Called when a field is added to the parent serializer instance. """ # In order to enforce a consistent style, we error if a redundant @@ -244,9 +233,9 @@ class Field(object): validated data. """ if data is empty: + if getattr(self.root, 'partial', False): + raise SkipField() if self.required: - if getattr(self.root, 'partial', False): - raise SkipField() self.fail('required') return self.get_default() @@ -314,6 +303,25 @@ class Field(object): """ return getattr(self.root, '_context', {}) + def __new__(cls, *args, **kwargs): + """ + When a field is instantiated, we store the arguments that were used, + so that we can present a helpful representation of the object. + """ + instance = super(Field, cls).__new__(cls) + instance._args = args + instance._kwargs = kwargs + return instance + + def __deepcopy__(self, memo): + """ + When cloning fields we instantiate using the arguments it was + originally created with, rather than copying the complete state. + """ + args = copy.deepcopy(self._args) + kwargs = copy.deepcopy(self._kwargs) + return self.__class__(*args, **kwargs) + def __repr__(self): """ Fields are represented using their initial calling arguments. @@ -358,6 +366,7 @@ class NullBooleanField(Field): 'invalid': _('`{input}` is not a valid boolean.') } default_empty_html = None + initial = None TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) NULL_VALUES = set(('n', 'N', 'null', 'Null', 'NULL', '', None)) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index c580f935..085dfe65 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -64,7 +64,7 @@ class DjangoFilterBackend(BaseFilterBackend): filter_class = self.get_filter_class(view, queryset) if filter_class: - return filter_class(request.QUERY_PARAMS, queryset=queryset).qs + return filter_class(request.query_params, queryset=queryset).qs return queryset @@ -78,7 +78,7 @@ class SearchFilter(BaseFilterBackend): Search terms are set by a ?search=... query parameter, and may be comma and/or whitespace delimited. """ - params = request.QUERY_PARAMS.get(self.search_param, '') + params = request.query_params.get(self.search_param, '') return params.replace(',', ' ').split() def construct_search(self, field_name): @@ -121,7 +121,7 @@ class OrderingFilter(BaseFilterBackend): the `ordering_param` value on the OrderingFilter or by specifying an `ORDERING_PARAM` value in the API settings. """ - params = request.QUERY_PARAMS.get(self.ordering_param) + params = request.query_params.get(self.ordering_param) if params: return [param.strip() for param in params.split(',')] diff --git a/rest_framework/generics.py b/rest_framework/generics.py index f49b0a43..cf903dab 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -112,7 +112,7 @@ class GenericAPIView(views.APIView): paginator = self.paginator_class(queryset, page_size) page_kwarg = self.kwargs.get(self.page_kwarg) - page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) + page_query_param = self.request.query_params.get(self.page_kwarg) page = page_kwarg or page_query_param or 1 try: page_number = paginator.validate_number(page) @@ -166,7 +166,7 @@ class GenericAPIView(views.APIView): if self.paginate_by_param: try: return strict_positive_int( - self.request.QUERY_PARAMS[self.paginate_by_param], + self.request.query_params[self.paginate_by_param], cutoff=self.max_paginate_by ) except (KeyError, ValueError): diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 14a6b44b..04b7a763 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -18,7 +18,7 @@ class CreateModelMixin(object): Create a model instance. """ def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.DATA) + serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) serializer.save() headers = self.get_success_headers(serializer.data) @@ -62,7 +62,7 @@ class UpdateModelMixin(object): def update(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) instance = self.get_object() - serializer = self.get_serializer(instance, data=request.DATA, partial=partial) + serializer = self.get_serializer(instance, data=request.data, partial=partial) serializer.is_valid(raise_exception=True) serializer.save() return Response(serializer.data) @@ -95,7 +95,7 @@ class AllowPUTAsCreateMixin(object): def update(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) instance = self.get_object_or_none() - serializer = self.get_serializer(instance, data=request.DATA, partial=partial) + serializer = self.get_serializer(instance, data=request.data, partial=partial) serializer.is_valid(raise_exception=True) if instance is None: diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index ca7b5397..1838130a 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -38,7 +38,7 @@ class DefaultContentNegotiation(BaseContentNegotiation): """ # Allow URL style format override. eg. "?format=json format_query_param = self.settings.URL_FORMAT_OVERRIDE - format = format_suffix or request.QUERY_PARAMS.get(format_query_param) + format = format_suffix or request.query_params.get(format_query_param) if format: renderers = self.filter_renderers(renderers, format) @@ -87,5 +87,5 @@ class DefaultContentNegotiation(BaseContentNegotiation): Allows URL style accept override. eg. "?accept=application/json" """ header = request.META.get('HTTP_ACCEPT', '*/*') - header = request.QUERY_PARAMS.get(self.settings.URL_ACCEPT_OVERRIDE, header) + header = request.query_params.get(self.settings.URL_ACCEPT_OVERRIDE, header) return [token.strip() for token in header.split(',')] diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 3bf03e62..225f9fe8 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -120,7 +120,7 @@ class JSONPRenderer(JSONRenderer): Determine the name of the callback to wrap around the json output. """ request = renderer_context.get('request', None) - params = request and request.QUERY_PARAMS or {} + params = request and request.query_params or {} return params.get(self.callback_parameter, self.default_callback) def render(self, data, accepted_media_type=None, renderer_context=None): @@ -426,7 +426,7 @@ class BrowsableAPIRenderer(BaseRenderer): """ if request.method == method: try: - data = request.DATA + data = request.data # files = request.FILES except ParseError: data = None diff --git a/rest_framework/request.py b/rest_framework/request.py index 27532661..d80baa70 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -4,7 +4,7 @@ The Request class is used as a wrapper around the standard request object. The wrapped request then offers a richer API, in particular : - content automatically parsed according to `Content-Type` header, - and available as `request.DATA` + and available as `request.data` - full support of PUT method, including support for file uploads - form overloading of HTTP method, content type and content """ @@ -13,6 +13,7 @@ from django.conf import settings from django.http import QueryDict from django.http.multipartparser import parse_header from django.utils.datastructures import MultiValueDict +from django.utils.datastructures import MergeDict as DjangoMergeDict from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions from rest_framework.compat import BytesIO @@ -58,6 +59,15 @@ class override_method(object): self.view.action = self.action +class MergeDict(DjangoMergeDict, dict): + """ + Using this as a workaround until the parsers API is properly + addressed in 3.1. + """ + def __init__(self, *dicts): + self.dicts = dicts + + class Empty(object): """ Placeholder for unset attributes. @@ -82,6 +92,7 @@ def clone_request(request, method): parser_context=request.parser_context) ret._data = request._data ret._files = request._files + ret._full_data = request._full_data ret._content_type = request._content_type ret._stream = request._stream ret._method = method @@ -133,6 +144,7 @@ class Request(object): self.parser_context = parser_context self._data = Empty self._files = Empty + self._full_data = Empty self._method = Empty self._content_type = Empty self._stream = Empty @@ -186,12 +198,25 @@ class Request(object): return self._stream @property - def QUERY_PARAMS(self): + def query_params(self): """ More semantically correct name for request.GET. """ return self._request.GET + @property + def QUERY_PARAMS(self): + """ + Synonym for `.query_params`, for backwards compatibility. + """ + return self._request.GET + + @property + def data(self): + if not _hasattr(self, '_full_data'): + self._load_data_and_files() + return self._full_data + @property def DATA(self): """ @@ -272,6 +297,10 @@ class Request(object): if not _hasattr(self, '_data'): self._data, self._files = self._parse() + if self._files: + self._full_data = MergeDict(self._data, self._files) + else: + self._full_data = self._data def _load_method_and_content_type(self): """ @@ -333,6 +362,7 @@ class Request(object): # At this point we're committed to parsing the request as form data. self._data = self._request.POST self._files = self._request.FILES + self._full_data = MergeDict(self._data, self._files) # Method overloading - change the method and remove the param from the content. if ( @@ -350,7 +380,7 @@ class Request(object): ): self._content_type = self._data[self._CONTENTTYPE_PARAM] self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding'])) - self._data, self._files = (Empty, Empty) + self._data, self._files, self._full_data = (Empty, Empty, Empty) def _parse(self): """ @@ -380,6 +410,7 @@ class Request(object): # logging the request or similar. self._data = QueryDict('', encoding=self._request._encoding) self._files = MultiValueDict() + self._full_data = self._data raise # Parser classes may return the raw data, or a diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b6a1898c..a2b878ec 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -57,21 +57,24 @@ class BaseSerializer(Field): def to_representation(self, instance): raise NotImplementedError('`to_representation()` must be implemented.') - def update(self, instance, attrs): + def update(self, instance, validated_data): raise NotImplementedError('`update()` must be implemented.') - def create(self, attrs): + def create(self, validated_data): raise NotImplementedError('`create()` must be implemented.') def save(self, extras=None): - attrs = self.validated_data + validated_data = self.validated_data if extras is not None: - attrs = dict(list(attrs.items()) + list(extras.items())) + validated_data = dict( + list(validated_data.items()) + + list(extras.items()) + ) if self.instance is not None: - self.update(self.instance, attrs) + self.update(self.instance, validated_data) else: - self.instance = self.create(attrs) + self.instance = self.create(validated_data) return self.instance @@ -321,12 +324,6 @@ class ListSerializer(BaseSerializer): def create(self, attrs_list): return [self.child.create(attrs) for attrs in attrs_list] - def save(self): - if self.instance is not None: - self.update(self.instance, self.validated_data) - self.instance = self.create(self.validated_data) - return self.instance - def __repr__(self): return representation.list_repr(self, indent=1) -- cgit v1.2.3 From 43e80c74b225e17edfe8a90da893823bf50b946f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 26 Sep 2014 11:56:29 +0100 Subject: Release notes --- rest_framework/serializers.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index a2b878ec..86bed773 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -75,6 +75,9 @@ class BaseSerializer(Field): self.update(self.instance, validated_data) else: self.instance = self.create(validated_data) + assert self.instance is not None, ( + '`create()` did not return an object instance.' + ) return self.instance -- cgit v1.2.3 From 8b8623c5f84d443d26804cac52a793a3037a1dd0 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 26 Sep 2014 12:48:20 +0100 Subject: Allow many, partial and context in BaseSerializer --- rest_framework/serializers.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 86bed773..245ec26f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -47,9 +47,20 @@ class BaseSerializer(Field): """ def __init__(self, instance=None, data=None, **kwargs): - super(BaseSerializer, self).__init__(**kwargs) self.instance = instance self._initial_data = data + self.partial = kwargs.pop('partial', False) + self._context = kwargs.pop('context', {}) + kwargs.pop('many', None) + super(BaseSerializer, self).__init__(**kwargs) + + def __new__(cls, *args, **kwargs): + # We override this method in order to automagically create + # `ListSerializer` classes instead when `many=True` is set. + if kwargs.pop('many', False): + kwargs['child'] = cls() + return ListSerializer(*args, **kwargs) + return super(BaseSerializer, cls).__new__(cls, *args, **kwargs) def to_internal_value(self, data): raise NotImplementedError('`to_internal_value()` must be implemented.') @@ -187,10 +198,6 @@ class BindingDict(object): @six.add_metaclass(SerializerMetaclass) class Serializer(BaseSerializer): def __init__(self, *args, **kwargs): - kwargs.pop('many', None) - self.partial = kwargs.pop('partial', False) - self._context = kwargs.pop('context', {}) - super(Serializer, self).__init__(*args, **kwargs) # Every new serializer is created with a clone of the field instances. @@ -200,14 +207,6 @@ class Serializer(BaseSerializer): for key, value in self._get_base_fields().items(): self.fields[key] = value - def __new__(cls, *args, **kwargs): - # We override this method in order to automagically create - # `ListSerializer` classes instead when `many=True` is set. - if kwargs.pop('many', False): - kwargs['child'] = cls() - return ListSerializer(*args, **kwargs) - return super(Serializer, cls).__new__(cls, *args, **kwargs) - def _get_base_fields(self): return copy.deepcopy(self._declared_fields) @@ -296,9 +295,6 @@ class ListSerializer(BaseSerializer): self.child = kwargs.pop('child', copy.deepcopy(self.child)) assert self.child is not None, '`child` is a required argument.' assert not inspect.isclass(self.child), '`child` has not been instantiated.' - self.partial = kwargs.pop('partial', False) - self._context = kwargs.pop('context', {}) - super(ListSerializer, self).__init__(*args, **kwargs) self.child.bind(field_name='', parent=self) -- cgit v1.2.3 From 2e87de01430d7fec83f00948e60c8d61b317053b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 26 Sep 2014 13:08:20 +0100 Subject: Added ListField --- rest_framework/fields.py | 38 ++++++++++++++++++++++++++++++++++++++ rest_framework/serializers.py | 6 ++++-- 2 files changed, 42 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index ec07a413..cf42d36c 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -881,6 +881,44 @@ class ImageField(Field): # Advanced field types... +class ListField(Field): + child = None + initial = [] + default_error_messages = { + 'not_a_list': _('Expected a list of items but got type `{input_type}`') + } + + def __init__(self, *args, **kwargs): + self.child = kwargs.pop('child', copy.deepcopy(self.child)) + assert self.child is not None, '`child` is a required argument.' + assert not inspect.isclass(self.child), '`child` has not been instantiated.' + super(ListField, self).__init__(*args, **kwargs) + self.child.bind(field_name='', parent=self) + + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if html.is_html_input(dictionary): + return html.parse_html_list(dictionary, prefix=self.field_name) + return dictionary.get(self.field_name, empty) + + def to_internal_value(self, data): + """ + List of dicts of native values <- List of dicts of primitive datatypes. + """ + if html.is_html_input(data): + data = html.parse_html_list(data) + if isinstance(data, type('')) or not hasattr(data, '__iter__'): + self.fail('not_a_list', input_type=type(data).__name__) + return [self.child.run_validation(item) for item in data] + + def to_representation(self, data): + """ + List of object instances -> List of dicts of primitive datatypes. + """ + return [self.child.to_representation(item) for item in data] + + class ReadOnlyField(Field): """ A read-only field that simply returns the field value. diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 245ec26f..fa2e8fb1 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -287,6 +287,9 @@ class Serializer(BaseSerializer): return representation.serializer_repr(self, indent=1) +# There's some replication of `ListField` here, +# but that's probably better than obfuscating the call hierarchy. + class ListSerializer(BaseSerializer): child = None initial = [] @@ -301,7 +304,7 @@ class ListSerializer(BaseSerializer): def get_value(self, dictionary): # We override the default field access in order to support # lists in HTML forms. - if is_html_input(dictionary): + if html.is_html_input(dictionary): return html.parse_html_list(dictionary, prefix=self.field_name) return dictionary.get(self.field_name, empty) @@ -311,7 +314,6 @@ class ListSerializer(BaseSerializer): """ if html.is_html_input(data): data = html.parse_html_list(data) - return [self.child.run_validation(item) for item in data] def to_representation(self, data): -- cgit v1.2.3 From 33ccf40b76ddae790c34c294a133219e68efb946 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 26 Sep 2014 13:14:08 +0100 Subject: Update version number --- rest_framework/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 7f724c18..261c9c98 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -8,7 +8,7 @@ ______ _____ _____ _____ __ """ __title__ = 'Django REST framework' -__version__ = '2.4.3' +__version__ = '3.0.0' __author__ = 'Tom Christie' __license__ = 'BSD 2-Clause' __copyright__ = 'Copyright 2011-2014 Tom Christie' -- cgit v1.2.3 From 609014460861fdfe82054551790d6439292dde7b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 26 Sep 2014 14:32:44 +0100 Subject: Simplify serialization slightly --- rest_framework/fields.py | 11 +++++++---- rest_framework/serializers.py | 3 +-- 2 files changed, 8 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index cf42d36c..4c49aaba 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -202,12 +202,15 @@ class Field(object): return self.default_empty_html if (ret == '') else ret return dictionary.get(self.field_name, empty) - def get_attribute(self, instance): + def get_field_representation(self, instance): """ - Given the *outgoing* object instance, return the value for this field - that should be returned as a primative value. + Given the outgoing object instance, return the primative value + that should be used for this field. """ - return get_attribute(instance, self.source_attrs) + attribute = get_attribute(instance, self.source_attrs) + if attribute is None: + return None + return self.to_representation(attribute) def get_default(self): """ diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index fa2e8fb1..080b958d 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -268,8 +268,7 @@ class Serializer(BaseSerializer): fields = [field for field in self.fields.values() if not field.write_only] for field in fields: - native_value = field.get_attribute(instance) - ret[field.field_name] = field.to_representation(native_value) + ret[field.field_name] = field.get_field_representation(instance) return ret -- cgit v1.2.3 From dee3f78cb688b1bee892ef78d6eec23ccf67a80e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 26 Sep 2014 17:06:20 +0100 Subject: FileField and ImageField --- rest_framework/compat.py | 9 ----- rest_framework/fields.py | 82 ++++++++++++++++++++++++++++++++++++---------- rest_framework/settings.py | 3 +- 3 files changed, 66 insertions(+), 28 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 7303c32a..89af9b48 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -84,15 +84,6 @@ except ImportError: from collections import UserDict from collections import MutableMapping as DictMixin -# Try to import PIL in either of the two ways it can end up installed. -try: - from PIL import Image -except ImportError: - try: - import Image - except ImportError: - Image = None - def get_model_name(model_cls): try: diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 4c49aaba..f4b53279 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,3 +1,4 @@ +from django import forms from django.conf import settings from django.core import validators from django.core.exceptions import ValidationError @@ -427,8 +428,6 @@ class CharField(Field): return str(data) def to_representation(self, value): - if value is None: - return None return str(value) @@ -446,8 +445,6 @@ class EmailField(CharField): return str(data).strip() def to_representation(self, value): - if value is None: - return None return str(value).strip() @@ -513,8 +510,6 @@ class IntegerField(Field): return data def to_representation(self, value): - if value is None: - return None return int(value) @@ -543,8 +538,6 @@ class FloatField(Field): self.fail('invalid') def to_representation(self, value): - if value is None: - return None return float(value) @@ -616,9 +609,6 @@ class DecimalField(Field): return value def to_representation(self, value): - if value in (None, ''): - return None - if not isinstance(value, decimal.Decimal): value = decimal.Decimal(str(value).strip()) @@ -689,7 +679,7 @@ class DateTimeField(Field): self.fail('invalid', format=humanized_format) def to_representation(self, value): - if value is None or self.format is None: + if self.format is None: return value if self.format.lower() == ISO_8601: @@ -741,7 +731,7 @@ class DateField(Field): self.fail('invalid', format=humanized_format) def to_representation(self, value): - if value is None or self.format is None: + if self.format is None: return value # Applying a `DateField` to a datetime value is almost always @@ -795,7 +785,7 @@ class TimeField(Field): self.fail('invalid', format=humanized_format) def to_representation(self, value): - if value is None or self.format is None: + if self.format is None: return value # Applying a `TimeField` to a datetime value is almost always @@ -875,14 +865,68 @@ class MultipleChoiceField(ChoiceField): # File types... class FileField(Field): - pass # TODO + default_error_messages = { + 'required': _("No file was submitted."), + 'invalid': _("The submitted data was not a file. Check the encoding type on the form."), + 'no_name': _("No filename could be determined."), + 'empty': _("The submitted file is empty."), + 'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'), + } + use_url = api_settings.UPLOADED_FILES_USE_URL + def __init__(self, *args, **kwargs): + self.max_length = kwargs.pop('max_length', None) + self.allow_empty_file = kwargs.pop('allow_empty_file', False) + self.use_url = kwargs.pop('use_url', self.use_url) + super(FileField, self).__init__(*args, **kwargs) -class ImageField(Field): - pass # TODO + def to_internal_value(self, data): + try: + # `UploadedFile` objects should have name and size attributes. + file_name = data.name + file_size = data.size + except AttributeError: + self.fail('invalid') + if not file_name: + self.fail('no_name') + if not self.allow_empty_file and not file_size: + self.fail('empty') + if self.max_length and len(file_name) > self.max_length: + self.fail('max_length', max_length=self.max_length, length=len(file_name)) -# Advanced field types... + return data + + def to_representation(self, value): + if self.use_url: + return settings.MEDIA_URL + value.url + return value.name + + +class ImageField(FileField): + default_error_messages = { + 'invalid_image': _( + 'Upload a valid image. The file you uploaded was either not an ' + 'image or a corrupted image.' + ), + } + + def __init__(self, *args, **kwargs): + self._DjangoImageField = kwargs.pop('_DjangoImageField', forms.ImageField) + super(ImageField, self).__init__(*args, **kwargs) + + def to_internal_value(self, data): + # Image validation is a bit grungy, so we'll just outright + # defer to Django's implementation so we don't need to + # consider it, or treat PIL as a test dependancy. + file_object = super(ImageField, self).to_internal_value(data) + django_field = self._DjangoImageField() + django_field.error_messages = self.error_messages + django_field.to_python(file_object) + return file_object + + +# Composite field types... class ListField(Field): child = None @@ -922,6 +966,8 @@ class ListField(Field): return [self.child.to_representation(item) for item in data] +# Miscellaneous field types... + class ReadOnlyField(Field): """ A read-only field that simply returns the field value. diff --git a/rest_framework/settings.py b/rest_framework/settings.py index d7fb0a43..1e8c27fc 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -110,7 +110,8 @@ DEFAULTS = { # Encoding 'UNICODE_JSON': True, 'COMPACT_JSON': True, - 'COERCE_DECIMAL_TO_STRING': True + 'COERCE_DECIMAL_TO_STRING': True, + 'UPLOADED_FILES_USE_URL': True } -- cgit v1.2.3 From 43fd5a873051c99600386c1fdc9fa368edeb6eda Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Sep 2014 09:24:03 +0100 Subject: Uniqueness validation --- rest_framework/fields.py | 4 +++ rest_framework/utils/field_mapping.py | 5 +++ rest_framework/validators.py | 57 +++++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+) create mode 100644 rest_framework/validators.py (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f4b53279..231f693c 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -150,6 +150,10 @@ class Field(object): messages.update(error_messages or {}) self.error_messages = messages + for validator in validators: + if getattr(validator, 'requires_context', False): + validator.serializer_field = self + def bind(self, field_name, parent): """ Initializes the field name and parent for the field instance. diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index c3794083..cf9d910a 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -6,6 +6,7 @@ from django.core import validators from django.db import models from django.utils.text import capfirst from rest_framework.compat import clean_manytomany_helptext +from rest_framework.validators import UniqueValidator import inspect @@ -156,6 +157,10 @@ def get_field_kwargs(field_name, model_field): if validator is not validators.validate_slug ] + if getattr(model_field, 'unique', False): + validator = UniqueValidator(queryset=model_field.model._default_manager) + validator_kwarg.append(validator) + max_digits = getattr(model_field, 'max_digits', None) if max_digits is not None: kwargs['max_digits'] = max_digits diff --git a/rest_framework/validators.py b/rest_framework/validators.py new file mode 100644 index 00000000..f5fbeb3c --- /dev/null +++ b/rest_framework/validators.py @@ -0,0 +1,57 @@ +from django.core.exceptions import ValidationError + + +class UniqueValidator: + # Validators with `requires_context` will have the field instance + # passed to them when the field is instantiated. + requires_context = True + + def __init__(self, queryset): + self.queryset = queryset + self.serializer_field = None + + def get_queryset(self): + return self.queryset.all() + + def __call__(self, value): + field = self.serializer_field + + # Determine the model field name that the serializer field corresponds to. + field_name = field.source_attrs[0] if field.source_attrs else field.field_name + + # Determine the existing instance, if this is an update operation. + instance = getattr(field.parent, 'instance', None) + + # Ensure uniqueness. + filter_kwargs = {field_name: value} + queryset = self.get_queryset().filter(**filter_kwargs) + if instance: + queryset = queryset.exclude(pk=instance.pk) + if queryset.exists(): + raise ValidationError('This field must be unique.') + + +class UniqueTogetherValidator: + requires_context = True + + def __init__(self, queryset, fields): + self.queryset = queryset + self.fields = fields + self.serializer_field = None + + def __call__(self, value): + serializer = self.serializer_field + + # Determine the existing instance, if this is an update operation. + instance = getattr(serializer, 'instance', None) + + # Ensure uniqueness. + filter_kwargs = dict([ + (field_name, value[field_name]) for field_name in self.fields + ]) + queryset = self.get_queryset().filter(**filter_kwargs) + if instance: + queryset = queryset.exclude(pk=instance.pk) + if queryset.exists(): + field_names = ' and '.join(self.fields) + raise ValidationError('The fields %s must make a unique set.' % field_names) -- cgit v1.2.3 From 9805a085fb115785f272489dc24b51ba6f8e6329 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Sep 2014 11:23:02 +0100 Subject: UniqueTogetherValidator --- rest_framework/serializers.py | 80 ++++++++++++++++++++++++++++++++++++++----- rest_framework/validators.py | 38 +++++++++++++++----- 2 files changed, 101 insertions(+), 17 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 080b958d..09ad376a 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -23,6 +23,7 @@ from rest_framework.utils.field_mapping import ( get_relation_kwargs, get_nested_relation_kwargs, ClassLookupDict ) +from rest_framework.validators import UniqueTogetherValidator import copy import inspect @@ -95,7 +96,7 @@ class BaseSerializer(Field): def is_valid(self, raise_exception=False): if not hasattr(self, '_validated_data'): try: - self._validated_data = self.to_internal_value(self._initial_data) + self._validated_data = self.run_validation(self._initial_data) except ValidationError as exc: self._validated_data = {} self._errors = exc.message_dict @@ -223,15 +224,43 @@ class Serializer(BaseSerializer): return html.parse_html_dict(dictionary, prefix=self.field_name) return dictionary.get(self.field_name, empty) - def to_internal_value(self, data): + def run_validation(self, data=empty): """ - Dict of native values <- Dict of primitive datatypes. + We override the default `run_validation`, because the validation + performed by validators and the `.validate()` method should + be coerced into an error dictionary with a 'non_fields_error' key. """ + if data is empty: + if getattr(self.root, 'partial', False): + raise SkipField() + if self.required: + self.fail('required') + return self.get_default() + + if data is None: + if not self.allow_null: + self.fail('null') + return None + if not isinstance(data, dict): raise ValidationError({ api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data'] }) + value = self.to_internal_value(data) + try: + self.run_validators(value) + self.validate(value) + except ValidationError as exc: + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: exc.messages + }) + return value + + def to_internal_value(self, data): + """ + Dict of native values <- Dict of primitive datatypes. + """ ret = {} errors = {} fields = [field for field in self.fields.values() if not field.read_only] @@ -253,12 +282,7 @@ class Serializer(BaseSerializer): if errors: raise ValidationError(errors) - try: - return self.validate(ret) - except ValidationError as exc: - raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: exc.messages - }) + return ret def to_representation(self, instance): """ @@ -355,6 +379,14 @@ class ModelSerializer(Serializer): }) _related_class = PrimaryKeyRelatedField + def __init__(self, *args, **kwargs): + super(ModelSerializer, self).__init__(*args, **kwargs) + if 'validators' not in kwargs: + validators = self.get_unique_together_validators() + if validators: + self.validators.extend(validators) + self._kwargs['validators'] = validators + def create(self, attrs): ModelClass = self.Meta.model @@ -381,6 +413,36 @@ class ModelSerializer(Serializer): setattr(obj, attr, value) obj.save() + def get_unique_together_validators(self): + field_names = set([ + field.source for field in self.fields.values() + if (field.source != '*') and ('.' not in field.source) + ]) + + validators = [] + model_class = self.Meta.model + + for unique_together in model_class._meta.unique_together: + if field_names.issuperset(set(unique_together)): + validator = UniqueTogetherValidator( + queryset=model_class._default_manager, + fields=unique_together + ) + validator.serializer_field = self + validators.append(validator) + + for parent_class in model_class._meta.parents.keys(): + for unique_together in parent_class._meta.unique_together: + if field_names.issuperset(set(unique_together)): + validator = UniqueTogetherValidator( + queryset=parent_class._default_manager, + fields=unique_together + ) + validator.serializer_field = self + validators.append(validator) + + return validators + def _get_base_fields(self): declared_fields = copy.deepcopy(self._declared_fields) diff --git a/rest_framework/validators.py b/rest_framework/validators.py index f5fbeb3c..20de4b42 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -1,18 +1,26 @@ +""" +We perform uniqueness checks explicitly on the serializer class, rather +the using Django's `.full_clean()`. + +This gives us better seperation of concerns, allows us to use single-step +object creation, and makes it possible to switch between using the implicit +`ModelSerializer` class and an equivelent explicit `Serializer` class. +""" from django.core.exceptions import ValidationError +from django.utils.translation import ugettext_lazy as _ +from rest_framework.utils.representation import smart_repr class UniqueValidator: # Validators with `requires_context` will have the field instance # passed to them when the field is instantiated. requires_context = True + message = _('This field must be unique.') def __init__(self, queryset): self.queryset = queryset self.serializer_field = None - def get_queryset(self): - return self.queryset.all() - def __call__(self, value): field = self.serializer_field @@ -24,15 +32,22 @@ class UniqueValidator: # Ensure uniqueness. filter_kwargs = {field_name: value} - queryset = self.get_queryset().filter(**filter_kwargs) + queryset = self.queryset.filter(**filter_kwargs) if instance: queryset = queryset.exclude(pk=instance.pk) if queryset.exists(): - raise ValidationError('This field must be unique.') + raise ValidationError(self.message) + + def __repr__(self): + return '<%s(queryset=%s)>' % ( + self.__class__.__name__, + smart_repr(self.queryset) + ) class UniqueTogetherValidator: requires_context = True + message = _('The fields {field_names} must make a unique set.') def __init__(self, queryset, fields): self.queryset = queryset @@ -49,9 +64,16 @@ class UniqueTogetherValidator: filter_kwargs = dict([ (field_name, value[field_name]) for field_name in self.fields ]) - queryset = self.get_queryset().filter(**filter_kwargs) + queryset = self.queryset.filter(**filter_kwargs) if instance: queryset = queryset.exclude(pk=instance.pk) if queryset.exists(): - field_names = ' and '.join(self.fields) - raise ValidationError('The fields %s must make a unique set.' % field_names) + field_names = ', '.join(self.fields) + raise ValidationError(self.message.format(field_names=field_names)) + + def __repr__(self): + return '<%s(queryset=%s, fields=%s)>' % ( + self.__class__.__name__, + smart_repr(self.queryset), + smart_repr(self.fields) + ) -- cgit v1.2.3 From d2d412993f537952fd7809ded3e981f85ec318e9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Sep 2014 11:24:21 +0100 Subject: .validate() on serializer fields --- rest_framework/fields.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 231f693c..fee6080a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -254,6 +254,7 @@ class Field(object): value = self.to_internal_value(data) self.run_validators(value) + self.validate(value) return value def run_validators(self, value): @@ -270,6 +271,9 @@ class Field(object): if errors: raise ValidationError(errors) + def validate(self, value): + pass + def to_internal_value(self, data): """ Transform the *incoming* primative data into a native value. -- cgit v1.2.3 From d1b2c8ac7faec65483cbddf4f1718ca4f5805246 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Sep 2014 14:12:26 +0100 Subject: Absolute URLs for file fields --- rest_framework/fields.py | 12 +++++++----- rest_framework/serializers.py | 2 -- 2 files changed, 7 insertions(+), 7 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index fee6080a..f7ea3b0c 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -150,10 +150,6 @@ class Field(object): messages.update(error_messages or {}) self.error_messages = messages - for validator in validators: - if getattr(validator, 'requires_context', False): - validator.serializer_field = self - def bind(self, field_name, parent): """ Initializes the field name and parent for the field instance. @@ -264,6 +260,8 @@ class Field(object): """ errors = [] for validator in self.validators: + if getattr(validator, 'requires_context', False): + validator.serializer_field = self try: validator(value) except ValidationError as exc: @@ -907,7 +905,11 @@ class FileField(Field): def to_representation(self, value): if self.use_url: - return settings.MEDIA_URL + value.url + url = settings.MEDIA_URL + value.url + request = self.context.get('request', None) + if request is not None: + return request.build_absolute_uri(url) + return url return value.name diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 09ad376a..0faa5671 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -428,7 +428,6 @@ class ModelSerializer(Serializer): queryset=model_class._default_manager, fields=unique_together ) - validator.serializer_field = self validators.append(validator) for parent_class in model_class._meta.parents.keys(): @@ -438,7 +437,6 @@ class ModelSerializer(Serializer): queryset=parent_class._default_manager, fields=unique_together ) - validator.serializer_field = self validators.append(validator) return validators -- cgit v1.2.3 From 381771731f48c75e7d5951e353049cceec386512 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 1 Oct 2014 13:09:14 +0100 Subject: Use six.text_type instead of str everywhere --- rest_framework/compat.py | 9 +++++---- rest_framework/fields.py | 22 +++++++++++----------- rest_framework/filters.py | 3 ++- rest_framework/generics.py | 5 +++-- rest_framework/parsers.py | 3 ++- rest_framework/relations.py | 3 ++- rest_framework/reverse.py | 3 ++- rest_framework/utils/encoders.py | 6 +++--- 8 files changed, 30 insertions(+), 24 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 89af9b48..3993cee6 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -5,11 +5,12 @@ versions of django/python, and compatibility wrappers around optional packages. # flake8: noqa from __future__ import unicode_literals -import django -import inspect + from django.core.exceptions import ImproperlyConfigured from django.conf import settings from django.utils import six +import django +import inspect # Handle django.utils.encoding rename in 1.5 onwards. @@ -177,12 +178,12 @@ class RequestFactory(DjangoRequestFactory): r = { 'PATH_INFO': self._get_path(parsed), 'QUERY_STRING': force_text(parsed[4]), - 'REQUEST_METHOD': str(method), + 'REQUEST_METHOD': six.text_type(method), } if data: r.update({ 'CONTENT_LENGTH': len(data), - 'CONTENT_TYPE': str(content_type), + 'CONTENT_TYPE': six.text_type(content_type), 'wsgi.input': FakePayload(data), }) elif django.VERSION <= (1, 4): diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f7ea3b0c..f3ff2233 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -2,7 +2,7 @@ from django import forms from django.conf import settings from django.core import validators from django.core.exceptions import ValidationError -from django.utils import timezone +from django.utils import six, timezone from django.utils.datastructures import SortedDict from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import is_protected_type @@ -431,10 +431,10 @@ class CharField(Field): return super(CharField, self).run_validation(data) def to_internal_value(self, data): - return str(data) + return six.text_type(data) def to_representation(self, value): - return str(value) + return six.text_type(value) class EmailField(CharField): @@ -448,10 +448,10 @@ class EmailField(CharField): self.validators.append(validator) def to_internal_value(self, data): - return str(data).strip() + return six.text_type(data).strip() def to_representation(self, value): - return str(value).strip() + return six.text_type(value).strip() class RegexField(CharField): @@ -510,7 +510,7 @@ class IntegerField(Field): def to_internal_value(self, data): try: - data = int(str(data)) + data = int(six.text_type(data)) except (ValueError, TypeError): self.fail('invalid') return data @@ -616,7 +616,7 @@ class DecimalField(Field): def to_representation(self, value): if not isinstance(value, decimal.Decimal): - value = decimal.Decimal(str(value).strip()) + value = decimal.Decimal(six.text_type(value).strip()) context = decimal.getcontext().copy() context.prec = self.max_digits @@ -832,19 +832,19 @@ class ChoiceField(Field): # Allows us to deal with eg. integer choices while supporting either # integer or string input, but still get the correct datatype out. self.choice_strings_to_values = dict([ - (str(key), key) for key in self.choices.keys() + (six.text_type(key), key) for key in self.choices.keys() ]) super(ChoiceField, self).__init__(**kwargs) def to_internal_value(self, data): try: - return self.choice_strings_to_values[str(data)] + return self.choice_strings_to_values[six.text_type(data)] except KeyError: self.fail('invalid_choice', input=data) def to_representation(self, value): - return self.choice_strings_to_values[str(value)] + return self.choice_strings_to_values[six.text_type(value)] class MultipleChoiceField(ChoiceField): @@ -864,7 +864,7 @@ class MultipleChoiceField(ChoiceField): def to_representation(self, value): return set([ - self.choice_strings_to_values[str(item)] for item in value + self.choice_strings_to_values[six.text_type(item)] for item in value ]) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 085dfe65..4c485668 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -3,6 +3,7 @@ Provides generic filtering backends that can be used to filter the results returned by list views. """ from __future__ import unicode_literals + from django.core.exceptions import ImproperlyConfigured from django.db import models from django.utils import six @@ -97,7 +98,7 @@ class SearchFilter(BaseFilterBackend): if not search_fields: return queryset - orm_lookups = [self.construct_search(str(search_field)) + orm_lookups = [self.construct_search(six.text_type(search_field)) for search_field in search_fields] for search_term in self.get_search_terms(request): diff --git a/rest_framework/generics.py b/rest_framework/generics.py index cf903dab..3d6cf168 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -3,10 +3,11 @@ Generic views that provide commonly needed behaviour. """ from __future__ import unicode_literals -from django.db.models.query import QuerySet from django.core.paginator import Paginator, InvalidPage +from django.db.models.query import QuerySet from django.http import Http404 from django.shortcuts import get_object_or_404 as _get_object_or_404 +from django.utils import six from django.utils.translation import ugettext as _ from rest_framework import views, mixins from rest_framework.settings import api_settings @@ -127,7 +128,7 @@ class GenericAPIView(views.APIView): error_format = _('Invalid page (%(page_number)s): %(message)s') raise Http404(error_format % { 'page_number': page_number, - 'message': str(exc) + 'message': six.text_type(exc) }) return page diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index fa02ecf1..ccb82f03 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -5,6 +5,7 @@ They give us a generic way of being able to handle various media types on the request, such as form content or json encoded data. """ from __future__ import unicode_literals + from django.conf import settings from django.core.files.uploadhandler import StopFutureHandlers from django.http import QueryDict @@ -132,7 +133,7 @@ class MultiPartParser(BaseParser): data, files = parser.parse() return DataAndFiles(data, files) except MultiPartParserError as exc: - raise ParseError('Multipart form parse error - %s' % str(exc)) + raise ParseError('Multipart form parse error - %s' % six.text_type(exc)) class XMLParser(BaseParser): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index b37a6fed..b5effc6c 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -4,6 +4,7 @@ from rest_framework.reverse import reverse from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch, Resolver404 from django.db.models.query import QuerySet +from django.utils import six from django.utils.translation import ugettext_lazy as _ @@ -49,7 +50,7 @@ class StringRelatedField(Field): super(StringRelatedField, self).__init__(**kwargs) def to_representation(self, value): - return str(value) + return six.text_type(value) class PrimaryKeyRelatedField(RelatedField): diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index a51b07f5..a74e8aa2 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -3,6 +3,7 @@ Provide reverse functions that return fully qualified URLs """ from __future__ import unicode_literals from django.core.urlresolvers import reverse as django_reverse +from django.utils import six from django.utils.functional import lazy @@ -20,4 +21,4 @@ def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra return url -reverse_lazy = lazy(reverse, str) +reverse_lazy = lazy(reverse, six.text_type) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 174b08b8..7c4179a1 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -2,8 +2,8 @@ Helper classes for parsers. """ from __future__ import unicode_literals -from django.utils import timezone from django.db.models.query import QuerySet +from django.utils import six, timezone from django.utils.datastructures import SortedDict from django.utils.functional import Promise from rest_framework.compat import force_text @@ -40,7 +40,7 @@ class JSONEncoder(json.JSONEncoder): representation = representation[:12] return representation elif isinstance(obj, datetime.timedelta): - return str(obj.total_seconds()) + return six.text_type(obj.total_seconds()) elif isinstance(obj, decimal.Decimal): # Serializers will coerce decimals to strings by default. return float(obj) @@ -72,7 +72,7 @@ else: than the usual behaviour of sorting the keys. """ def represent_decimal(self, data): - return self.represent_scalar('tag:yaml.org,2002:str', str(data)) + return self.represent_scalar('tag:yaml.org,2002:str', six.text_type(data)) def represent_mapping(self, tag, mapping, flow_style=None): value = [] -- cgit v1.2.3 From c630a12e26f29145784523dd1b01ab0b3576f42c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 1 Oct 2014 13:24:47 +0100 Subject: Deal with lazy strings in serializer reprs --- rest_framework/utils/representation.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py index e64fdd22..180b51f8 100644 --- a/rest_framework/utils/representation.py +++ b/rest_framework/utils/representation.py @@ -3,6 +3,8 @@ Helper functions for creating user-friendly representations of serializer classes and serializer fields. """ from django.db import models +from django.utils.functional import Promise +from rest_framework.compat import force_text import re @@ -19,6 +21,9 @@ def smart_repr(value): if isinstance(value, models.Manager): return manager_repr(value) + if isinstance(value, Promise) and value._delegate_text: + value = force_text(value) + value = repr(value) # Representations like u'help text' -- cgit v1.2.3 From c171fa21ac62538331755524057d2435f33ec8a5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 1 Oct 2014 19:44:46 +0100 Subject: First pass at HTML form rendering --- rest_framework/renderers.py | 47 ++++++++++++++++++++-- rest_framework/serializers.py | 2 + .../templates/rest_framework/fields/attrs.html | 1 + .../rest_framework/fields/horizontal/checkbox.html | 10 +++++ .../rest_framework/fields/horizontal/fieldset.html | 10 +++++ .../rest_framework/fields/horizontal/input.html | 7 ++++ .../rest_framework/fields/horizontal/label.html | 1 + .../rest_framework/fields/horizontal/select.html | 10 +++++ .../fields/horizontal/select_checkbox.html | 22 ++++++++++ .../fields/horizontal/select_multiple.html | 10 +++++ .../fields/horizontal/select_radio.html | 22 ++++++++++ .../rest_framework/fields/horizontal/textarea.html | 7 ++++ .../rest_framework/fields/inline/checkbox.html | 6 +++ .../rest_framework/fields/inline/fieldset.html | 3 ++ .../rest_framework/fields/inline/input.html | 4 ++ .../rest_framework/fields/inline/label.html | 1 + .../rest_framework/fields/inline/select.html | 8 ++++ .../fields/inline/select_checkbox.html | 11 +++++ .../fields/inline/select_multiple.html | 8 ++++ .../rest_framework/fields/inline/select_radio.html | 11 +++++ .../rest_framework/fields/inline/textarea.html | 4 ++ .../rest_framework/fields/vertical/checkbox.html | 6 +++ .../rest_framework/fields/vertical/fieldset.html | 6 +++ .../rest_framework/fields/vertical/input.html | 5 +++ .../rest_framework/fields/vertical/label.html | 1 + .../rest_framework/fields/vertical/select.html | 8 ++++ .../fields/vertical/select_checkbox.html | 22 ++++++++++ .../fields/vertical/select_multiple.html | 8 ++++ .../fields/vertical/select_radio.html | 22 ++++++++++ .../rest_framework/fields/vertical/textarea.html | 5 +++ rest_framework/templates/rest_framework/form.html | 40 ++++++++++++------ rest_framework/templatetags/rest_framework.py | 8 ++++ rest_framework/utils/field_mapping.py | 3 ++ 33 files changed, 324 insertions(+), 15 deletions(-) create mode 100644 rest_framework/templates/rest_framework/fields/attrs.html create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/checkbox.html create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/fieldset.html create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/input.html create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/label.html create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/select.html create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/select_radio.html create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/textarea.html create mode 100644 rest_framework/templates/rest_framework/fields/inline/checkbox.html create mode 100644 rest_framework/templates/rest_framework/fields/inline/fieldset.html create mode 100644 rest_framework/templates/rest_framework/fields/inline/input.html create mode 100644 rest_framework/templates/rest_framework/fields/inline/label.html create mode 100644 rest_framework/templates/rest_framework/fields/inline/select.html create mode 100644 rest_framework/templates/rest_framework/fields/inline/select_checkbox.html create mode 100644 rest_framework/templates/rest_framework/fields/inline/select_multiple.html create mode 100644 rest_framework/templates/rest_framework/fields/inline/select_radio.html create mode 100644 rest_framework/templates/rest_framework/fields/inline/textarea.html create mode 100644 rest_framework/templates/rest_framework/fields/vertical/checkbox.html create mode 100644 rest_framework/templates/rest_framework/fields/vertical/fieldset.html create mode 100644 rest_framework/templates/rest_framework/fields/vertical/input.html create mode 100644 rest_framework/templates/rest_framework/fields/vertical/label.html create mode 100644 rest_framework/templates/rest_framework/fields/vertical/select.html create mode 100644 rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html create mode 100644 rest_framework/templates/rest_framework/fields/vertical/select_multiple.html create mode 100644 rest_framework/templates/rest_framework/fields/vertical/select_radio.html create mode 100644 rest_framework/templates/rest_framework/fields/vertical/textarea.html (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 225f9fe8..6483a47c 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -13,17 +13,18 @@ import django from django import forms from django.core.exceptions import ImproperlyConfigured from django.http.multipartparser import parse_header -from django.template import RequestContext, loader, Template +from django.template import Context, RequestContext, loader, Template from django.test.client import encode_multipart from django.utils import six from django.utils.xmlutils import SimplerXMLGenerator +from rest_framework import exceptions, serializers, status, VERSION from rest_framework.compat import StringIO, smart_text, yaml from rest_framework.exceptions import ParseError from rest_framework.settings import api_settings from rest_framework.request import is_form_media_type, override_method from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs -from rest_framework import exceptions, status, VERSION +from rest_framework.utils.field_mapping import ClassLookupDict def zero_as_none(value): @@ -341,6 +342,42 @@ class HTMLFormRenderer(BaseRenderer): template = 'rest_framework/form.html' charset = 'utf-8' + field_templates = ClassLookupDict({ + serializers.Field: { + 'default': 'input.html' + }, + serializers.BooleanField: { + 'default': 'checkbox.html' + }, + serializers.CharField: { + 'default': 'input.html', + 'textarea': 'textarea.html' + }, + serializers.ChoiceField: { + 'default': 'select.html', + 'radio': 'select_radio.html' + }, + serializers.MultipleChoiceField: { + 'default': 'select_multiple.html', + 'checkbox': 'select_checkbox.html' + } + }) + + def render_field(self, field, value, errors, layout=None): + layout = layout or 'vertical' + style_type = field.style.get('type', 'default') + if style_type == 'textarea' and layout == 'inline': + style_type = 'default' + base = self.field_templates[field][style_type] + template_name = 'rest_framework/fields/' + layout + '/' + base + template = loader.get_template(template_name) + context = Context({ + 'field': field, + 'value': value, + 'errors': errors + }) + return template.render(context) + def render(self, data, accepted_media_type=None, renderer_context=None): """ Render serializer data and return an HTML form, as a string. @@ -349,7 +386,11 @@ class HTMLFormRenderer(BaseRenderer): request = renderer_context['request'] template = loader.get_template(self.template) - context = RequestContext(request, {'form': data}) + context = RequestContext(request, { + 'form': data, + 'layout': getattr(getattr(data, 'Meta', None), 'layout', 'vertical'), + 'renderer': self + }) return template.render(context) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 0faa5671..5da81247 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -302,6 +302,8 @@ class Serializer(BaseSerializer): def __iter__(self): errors = self.errors if hasattr(self, '_errors') else {} for field in self.fields.values(): + if field.read_only: + continue value = self.data.get(field.field_name) if self.data else None error = errors.get(field.field_name) yield FieldResult(field, value, error) diff --git a/rest_framework/templates/rest_framework/fields/attrs.html b/rest_framework/templates/rest_framework/fields/attrs.html new file mode 100644 index 00000000..b5a4dbcf --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/attrs.html @@ -0,0 +1 @@ +name="{{ field.field_name }}" {% if field.style.placeholder %}placeholder="{{ field.style.placeholder }}"{% endif %} {% if field.style.rows %}rows="{{ field.style.rows }}"{% endif %} diff --git a/rest_framework/templates/rest_framework/fields/horizontal/checkbox.html b/rest_framework/templates/rest_framework/fields/horizontal/checkbox.html new file mode 100644 index 00000000..dce4a5cf --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/horizontal/checkbox.html @@ -0,0 +1,10 @@ +
+
+
+ +
+
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html b/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html new file mode 100644 index 00000000..86417633 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html @@ -0,0 +1,10 @@ +
+ {% if field.label %} +
+ {{ field.label }} +
+ {% endif %} + {% for field_item in value.field_items.values() %} + {{ renderer.render_field(field_item, layout=layout) }} + {% endfor %} +
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/input.html b/rest_framework/templates/rest_framework/fields/horizontal/input.html new file mode 100644 index 00000000..310154bb --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/horizontal/input.html @@ -0,0 +1,7 @@ +
+ {% include "rest_framework/fields/horizontal/label.html" %} +
+ + {% if field.help_text %}

{{ field.help_text }}

{% endif %} +
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/label.html b/rest_framework/templates/rest_framework/fields/horizontal/label.html new file mode 100644 index 00000000..bf21f78c --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/horizontal/label.html @@ -0,0 +1 @@ +{% if field.label %}{% endif %} diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select.html b/rest_framework/templates/rest_framework/fields/horizontal/select.html new file mode 100644 index 00000000..3f8cab0a --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/horizontal/select.html @@ -0,0 +1,10 @@ +
+ {% include "rest_framework/fields/horizontal/label.html" %} +
+ +
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html b/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html new file mode 100644 index 00000000..659eede8 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html @@ -0,0 +1,22 @@ +
+ {% include "rest_framework/fields/horizontal/label.html" %} +
+ {% if field.style.inline %} + {% for key, text in field.choices.items %} + + {% endfor %} + {% else %} + {% for key, text in field.choices.items %} +
+ +
+ {% endfor %} + {% endif %} +
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html b/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html new file mode 100644 index 00000000..da25eb2b --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html @@ -0,0 +1,10 @@ +
+ {% include "rest_framework/fields/horizontal/label.html" %} +
+ +
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select_radio.html b/rest_framework/templates/rest_framework/fields/horizontal/select_radio.html new file mode 100644 index 00000000..188f05e2 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/horizontal/select_radio.html @@ -0,0 +1,22 @@ +
+ {% include "rest_framework/fields/horizontal/label.html" %} +
+ {% if field.style.inline %} + {% for key, text in field.choices.items %} + + {% endfor %} + {% else %} + {% for key, text in field.choices.items %} +
+ +
+ {% endfor %} + {% endif %} +
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/textarea.html b/rest_framework/templates/rest_framework/fields/horizontal/textarea.html new file mode 100644 index 00000000..e99266f3 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/horizontal/textarea.html @@ -0,0 +1,7 @@ +
+ {% include "rest_framework/fields/horizontal/label.html" %} +
+ + {% if field.help_text %}

{{ field.help_text }}

{% endif %} +
+
diff --git a/rest_framework/templates/rest_framework/fields/inline/checkbox.html b/rest_framework/templates/rest_framework/fields/inline/checkbox.html new file mode 100644 index 00000000..01d30aae --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/inline/checkbox.html @@ -0,0 +1,6 @@ +
+ +
diff --git a/rest_framework/templates/rest_framework/fields/inline/fieldset.html b/rest_framework/templates/rest_framework/fields/inline/fieldset.html new file mode 100644 index 00000000..d22982fd --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/inline/fieldset.html @@ -0,0 +1,3 @@ +{% for field_item in value.field_items.values() %} + {{ renderer.render_field(field_item, layout=layout) }} +{% endfor %} diff --git a/rest_framework/templates/rest_framework/fields/inline/input.html b/rest_framework/templates/rest_framework/fields/inline/input.html new file mode 100644 index 00000000..aefd1672 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/inline/input.html @@ -0,0 +1,4 @@ +
+ {% include "rest_framework/fields/inline/label.html" %} + +
diff --git a/rest_framework/templates/rest_framework/fields/inline/label.html b/rest_framework/templates/rest_framework/fields/inline/label.html new file mode 100644 index 00000000..7d546a57 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/inline/label.html @@ -0,0 +1 @@ +{% if field.label %}{% endif %} diff --git a/rest_framework/templates/rest_framework/fields/inline/select.html b/rest_framework/templates/rest_framework/fields/inline/select.html new file mode 100644 index 00000000..cb9a7013 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/inline/select.html @@ -0,0 +1,8 @@ +
+ {% include "rest_framework/fields/inline/label.html" %} + +
diff --git a/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html b/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html new file mode 100644 index 00000000..424df93e --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html @@ -0,0 +1,11 @@ +
+ {% include "rest_framework/fields/inline/label.html" %} + {% for key, text in field.choices.items %} +
+ +
+ {% endfor %} +
diff --git a/rest_framework/templates/rest_framework/fields/inline/select_multiple.html b/rest_framework/templates/rest_framework/fields/inline/select_multiple.html new file mode 100644 index 00000000..6fdfd672 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/inline/select_multiple.html @@ -0,0 +1,8 @@ +
+ {% include "rest_framework/fields/inline/label.html" %} + +
diff --git a/rest_framework/templates/rest_framework/fields/inline/select_radio.html b/rest_framework/templates/rest_framework/fields/inline/select_radio.html new file mode 100644 index 00000000..ddabc9e9 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/inline/select_radio.html @@ -0,0 +1,11 @@ +
+ {% include "rest_framework/fields/inline/label.html" %} + {% for key, text in field.choices.items %} +
+ +
+ {% endfor %} +
diff --git a/rest_framework/templates/rest_framework/fields/inline/textarea.html b/rest_framework/templates/rest_framework/fields/inline/textarea.html new file mode 100644 index 00000000..31366809 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/inline/textarea.html @@ -0,0 +1,4 @@ +
+ {% include "rest_framework/fields/inline/label.html" %} + +
diff --git a/rest_framework/templates/rest_framework/fields/vertical/checkbox.html b/rest_framework/templates/rest_framework/fields/vertical/checkbox.html new file mode 100644 index 00000000..01d30aae --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/vertical/checkbox.html @@ -0,0 +1,6 @@ +
+ +
diff --git a/rest_framework/templates/rest_framework/fields/vertical/fieldset.html b/rest_framework/templates/rest_framework/fields/vertical/fieldset.html new file mode 100644 index 00000000..cad32df9 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/vertical/fieldset.html @@ -0,0 +1,6 @@ +
+ {% if field.label %}{{ field.label }}{% endif %} + {% for field_item in value.field_items.values() %} + {{ renderer.render_field(field_item, layout=layout) }} + {% endfor %} +
diff --git a/rest_framework/templates/rest_framework/fields/vertical/input.html b/rest_framework/templates/rest_framework/fields/vertical/input.html new file mode 100644 index 00000000..c25407d1 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/vertical/input.html @@ -0,0 +1,5 @@ +
+ {% include "rest_framework/fields/vertical/label.html" %} + + {% if field.help_text %}

{{ field.help_text }}

{% endif %} +
diff --git a/rest_framework/templates/rest_framework/fields/vertical/label.html b/rest_framework/templates/rest_framework/fields/vertical/label.html new file mode 100644 index 00000000..651939b2 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/vertical/label.html @@ -0,0 +1 @@ +{% if field.label %}{% endif %} diff --git a/rest_framework/templates/rest_framework/fields/vertical/select.html b/rest_framework/templates/rest_framework/fields/vertical/select.html new file mode 100644 index 00000000..44679d8a --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/vertical/select.html @@ -0,0 +1,8 @@ +
+ {% include "rest_framework/fields/vertical/label.html" %} + +
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html b/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html new file mode 100644 index 00000000..e60574c0 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html @@ -0,0 +1,22 @@ +
+ {% include "rest_framework/fields/vertical/label.html" %} + {% if field.style.inline %} +
+ {% for key, text in field.choices.items %} + + {% endfor %} +
+ {% else %} + {% for key, text in field.choices.items %} +
+ +
+ {% endfor %} + {% endif %} +
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html b/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html new file mode 100644 index 00000000..f0fa418b --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html @@ -0,0 +1,8 @@ +
+ {% include "rest_framework/fields/vertical/label.html" %} + +
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_radio.html b/rest_framework/templates/rest_framework/fields/vertical/select_radio.html new file mode 100644 index 00000000..4ffe38ea --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/vertical/select_radio.html @@ -0,0 +1,22 @@ +
+ {% include "rest_framework/fields/vertical/label.html" %} + {% if field.style.inline %} +
+ {% for key, text in field.choices.items %} + + {% endfor %} +
+ {% else %} + {% for key, text in field.choices.items %} +
+ +
+ {% endfor %} + {% endif %} +
diff --git a/rest_framework/templates/rest_framework/fields/vertical/textarea.html b/rest_framework/templates/rest_framework/fields/vertical/textarea.html new file mode 100644 index 00000000..33cb27c7 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/vertical/textarea.html @@ -0,0 +1,5 @@ +
+ {% include "rest_framework/fields/vertical/label.html" %} + + {% if field.help_text %}

{{ field.help_text }}

{% endif %} +
diff --git a/rest_framework/templates/rest_framework/form.html b/rest_framework/templates/rest_framework/form.html index b1e148df..64b1b0bc 100644 --- a/rest_framework/templates/rest_framework/form.html +++ b/rest_framework/templates/rest_framework/form.html @@ -1,15 +1,31 @@ + + + + + +
+ +

User update

+
+ {% load rest_framework %} -{% csrf_token %} -{{ form.non_field_errors }} -{% for field in form.fields.values %} - {% if not field.read_only %} -
- {{ field.label_tag|add_class:"control-label" }} -
- {{ field.widget_html }} - {% if field.help_text %}{{ field.help_text }}{% endif %} - {% for error in field.errors %}{{ error }}{% endfor %} +
+ {% csrf_token %} + {% for field, value, errors in form %} + {% render_field field value errors layout=layout renderer=renderer %} + {% endfor %} + + {% if layout == "horizontal" %} +
+
+ +
-
+ {% else %} + {% endif %} -{% endfor %} + + +
+
+ diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 864d64dd..88ff9d4e 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -31,6 +31,14 @@ class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') # And the template tags themselves... +# @register.simple_tag +# def render_field(field, value, errors, renderer): +# return renderer.render_field(field, value, errors) +@register.simple_tag +def render_field(field, value, errors, layout=None, renderer=None): + return renderer.render_field(field, value, errors, layout) + + @register.simple_tag def optional_login(request): """ diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index cf9d910a..b4d33e39 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -79,6 +79,9 @@ def get_field_kwargs(field_name, model_field): kwargs['choices'] = model_field.flatchoices return kwargs + if isinstance(model_field, models.TextField): + kwargs['style'] = {'type': 'textarea'} + if model_field.null and not isinstance(model_field, models.NullBooleanField): kwargs['allow_null'] = True -- cgit v1.2.3 From ffc6aa3abcb0f823b43b63db1666913565e6f934 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 1 Oct 2014 21:35:27 +0100 Subject: More forms support --- rest_framework/relations.py | 20 +++++++++++++++++ rest_framework/renderers.py | 25 ++++++++++++++++++++-- .../fields/vertical/select_multiple.html | 2 +- rest_framework/templates/rest_framework/form.html | 2 +- 4 files changed, 45 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index b5effc6c..8c135672 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -38,6 +38,16 @@ class RelatedField(Field): queryset = queryset.all() return queryset + @property + def choices(self): + return dict([ + ( + str(self.to_representation(item)), + str(item) + ) + for item in self.queryset.all() + ]) + class StringRelatedField(Field): """ @@ -255,3 +265,13 @@ class ManyRelation(Field): self.child_relation.to_representation(value) for value in obj.all() ] + + @property + def choices(self): + return dict([ + ( + str(self.child_relation.to_representation(item)), + str(item) + ) + for item in self.child_relation.queryset.all() + ]) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 6483a47c..297c60d8 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -360,22 +360,43 @@ class HTMLFormRenderer(BaseRenderer): serializers.MultipleChoiceField: { 'default': 'select_multiple.html', 'checkbox': 'select_checkbox.html' + }, + serializers.ManyRelation: { + 'default': 'select_multiple.html', + 'checkbox': 'select_checkbox.html' } }) + input_type = ClassLookupDict({ + serializers.Field: 'text', + serializers.EmailField: 'email', + serializers.URLField: 'url', + serializers.IntegerField: 'number', + serializers.DateTimeField: 'datetime-local', + serializers.DateField: 'date', + serializers.TimeField: 'time', + }) + def render_field(self, field, value, errors, layout=None): layout = layout or 'vertical' style_type = field.style.get('type', 'default') if style_type == 'textarea' and layout == 'inline': style_type = 'default' + + input_type = self.input_type[field] + if input_type == 'datetime-local': + value = value.rstrip('Z') + base = self.field_templates[field][style_type] template_name = 'rest_framework/fields/' + layout + '/' + base template = loader.get_template(template_name) context = Context({ 'field': field, 'value': value, - 'errors': errors + 'errors': errors, + 'input_type': input_type }) + return template.render(context) def render(self, data, accepted_media_type=None, renderer_context=None): @@ -388,7 +409,7 @@ class HTMLFormRenderer(BaseRenderer): template = loader.get_template(self.template) context = RequestContext(request, { 'form': data, - 'layout': getattr(getattr(data, 'Meta', None), 'layout', 'vertical'), + 'layout': getattr(getattr(data, 'Meta', None), 'layout', 'horizontal'), 'renderer': self }) return template.render(context) diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html b/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html index f0fa418b..00b25b4b 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html +++ b/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html @@ -1,7 +1,7 @@
{% include "rest_framework/fields/vertical/label.html" %} diff --git a/rest_framework/templates/rest_framework/form.html b/rest_framework/templates/rest_framework/form.html index 64b1b0bc..658aa293 100644 --- a/rest_framework/templates/rest_framework/form.html +++ b/rest_framework/templates/rest_framework/form.html @@ -9,7 +9,7 @@
{% load rest_framework %} -
+ {% csrf_token %} {% for field, value, errors in form %} {% render_field field value errors layout=layout renderer=renderer %} -- cgit v1.2.3 From 79e91dff92443ab1f301638ac280bd3231a2ca15 Mon Sep 17 00:00:00 2001 From: Omer Katz Date: Thu, 2 Oct 2014 16:44:20 +0300 Subject: The encoder now returns tuples instead of lists. Tuples take a little less memory which is significant when serializing a lot of objects. --- rest_framework/utils/encoders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 7c4179a1..486186c9 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -45,7 +45,7 @@ class JSONEncoder(json.JSONEncoder): # Serializers will coerce decimals to strings by default. return float(obj) elif isinstance(obj, QuerySet): - return list(obj) + return tuple(obj) elif hasattr(obj, 'tolist'): # Numpy arrays and array scalars. return obj.tolist() @@ -55,7 +55,7 @@ class JSONEncoder(json.JSONEncoder): except: pass elif hasattr(obj, '__iter__'): - return [item for item in obj] + return tuple(item for item in obj) return super(JSONEncoder, self).default(obj) -- cgit v1.2.3 From df7b6fcf58417fd95e49655eb140b387899b1ceb Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 2 Oct 2014 16:24:24 +0100 Subject: First pass on incorperating the form rendering into the browsable API --- rest_framework/fields.py | 8 +- rest_framework/relations.py | 4 +- rest_framework/renderers.py | 62 +++++----- rest_framework/serializers.py | 129 +++++++++++++++------ .../static/rest_framework/css/bootstrap-tweaks.css | 18 ++- rest_framework/templates/rest_framework/base.html | 56 ++++----- .../rest_framework/fields/horizontal/checkbox.html | 2 +- .../rest_framework/fields/horizontal/fieldset.html | 2 +- .../rest_framework/fields/horizontal/input.html | 2 +- .../rest_framework/fields/horizontal/select.html | 2 +- .../fields/horizontal/select_checkbox.html | 4 +- .../fields/horizontal/select_multiple.html | 2 +- .../fields/horizontal/select_radio.html | 4 +- .../rest_framework/fields/horizontal/textarea.html | 2 +- .../rest_framework/fields/inline/checkbox.html | 2 +- .../rest_framework/fields/inline/fieldset.html | 2 +- .../rest_framework/fields/inline/input.html | 2 +- .../rest_framework/fields/inline/select.html | 2 +- .../fields/inline/select_checkbox.html | 2 +- .../fields/inline/select_multiple.html | 2 +- .../rest_framework/fields/inline/select_radio.html | 2 +- .../rest_framework/fields/inline/textarea.html | 2 +- .../rest_framework/fields/vertical/fieldset.html | 2 +- .../rest_framework/fields/vertical/input.html | 2 +- .../rest_framework/fields/vertical/select.html | 2 +- .../fields/vertical/select_checkbox.html | 4 +- .../fields/vertical/select_multiple.html | 2 +- .../fields/vertical/select_radio.html | 4 +- .../rest_framework/fields/vertical/textarea.html | 2 +- rest_framework/templates/rest_framework/form.html | 14 ++- rest_framework/templatetags/rest_framework.py | 7 +- rest_framework/utils/field_mapping.py | 9 +- 32 files changed, 217 insertions(+), 144 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f3ff2233..c794963e 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -689,10 +689,10 @@ class DateTimeField(Field): return value if self.format.lower() == ISO_8601: - ret = value.isoformat() - if ret.endswith('+00:00'): - ret = ret[:-6] + 'Z' - return ret + value = value.isoformat() + if value.endswith('+00:00'): + value = value[:-6] + 'Z' + return value return value.strftime(self.format) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 8c135672..988b9ede 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -127,7 +127,7 @@ class HyperlinkedRelatedField(RelatedField): attributes are not configured to correctly match the URL conf. """ # Unsaved objects will not yet have a valid URL. - if obj.pk is None: + if obj.pk: return None lookup_value = getattr(obj, self.lookup_field) @@ -248,11 +248,13 @@ class ManyRelation(Field): You shouldn't need to be using this class directly yourself. """ + initial = [] def __init__(self, child_relation=None, *args, **kwargs): self.child_relation = child_relation assert child_relation is not None, '`child_relation` is a required argument.' super(ManyRelation, self).__init__(*args, **kwargs) + self.child_relation.bind(field_name='', parent=self) def to_internal_value(self, data): return [ diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 297c60d8..931dd434 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -377,23 +377,21 @@ class HTMLFormRenderer(BaseRenderer): serializers.TimeField: 'time', }) - def render_field(self, field, value, errors, layout=None): + def render_field(self, field, layout=None): layout = layout or 'vertical' style_type = field.style.get('type', 'default') if style_type == 'textarea' and layout == 'inline': style_type = 'default' input_type = self.input_type[field] - if input_type == 'datetime-local': - value = value.rstrip('Z') + if input_type == 'datetime-local' and isinstance(field.value, six.text_type): + field.value = field.value.rstrip('Z') base = self.field_templates[field][style_type] template_name = 'rest_framework/fields/' + layout + '/' + base template = loader.get_template(template_name) context = Context({ 'field': field, - 'value': value, - 'errors': errors, 'input_type': input_type }) @@ -408,7 +406,7 @@ class HTMLFormRenderer(BaseRenderer): template = loader.get_template(self.template) context = RequestContext(request, { - 'form': data, + 'form': data.serializer, 'layout': getattr(getattr(data, 'Meta', None), 'layout', 'horizontal'), 'renderer': self }) @@ -479,27 +477,29 @@ class BrowsableAPIRenderer(BaseRenderer): return False # Doesn't have permissions return True - def get_rendered_html_form(self, view, method, request): + def get_rendered_html_form(self, data, view, method, request): """ Return a string representing a rendered HTML form, possibly bound to either the input or output data. In the absence of the View having an associated form then return None. """ + serializer = getattr(data, 'serializer', None) + if serializer and not getattr(serializer, 'many', False): + instance = getattr(serializer, 'instance', None) + else: + instance = None + if request.method == method: try: data = request.data - # files = request.FILES except ParseError: data = None - # files = None else: data = None - # files = None with override_method(view, request, method) as request: - obj = getattr(view, 'object', None) - if not self.show_form_for_method(view, method, request, obj): + if not self.show_form_for_method(view, method, request, instance): return if method in ('DELETE', 'OPTIONS'): @@ -511,19 +511,24 @@ class BrowsableAPIRenderer(BaseRenderer): ): return - serializer = view.get_serializer(instance=obj, data=data) - serializer.is_valid() - data = serializer.data - + serializer = view.get_serializer(instance=instance, data=data) + if data is not None: + serializer.is_valid() form_renderer = self.form_renderer_class() - return form_renderer.render(data, self.accepted_media_type, self.renderer_context) + return form_renderer.render(serializer.data, self.accepted_media_type, self.renderer_context) - def get_raw_data_form(self, view, method, request): + def get_raw_data_form(self, data, view, method, request): """ Returns a form that allows for arbitrary content types to be tunneled via standard HTML forms. (Which are typically application/x-www-form-urlencoded) """ + serializer = getattr(data, 'serializer', None) + if serializer and not getattr(serializer, 'many', False): + instance = getattr(serializer, 'instance', None) + else: + instance = None + with override_method(view, request, method) as request: # If we're not using content overloading there's no point in # supplying a generic form, as the view won't treat the form's @@ -533,8 +538,7 @@ class BrowsableAPIRenderer(BaseRenderer): return None # Check permissions - obj = getattr(view, 'object', None) - if not self.show_form_for_method(view, method, request, obj): + if not self.show_form_for_method(view, method, request, instance): return # If possible, serialize the initial content for the generic form @@ -545,8 +549,8 @@ class BrowsableAPIRenderer(BaseRenderer): # corresponding renderer that can be used to render the data. # Get a read-only version of the serializer - serializer = view.get_serializer(instance=obj) - if obj is None: + serializer = view.get_serializer(instance=instance) + if instance is None: for name, field in serializer.fields.items(): if getattr(field, 'read_only', None): del serializer.fields[name] @@ -606,9 +610,9 @@ class BrowsableAPIRenderer(BaseRenderer): renderer = self.get_default_renderer(view) - raw_data_post_form = self.get_raw_data_form(view, 'POST', request) - raw_data_put_form = self.get_raw_data_form(view, 'PUT', request) - raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request) + raw_data_post_form = self.get_raw_data_form(data, view, 'POST', request) + raw_data_put_form = self.get_raw_data_form(data, view, 'PUT', request) + raw_data_patch_form = self.get_raw_data_form(data, view, 'PATCH', request) raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form response_headers = dict(response.items()) @@ -632,10 +636,10 @@ class BrowsableAPIRenderer(BaseRenderer): 'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes], 'response_headers': response_headers, - # 'put_form': self.get_rendered_html_form(view, 'PUT', request), - # 'post_form': self.get_rendered_html_form(view, 'POST', request), - # 'delete_form': self.get_rendered_html_form(view, 'DELETE', request), - # 'options_form': self.get_rendered_html_form(view, 'OPTIONS', request), + 'put_form': self.get_rendered_html_form(data, view, 'PUT', request), + 'post_form': self.get_rendered_html_form(data, view, 'POST', request), + 'delete_form': self.get_rendered_html_form(data, view, 'DELETE', request), + 'options_form': self.get_rendered_html_form(data, view, 'OPTIONS', request), 'raw_data_put_form': raw_data_put_form, 'raw_data_post_form': raw_data_post_form, diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 5da81247..0f24ed40 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -14,7 +14,6 @@ from django.core.exceptions import ImproperlyConfigured, ValidationError from django.db import models from django.utils import six from django.utils.datastructures import SortedDict -from collections import namedtuple from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings from rest_framework.utils import html, model_meta, representation @@ -38,8 +37,8 @@ from rest_framework.relations import * # NOQA from rest_framework.fields import * # NOQA -FieldResult = namedtuple('FieldResult', ['field', 'value', 'error']) - +# BaseSerializer +# -------------- class BaseSerializer(Field): """ @@ -113,11 +112,6 @@ class BaseSerializer(Field): if not hasattr(self, '_data'): if self.instance is not None: self._data = self.to_representation(self.instance) - elif self._initial_data is not None: - self._data = dict([ - (field_name, field.get_value(self._initial_data)) - for field_name, field in self.fields.items() - ]) else: self._data = self.get_initial() return self._data @@ -137,34 +131,48 @@ class BaseSerializer(Field): return self._validated_data -class SerializerMetaclass(type): +# Serializer & ListSerializer classes +# ----------------------------------- + +class ReturnDict(SortedDict): """ - This metaclass sets a dictionary named `base_fields` on the class. + Return object from `serialier.data` for the `Serializer` class. + Includes a backlink to the serializer instance for renderers + to use if they need richer field information. + """ + def __init__(self, *args, **kwargs): + self.serializer = kwargs.pop('serializer') + super(ReturnDict, self).__init__(*args, **kwargs) - Any instances of `Field` included as attributes on either the class - or on any of its superclasses will be include in the - `base_fields` dictionary. + +class ReturnList(list): + """ + Return object from `serialier.data` for the `SerializerList` class. + Includes a backlink to the serializer instance for renderers + to use if they need richer field information. """ + def __init__(self, *args, **kwargs): + self.serializer = kwargs.pop('serializer') + super(ReturnList, self).__init__(*args, **kwargs) - @classmethod - def _get_declared_fields(cls, bases, attrs): - fields = [(field_name, attrs.pop(field_name)) - for field_name, obj in list(attrs.items()) - if isinstance(obj, Field)] - fields.sort(key=lambda x: x[1]._creation_counter) - # If this class is subclassing another Serializer, add that Serializer's - # fields. Note that we loop over the bases in *reverse*. This is necessary - # in order to maintain the correct order of fields. - for base in bases[::-1]: - if hasattr(base, '_declared_fields'): - fields = list(base._declared_fields.items()) + fields +class BoundField(object): + """ + A field object that also includes `.value` and `.error` properties. + Returned when iterating over a serializer instance, + providing an API similar to Django forms and form fields. + """ + def __init__(self, field, value, errors): + self._field = field + self.value = value + self.errors = errors - return SortedDict(fields) + def __getattr__(self, attr_name): + return getattr(self._field, attr_name) - def __new__(cls, name, bases, attrs): - attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) - return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) + @property + def _proxy_class(self): + return self._field.__class__ class BindingDict(object): @@ -196,6 +204,36 @@ class BindingDict(object): return self.fields.values() +class SerializerMetaclass(type): + """ + This metaclass sets a dictionary named `base_fields` on the class. + + Any instances of `Field` included as attributes on either the class + or on any of its superclasses will be include in the + `base_fields` dictionary. + """ + + @classmethod + def _get_declared_fields(cls, bases, attrs): + fields = [(field_name, attrs.pop(field_name)) + for field_name, obj in list(attrs.items()) + if isinstance(obj, Field)] + fields.sort(key=lambda x: x[1]._creation_counter) + + # If this class is subclassing another Serializer, add that Serializer's + # fields. Note that we loop over the bases in *reverse*. This is necessary + # in order to maintain the correct order of fields. + for base in bases[::-1]: + if hasattr(base, '_declared_fields'): + fields = list(base._declared_fields.items()) + fields + + return SortedDict(fields) + + def __new__(cls, name, bases, attrs): + attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) + return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) + + @six.add_metaclass(SerializerMetaclass) class Serializer(BaseSerializer): def __init__(self, *args, **kwargs): @@ -212,10 +250,18 @@ class Serializer(BaseSerializer): return copy.deepcopy(self._declared_fields) def get_initial(self): - return dict([ + if self._initial_data is not None: + return ReturnDict([ + (field_name, field.get_value(self._initial_data)) + for field_name, field in self.fields.items() + ], serializer=self) + #return self.to_representation(self._initial_data) + + return ReturnDict([ (field.field_name, field.get_initial()) for field in self.fields.values() - ]) + if not field.write_only + ], serializer=self) def get_value(self, dictionary): # We override the default field access in order to support @@ -288,7 +334,7 @@ class Serializer(BaseSerializer): """ Object instance -> Dict of primitive datatypes. """ - ret = SortedDict() + ret = ReturnDict(serializer=self) fields = [field for field in self.fields.values() if not field.write_only] for field in fields: @@ -302,11 +348,9 @@ class Serializer(BaseSerializer): def __iter__(self): errors = self.errors if hasattr(self, '_errors') else {} for field in self.fields.values(): - if field.read_only: - continue value = self.data.get(field.field_name) if self.data else None error = errors.get(field.field_name) - yield FieldResult(field, value, error) + yield BoundField(field, value, error) def __repr__(self): return representation.serializer_repr(self, indent=1) @@ -317,7 +361,7 @@ class Serializer(BaseSerializer): class ListSerializer(BaseSerializer): child = None - initial = [] + many = True def __init__(self, *args, **kwargs): self.child = kwargs.pop('child', copy.deepcopy(self.child)) @@ -326,6 +370,11 @@ class ListSerializer(BaseSerializer): super(ListSerializer, self).__init__(*args, **kwargs) self.child.bind(field_name='', parent=self) + def get_initial(self): + if self._initial_data is not None: + return self.to_representation(self._initial_data) + return ReturnList(serializer=self) + def get_value(self, dictionary): # We override the default field access in order to support # lists in HTML forms. @@ -345,7 +394,10 @@ class ListSerializer(BaseSerializer): """ List of object instances -> List of dicts of primitive datatypes. """ - return [self.child.to_representation(item) for item in data] + return ReturnList( + [self.child.to_representation(item) for item in data], + serializer=self + ) def create(self, attrs_list): return [self.child.create(attrs) for attrs in attrs_list] @@ -354,6 +406,9 @@ class ListSerializer(BaseSerializer): return representation.list_repr(self, indent=1) +# ModelSerializer & HyperlinkedModelSerializer +# -------------------------------------------- + class ModelSerializer(Serializer): _field_mapping = ClassLookupDict({ models.AutoField: IntegerField, diff --git a/rest_framework/static/rest_framework/css/bootstrap-tweaks.css b/rest_framework/static/rest_framework/css/bootstrap-tweaks.css index 6fa1e6cb..84389b1d 100644 --- a/rest_framework/static/rest_framework/css/bootstrap-tweaks.css +++ b/rest_framework/static/rest_framework/css/bootstrap-tweaks.css @@ -10,6 +10,12 @@ a single block in the template. background: transparent; border-top-color: transparent; padding-top: 0; + text-align: right; +} + +#generic-content-form textarea { + font-family:Consolas,Monaco,Lucida Console,Liberation Mono,DejaVu Sans Mono,Bitstream Vera Sans Mono,Courier New, monospace; + font-size: 80%; } .navbar-inverse .brand a { @@ -29,7 +35,7 @@ a single block in the template. z-index: 3; } -.navbar .navbar-inner { +.navbar { background: #2C2C2C; color: white; border: none; @@ -37,7 +43,7 @@ a single block in the template. border-radius: 0px; } -.navbar .navbar-inner .nav li, .navbar .navbar-inner .nav li a, .navbar .navbar-inner .brand:hover { +.navbar .nav li, .navbar .nav li a, .navbar .brand:hover { color: white; } @@ -45,11 +51,11 @@ a single block in the template. background: #2C2C2C; } -.navbar .navbar-inner .dropdown-menu li a, .navbar .navbar-inner .dropdown-menu li { +.navbar .dropdown-menu li a, .navbar .dropdown-menu li { color: #A30000; } -.navbar .navbar-inner .dropdown-menu li a:hover { +.navbar .dropdown-menu li a:hover { background: #EEEEEE; color: #C20000; } @@ -61,10 +67,10 @@ html { background: none; } -body, .navbar .navbar-inner .container-fluid { +/*body, .navbar .container-fluid { max-width: 1150px; margin: 0 auto; -} +}*/ body { background: url("../img/grid.png") repeat-x; diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index a84ccf26..2e03dd98 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -15,7 +15,8 @@ {% block style %} {% block bootstrap_theme %} - + + {% endblock %} @@ -26,44 +27,42 @@ {% block body %} - +
{% block navbar %} - diff --git a/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html b/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html index 424df93e..0f33fb69 100644 --- a/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html +++ b/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html @@ -3,7 +3,7 @@ {% for key, text in field.choices.items %}
diff --git a/rest_framework/templates/rest_framework/fields/inline/select_multiple.html b/rest_framework/templates/rest_framework/fields/inline/select_multiple.html index 6fdfd672..7c9e5168 100644 --- a/rest_framework/templates/rest_framework/fields/inline/select_multiple.html +++ b/rest_framework/templates/rest_framework/fields/inline/select_multiple.html @@ -2,7 +2,7 @@ {% include "rest_framework/fields/inline/label.html" %}
diff --git a/rest_framework/templates/rest_framework/fields/inline/select_radio.html b/rest_framework/templates/rest_framework/fields/inline/select_radio.html index ddabc9e9..177c0eeb 100644 --- a/rest_framework/templates/rest_framework/fields/inline/select_radio.html +++ b/rest_framework/templates/rest_framework/fields/inline/select_radio.html @@ -3,7 +3,7 @@ {% for key, text in field.choices.items %}
diff --git a/rest_framework/templates/rest_framework/fields/inline/textarea.html b/rest_framework/templates/rest_framework/fields/inline/textarea.html index 31366809..8259487b 100644 --- a/rest_framework/templates/rest_framework/fields/inline/textarea.html +++ b/rest_framework/templates/rest_framework/fields/inline/textarea.html @@ -1,4 +1,4 @@
{% include "rest_framework/fields/inline/label.html" %} - +
diff --git a/rest_framework/templates/rest_framework/fields/vertical/fieldset.html b/rest_framework/templates/rest_framework/fields/vertical/fieldset.html index cad32df9..8708916b 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/fieldset.html +++ b/rest_framework/templates/rest_framework/fields/vertical/fieldset.html @@ -1,6 +1,6 @@
{% if field.label %}{{ field.label }}{% endif %} - {% for field_item in value.field_items.values() %} + {% for field_item in field.value.field_items.values() %} {{ renderer.render_field(field_item, layout=layout) }} {% endfor %}
diff --git a/rest_framework/templates/rest_framework/fields/vertical/input.html b/rest_framework/templates/rest_framework/fields/vertical/input.html index c25407d1..3ee2716a 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/input.html +++ b/rest_framework/templates/rest_framework/fields/vertical/input.html @@ -1,5 +1,5 @@
{% include "rest_framework/fields/vertical/label.html" %} - + {% if field.help_text %}

{{ field.help_text }}

{% endif %}
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select.html b/rest_framework/templates/rest_framework/fields/vertical/select.html index 44679d8a..dcc9a3cd 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/select.html +++ b/rest_framework/templates/rest_framework/fields/vertical/select.html @@ -2,7 +2,7 @@ {% include "rest_framework/fields/vertical/label.html" %}
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html b/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html index e60574c0..1fbe6a94 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html +++ b/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html @@ -4,7 +4,7 @@
{% for key, text in field.choices.items %} {% endfor %} @@ -13,7 +13,7 @@ {% for key, text in field.choices.items %}
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html b/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html index 00b25b4b..2cc40d99 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html +++ b/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html @@ -2,7 +2,7 @@ {% include "rest_framework/fields/vertical/label.html" %}
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_radio.html b/rest_framework/templates/rest_framework/fields/vertical/select_radio.html index 4ffe38ea..470bcb0b 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/select_radio.html +++ b/rest_framework/templates/rest_framework/fields/vertical/select_radio.html @@ -4,7 +4,7 @@
{% for key, text in field.choices.items %} {% endfor %} @@ -13,7 +13,7 @@ {% for key, text in field.choices.items %}
diff --git a/rest_framework/templates/rest_framework/fields/vertical/textarea.html b/rest_framework/templates/rest_framework/fields/vertical/textarea.html index 33cb27c7..406cfa77 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/textarea.html +++ b/rest_framework/templates/rest_framework/fields/vertical/textarea.html @@ -1,5 +1,5 @@
{% include "rest_framework/fields/vertical/label.html" %} - + {% if field.help_text %}

{{ field.help_text }}

{% endif %}
diff --git a/rest_framework/templates/rest_framework/form.html b/rest_framework/templates/rest_framework/form.html index 658aa293..162c5633 100644 --- a/rest_framework/templates/rest_framework/form.html +++ b/rest_framework/templates/rest_framework/form.html @@ -1,4 +1,4 @@ - + {% load rest_framework %} {% csrf_token %} - {% for field, value, errors in form %} - {% render_field field value errors layout=layout renderer=renderer %} + {% for field in form %} + {% if not field.read_only %} + {% render_field field layout=layout renderer=renderer %} + {% endif %} {% endfor %} {% if layout == "horizontal" %} @@ -25,7 +27,7 @@ {% endif %} - + diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 88ff9d4e..49a4c338 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -31,12 +31,9 @@ class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') # And the template tags themselves... -# @register.simple_tag -# def render_field(field, value, errors, renderer): -# return renderer.render_field(field, value, errors) @register.simple_tag -def render_field(field, value, errors, layout=None, renderer=None): - return renderer.render_field(field, value, errors, layout) +def render_field(field, layout=None, renderer=None): + return renderer.render_field(field, layout) @register.simple_tag diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index b4d33e39..30fae370 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -21,7 +21,14 @@ class ClassLookupDict(object): self.mapping = mapping def __getitem__(self, key): - for cls in inspect.getmro(key.__class__): + if hasattr(key, '_proxy_class'): + # Deal with proxy classes. Ie. BoundField behaves as if it + # is a Field instance when using ClassLookupDict. + base_class = key._proxy_class + else: + base_class = key.__class__ + + for cls in inspect.getmro(base_class): if cls in self.mapping: return self.mapping[cls] raise KeyError('Class %s not found in lookup.', cls.__name__) -- cgit v1.2.3 From fec7c4b45812d22423e73ec3ab801857a55d7340 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 2 Oct 2014 18:13:15 +0100 Subject: Browsable API tweaks --- rest_framework/fields.py | 5 ++--- rest_framework/relations.py | 1 + rest_framework/serializers.py | 1 + rest_framework/static/rest_framework/css/bootstrap-tweaks.css | 2 +- rest_framework/static/rest_framework/css/default.css | 3 +-- 5 files changed, 6 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c794963e..3f22660c 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -106,7 +106,7 @@ class Field(object): 'null': _('This field may not be null.') } default_validators = [] - default_empty_html = None + default_empty_html = empty initial = None def __init__(self, read_only=False, write_only=False, @@ -375,7 +375,6 @@ class NullBooleanField(Field): default_error_messages = { 'invalid': _('`{input}` is not a valid boolean.') } - default_empty_html = None initial = None TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) @@ -411,7 +410,6 @@ class CharField(Field): default_error_messages = { 'blank': _('This field may not be blank.') } - default_empty_html = '' initial = '' def __init__(self, **kwargs): @@ -852,6 +850,7 @@ class MultipleChoiceField(ChoiceField): 'invalid_choice': _('`{input}` is not a valid choice.'), 'not_a_list': _('Expected a list of items but got type `{input_type}`') } + default_empty_html = [] def to_internal_value(self, data): if isinstance(data, type('')) or not hasattr(data, '__iter__'): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 988b9ede..4f971917 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -249,6 +249,7 @@ class ManyRelation(Field): You shouldn't need to be using this class directly yourself. """ initial = [] + default_empty_html = [] def __init__(self, child_relation=None, *args, **kwargs): self.child_relation = child_relation diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 0f24ed40..21cb7ea2 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -254,6 +254,7 @@ class Serializer(BaseSerializer): return ReturnDict([ (field_name, field.get_value(self._initial_data)) for field_name, field in self.fields.items() + if field.get_value(self._initial_data) is not empty ], serializer=self) #return self.to_representation(self._initial_data) diff --git a/rest_framework/static/rest_framework/css/bootstrap-tweaks.css b/rest_framework/static/rest_framework/css/bootstrap-tweaks.css index 84389b1d..6a37cae2 100644 --- a/rest_framework/static/rest_framework/css/bootstrap-tweaks.css +++ b/rest_framework/static/rest_framework/css/bootstrap-tweaks.css @@ -173,7 +173,7 @@ footer a:hover { .page-header { border-bottom: none; padding-bottom: 0px; - margin-bottom: 20px; + margin: 0; } /* custom general page styles */ diff --git a/rest_framework/static/rest_framework/css/default.css b/rest_framework/static/rest_framework/css/default.css index 461cdfe5..82c6033b 100644 --- a/rest_framework/static/rest_framework/css/default.css +++ b/rest_framework/static/rest_framework/css/default.css @@ -33,7 +33,7 @@ h2, h3 { } ul.breadcrumb { - margin: 80px 0 0 0; + margin: 70px 0 0 0; } form select, form input, form textarea { @@ -67,5 +67,4 @@ pre { .page-header { border-bottom: none; padding-bottom: 0px; - margin-bottom: 20px; } -- cgit v1.2.3 From dfab9af294972720f59890967cd9ae1a6c0796b6 Mon Sep 17 00:00:00 2001 From: Craig de Stigter Date: Fri, 3 Oct 2014 08:41:18 +1300 Subject: Minor: fix spelling and grammar, mostly in 3.0 announcement --- rest_framework/compat.py | 2 +- rest_framework/fields.py | 12 ++++++------ rest_framework/relations.py | 2 +- rest_framework/validators.py | 2 +- rest_framework/views.py | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 3993cee6..e4e69580 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -131,7 +131,7 @@ else: self.message = kwargs.pop('message', self.message) super(MaxValueValidator, self).__init__(*args, **kwargs) -# URLValidator only accept `message` in 1.6+ +# URLValidator only accepts `message` in 1.6+ if django.VERSION >= (1, 6): from django.core.validators import URLValidator else: diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f3ff2233..bba8ccae 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -186,14 +186,14 @@ class Field(object): def get_initial(self): """ - Return a value to use when the field is being returned as a primative + Return a value to use when the field is being returned as a primitive value, without any object instance. """ return self.initial def get_value(self, dictionary): """ - Given the *incoming* primative data, return the value for this field + Given the *incoming* primitive data, return the value for this field that should be validated and transformed to a native value. """ if html.is_html_input(dictionary): @@ -205,7 +205,7 @@ class Field(object): def get_field_representation(self, instance): """ - Given the outgoing object instance, return the primative value + Given the outgoing object instance, return the primitive value that should be used for this field. """ attribute = get_attribute(instance, self.source_attrs) @@ -274,13 +274,13 @@ class Field(object): def to_internal_value(self, data): """ - Transform the *incoming* primative data into a native value. + Transform the *incoming* primitive data into a native value. """ raise NotImplementedError('to_internal_value() must be implemented.') def to_representation(self, value): """ - Transform the *outgoing* native value into primative data. + Transform the *outgoing* native value into primitive data. """ raise NotImplementedError('to_representation() must be implemented.') @@ -928,7 +928,7 @@ class ImageField(FileField): def to_internal_value(self, data): # Image validation is a bit grungy, so we'll just outright # defer to Django's implementation so we don't need to - # consider it, or treat PIL as a test dependancy. + # consider it, or treat PIL as a test dependency. file_object = super(ImageField, self).to_internal_value(data) django_field = self._DjangoImageField() django_field.error_messages = self.error_messages diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 8c135672..8141de13 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -100,7 +100,7 @@ class HyperlinkedRelatedField(RelatedField): self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) self.format = kwargs.pop('format', None) - # We include these simply for dependancy injection in tests. + # We include these simply for dependency injection in tests. # We can't add them as class attributes or they would expect an # implict `self` argument to be passed. self.reverse = reverse diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 20de4b42..5bb69ad8 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -2,7 +2,7 @@ We perform uniqueness checks explicitly on the serializer class, rather the using Django's `.full_clean()`. -This gives us better seperation of concerns, allows us to use single-step +This gives us better separation of concerns, allows us to use single-step object creation, and makes it possible to switch between using the implicit `ModelSerializer` class and an equivelent explicit `Serializer` class. """ diff --git a/rest_framework/views.py b/rest_framework/views.py index 835e223a..979229eb 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -100,7 +100,7 @@ class APIView(View): content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS metadata_class = api_settings.DEFAULT_METADATA_CLASS - # Allow dependancy injection of other settings to make testing easier. + # Allow dependency injection of other settings to make testing easier. settings = api_settings @classmethod -- cgit v1.2.3 From 857a8486b1534f89bd482de86d39ff717b6618eb Mon Sep 17 00:00:00 2001 From: Craig de Stigter Date: Fri, 3 Oct 2014 09:00:33 +1300 Subject: More spelling tweaks --- rest_framework/filters.py | 2 +- rest_framework/mixins.py | 2 +- rest_framework/relations.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 4c485668..d188a2d1 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -148,7 +148,7 @@ class OrderingFilter(BaseFilterBackend): if not getattr(field, 'write_only', False) ] elif valid_fields == '__all__': - # View explictly allows filtering on any model field + # View explicitly allows filtering on any model field valid_fields = [field.name for field in queryset.model._meta.fields] valid_fields += queryset.query.aggregates.keys() diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 04b7a763..de334b4b 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -83,7 +83,7 @@ class DestroyModelMixin(object): # The AllowPUTAsCreateMixin was previously the default behaviour -# for PUT requests. This has now been removed and must be *explictly* +# for PUT requests. This has now been removed and must be *explicitly* # included if it is the behavior that you want. # For more info see: ... diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 8141de13..dc9781e7 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -102,7 +102,7 @@ class HyperlinkedRelatedField(RelatedField): # We include these simply for dependency injection in tests. # We can't add them as class attributes or they would expect an - # implict `self` argument to be passed. + # implicit `self` argument to be passed. self.reverse = reverse self.resolve = resolve -- cgit v1.2.3 From 765b0b33bf1fa9b7c6b45d3877d10a05d4e9f6ea Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 3 Oct 2014 13:12:23 +0100 Subject: Revert accidental stupidity --- rest_framework/relations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 4f971917..f9b5ff0d 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -127,7 +127,7 @@ class HyperlinkedRelatedField(RelatedField): attributes are not configured to correctly match the URL conf. """ # Unsaved objects will not yet have a valid URL. - if obj.pk: + if obj.pk is None: return None lookup_value = getattr(obj, self.lookup_field) -- cgit v1.2.3 From e6c5ebdda6d0f169f21498909e2d390c460138a9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 3 Oct 2014 13:14:17 +0100 Subject: Fix indentation --- rest_framework/serializers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 21cb7ea2..c3a0815e 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -252,10 +252,10 @@ class Serializer(BaseSerializer): def get_initial(self): if self._initial_data is not None: return ReturnDict([ - (field_name, field.get_value(self._initial_data)) - for field_name, field in self.fields.items() - if field.get_value(self._initial_data) is not empty - ], serializer=self) + (field_name, field.get_value(self._initial_data)) + for field_name, field in self.fields.items() + if field.get_value(self._initial_data) is not empty + ], serializer=self) #return self.to_representation(self._initial_data) return ReturnDict([ -- cgit v1.2.3 From 3a3e2bf57d5443dc0b058d5beb3111f87c418947 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 3 Oct 2014 13:42:06 +0100 Subject: Serializer.save() takes keyword arguments, not 'extras' argument --- rest_framework/mixins.py | 4 ++-- rest_framework/serializers.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index de334b4b..bc4ce22f 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -101,8 +101,8 @@ class AllowPUTAsCreateMixin(object): if instance is None: lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field lookup_value = self.kwargs[lookup_url_kwarg] - extras = {self.lookup_field: lookup_value} - serializer.save(extras=extras) + extra_kwargs = {self.lookup_field: lookup_value} + serializer.save(**extra_kwargs) return Response(serializer.data, status=status.HTTP_201_CREATED) serializer.save() diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index c3a0815e..ed024f87 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -74,12 +74,12 @@ class BaseSerializer(Field): def create(self, validated_data): raise NotImplementedError('`create()` must be implemented.') - def save(self, extras=None): + def save(self, **kwargs): validated_data = self.validated_data - if extras is not None: + if kwargs: validated_data = dict( list(validated_data.items()) + - list(extras.items()) + list(kwargs.items()) ) if self.instance is not None: @@ -256,7 +256,6 @@ class Serializer(BaseSerializer): for field_name, field in self.fields.items() if field.get_value(self._initial_data) is not empty ], serializer=self) - #return self.to_representation(self._initial_data) return ReturnDict([ (field.field_name, field.get_initial()) -- cgit v1.2.3 From 2dfe75c23a041493bc83514d8e9e9268b79072d9 Mon Sep 17 00:00:00 2001 From: Jones Chi Date: Fri, 3 Oct 2014 14:42:49 +0800 Subject: Fix follow does not work on APIClient Handle follow just like Django's Client. --- rest_framework/test.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/test.py b/rest_framework/test.py index 9b40353a..74d2c868 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -156,6 +156,52 @@ class APIClient(APIRequestFactory, DjangoClient): kwargs.update(self._credentials) return super(APIClient, self).request(**kwargs) + def get(self, path, data=None, follow=False, **extra): + response = super(APIClient, self).get(path, data=data, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def post(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).post( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def put(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).put( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def patch(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).patch( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def delete(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).delete( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def options(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).options( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + def logout(self): self._credentials = {} return super(APIClient, self).logout() -- cgit v1.2.3 From 6bfed6f8525a49fc50df7143ac2d492528b8f2ac Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 7 Oct 2014 17:04:53 +0100 Subject: Enforce uniqueness validation for relational fields --- rest_framework/fields.py | 2 ++ rest_framework/utils/field_mapping.py | 3 +++ 2 files changed, 5 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 0963d4bf..9d577c53 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -224,6 +224,8 @@ class Field(object): """ if self.default is empty: raise SkipField() + if is_simple_callable(self.default): + return self.default() return self.default def run_validation(self, data=empty): diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index 30fae370..fd6da699 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -219,6 +219,9 @@ def get_relation_kwargs(field_name, relation_info): kwargs['required'] = False if model_field.null: kwargs['allow_null'] = True + if getattr(model_field, 'unique', False): + validator = UniqueValidator(queryset=model_field.model._default_manager) + kwargs['validators'] = [validator] return kwargs -- cgit v1.2.3 From 3fa4a1898aee0dabee951f81f790bb2da042ec81 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 7 Oct 2014 17:21:12 +0100 Subject: Reintroduce save hooks --- rest_framework/mixins.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index bc4ce22f..03ebb034 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -20,10 +20,13 @@ class CreateModelMixin(object): def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - serializer.save() + self.create_valid(serializer) headers = self.get_success_headers(serializer.data) return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + def create_valid(self, serializer): + serializer.save() + def get_success_headers(self, data): try: return {'Location': data[api_settings.URL_FIELD_NAME]} @@ -64,9 +67,12 @@ class UpdateModelMixin(object): instance = self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) serializer.is_valid(raise_exception=True) - serializer.save() + self.update_valid(serializer) return Response(serializer.data) + def update_valid(self, serializer): + serializer.save() + def partial_update(self, request, *args, **kwargs): kwargs['partial'] = True return self.update(request, *args, **kwargs) -- cgit v1.2.3 From 093febb91299e332c810de6a6b6aba57c2b16a91 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 Oct 2014 11:04:08 +0100 Subject: Tests for relational fields --- rest_framework/fields.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 9d577c53..5fb0ec8d 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,7 +1,7 @@ from django import forms from django.conf import settings from django.core import validators -from django.core.exceptions import ValidationError +from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.utils import six, timezone from django.utils.datastructures import SortedDict from django.utils.dateparse import parse_date, parse_datetime, parse_time @@ -54,6 +54,8 @@ def get_attribute(instance, attrs): for attr in attrs: try: instance = getattr(instance, attr) + except ObjectDoesNotExist: + return None except AttributeError as exc: try: return instance[attr] @@ -108,6 +110,7 @@ class Field(object): default_validators = [] default_empty_html = empty initial = None + coerce_blank_to_null = True def __init__(self, read_only=False, write_only=False, required=None, default=empty, initial=empty, source=None, @@ -245,6 +248,9 @@ class Field(object): self.fail('required') return self.get_default() + if data == '' and self.coerce_blank_to_null: + data = None + if data is None: if not self.allow_null: self.fail('null') @@ -413,6 +419,7 @@ class CharField(Field): 'blank': _('This field may not be blank.') } initial = '' + coerce_blank_to_null = False def __init__(self, **kwargs): self.allow_blank = kwargs.pop('allow_blank', False) -- cgit v1.2.3 From 6b09e5f2bba9167404ec329fa12c7f0215ca51ac Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 Oct 2014 11:22:10 +0100 Subject: Tests for generic relationships --- rest_framework/relations.py | 2 +- rest_framework/serializers.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index e5bdf60c..df5025b8 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -49,7 +49,7 @@ class RelatedField(Field): ]) -class StringRelatedField(Field): +class StringRelatedField(RelatedField): """ A read only field that represents its targets using their plain string representation. diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index ed024f87..3d868a9e 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -520,11 +520,6 @@ class ModelSerializer(Serializer): ret[field_name] = declared_fields[field_name] continue - elif field_name == api_settings.URL_FIELD_NAME: - # Create the URL field. - field_cls = HyperlinkedIdentityField - kwargs = get_url_kwargs(model) - elif field_name in info.fields_and_pk: # Create regular model fields. model_field = info.fields_and_pk[field_name] @@ -561,6 +556,11 @@ class ModelSerializer(Serializer): field_cls = ReadOnlyField kwargs = {} + elif field_name == api_settings.URL_FIELD_NAME: + # Create the URL field. + field_cls = HyperlinkedIdentityField + kwargs = get_url_kwargs(model) + else: raise ImproperlyConfigured( 'Field name `%s` is not valid for model `%s`.' % -- cgit v1.2.3 From 0cbb57b40fdb073c7ca09c9d1078926260c646db Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 Oct 2014 12:17:30 +0100 Subject: Tweak pre/post save hooks. Return instance in .update(). --- rest_framework/mixins.py | 13 ++++++++----- rest_framework/serializers.py | 24 ++++++++++++++---------- 2 files changed, 22 insertions(+), 15 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 03ebb034..4c62debb 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -20,11 +20,11 @@ class CreateModelMixin(object): def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - self.create_valid(serializer) + self.perform_create(serializer) headers = self.get_success_headers(serializer.data) return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) - def create_valid(self, serializer): + def perform_create(self, serializer): serializer.save() def get_success_headers(self, data): @@ -67,10 +67,10 @@ class UpdateModelMixin(object): instance = self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) serializer.is_valid(raise_exception=True) - self.update_valid(serializer) + self.preform_update(serializer) return Response(serializer.data) - def update_valid(self, serializer): + def preform_update(self, serializer): serializer.save() def partial_update(self, request, *args, **kwargs): @@ -84,9 +84,12 @@ class DestroyModelMixin(object): """ def destroy(self, request, *args, **kwargs): instance = self.get_object() - instance.delete() + self.perform_destroy(instance) return Response(status=status.HTTP_204_NO_CONTENT) + def perform_destroy(self, instance): + instance.delete() + # The AllowPUTAsCreateMixin was previously the default behaviour # for PUT requests. This has now been removed and must be *explicitly* diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 3d868a9e..e7cd50d6 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -83,7 +83,10 @@ class BaseSerializer(Field): ) if self.instance is not None: - self.update(self.instance, validated_data) + self.instance = self.update(self.instance, validated_data) + assert self.instance is not None, ( + '`update()` did not return an object instance.' + ) else: self.instance = self.create(validated_data) assert self.instance is not None, ( @@ -444,19 +447,19 @@ class ModelSerializer(Serializer): self.validators.extend(validators) self._kwargs['validators'] = validators - def create(self, attrs): + def create(self, validated_attrs): ModelClass = self.Meta.model - # Remove many-to-many relationships from attrs. + # Remove many-to-many relationships from validated_attrs. # They are not valid arguments to the default `.create()` method, # as they require that the instance has already been saved. info = model_meta.get_field_info(ModelClass) many_to_many = {} for field_name, relation_info in info.relations.items(): - if relation_info.to_many and (field_name in attrs): - many_to_many[field_name] = attrs.pop(field_name) + if relation_info.to_many and (field_name in validated_attrs): + many_to_many[field_name] = validated_attrs.pop(field_name) - instance = ModelClass.objects.create(**attrs) + instance = ModelClass.objects.create(**validated_attrs) # Save many-to-many relationships after the instance is created. if many_to_many: @@ -465,10 +468,11 @@ class ModelSerializer(Serializer): return instance - def update(self, obj, attrs): - for attr, value in attrs.items(): - setattr(obj, attr, value) - obj.save() + def update(self, instance, validated_attrs): + for attr, value in validated_attrs.items(): + setattr(instance, attr, value) + instance.save() + return instance def get_unique_together_validators(self): field_names = set([ -- cgit v1.2.3 From 28f3b314f12cbff33c55602c2c5f5f5cce956171 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 Oct 2014 12:36:28 +0100 Subject: .validate() returning validated data. transform_ hooks. --- rest_framework/serializers.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e7cd50d6..8513428c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -299,7 +299,8 @@ class Serializer(BaseSerializer): value = self.to_internal_value(data) try: self.run_validators(value) - self.validate(value) + value = self.validate(value) + assert value is not None, '.validate() should return the validated data' except ValidationError as exc: raise ValidationError({ api_settings.NON_FIELD_ERRORS_KEY: exc.messages @@ -341,7 +342,12 @@ class Serializer(BaseSerializer): fields = [field for field in self.fields.values() if not field.write_only] for field in fields: - ret[field.field_name] = field.get_field_representation(instance) + value = field.get_field_representation(instance) + transform_method = getattr(self, 'transform_' + field.field_name, None) + if transform_method is not None: + value = transform_method(value) + + ret[field.field_name] = value return ret -- cgit v1.2.3 From 14ae52a24e93063f77c6010269bf9cd3316627fe Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 Oct 2014 16:09:37 +0100 Subject: More gradual deprecation --- rest_framework/request.py | 16 ++++++++ rest_framework/serializers.py | 71 ++++++++++++++++++++++++++++++++++- rest_framework/utils/field_mapping.py | 22 +++++------ rest_framework/utils/model_meta.py | 4 +- 4 files changed, 99 insertions(+), 14 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/request.py b/rest_framework/request.py index d80baa70..d4352742 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -18,6 +18,7 @@ from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions from rest_framework.compat import BytesIO from rest_framework.settings import api_settings +import warnings def is_form_media_type(media_type): @@ -209,6 +210,11 @@ class Request(object): """ Synonym for `.query_params`, for backwards compatibility. """ + warnings.warn( + "`request.QUERY_PARAMS` is pending deprecation. Use `request.query_params` instead.", + PendingDeprecationWarning, + stacklevel=1 + ) return self._request.GET @property @@ -225,6 +231,11 @@ class Request(object): Similar to usual behaviour of `request.POST`, except that it handles arbitrary parsers, and also works on methods other than POST (eg PUT). """ + warnings.warn( + "`request.DATA` is pending deprecation. Use `request.data` instead.", + PendingDeprecationWarning, + stacklevel=1 + ) if not _hasattr(self, '_data'): self._load_data_and_files() return self._data @@ -237,6 +248,11 @@ class Request(object): Similar to usual behaviour of `request.FILES`, except that it handles arbitrary parsers, and also works on methods other than POST (eg PUT). """ + warnings.warn( + "`request.FILES` is pending deprecation. Use `request.data` instead.", + PendingDeprecationWarning, + stacklevel=1 + ) if not _hasattr(self, '_files'): self._load_data_and_files() return self._files diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8513428c..9fcbcba7 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -25,6 +25,7 @@ from rest_framework.utils.field_mapping import ( from rest_framework.validators import UniqueTogetherValidator import copy import inspect +import warnings # Note: We do the following so that users of the framework can use this style: # @@ -517,12 +518,24 @@ class ModelSerializer(Serializer): depth = getattr(self.Meta, 'depth', 0) extra_kwargs = getattr(self.Meta, 'extra_kwargs', {}) + extra_kwargs = self._include_additional_options(extra_kwargs) + # Retrieve metadata about fields & relationships on the model class. info = model_meta.get_field_info(model) # Use the default set of fields if none is supplied explicitly. if fields is None: fields = self._get_default_field_names(declared_fields, info) + exclude = getattr(self.Meta, 'exclude', None) + if exclude is not None: + warnings.warn( + "The `Meta.exclude` option is pending deprecation. " + "Use the explicit `Meta.fields` instead.", + PendingDeprecationWarning, + stacklevel=3 + ) + for field_name in exclude: + fields.remove(field_name) for field_name in fields: if field_name in declared_fields: @@ -589,13 +602,69 @@ class ModelSerializer(Serializer): ) # Populate any kwargs defined in `Meta.extra_kwargs` - kwargs.update(extra_kwargs.get(field_name, {})) + extras = extra_kwargs.get(field_name, {}) + if extras.get('read_only', False): + for attr in [ + 'required', 'default', 'allow_blank', 'allow_null', + 'min_length', 'max_length', 'min_value', 'max_value', + 'validators' + ]: + kwargs.pop(attr, None) + kwargs.update(extras) # Create the serializer field. ret[field_name] = field_cls(**kwargs) return ret + def _include_additional_options(self, extra_kwargs): + read_only_fields = getattr(self.Meta, 'read_only_fields', None) + if read_only_fields is not None: + for field_name in read_only_fields: + kwargs = extra_kwargs.get(field_name, {}) + kwargs['read_only'] = True + extra_kwargs[field_name] = kwargs + + # These are all pending deprecation. + write_only_fields = getattr(self.Meta, 'write_only_fields', None) + if write_only_fields is not None: + warnings.warn( + "The `Meta.write_only_fields` option is pending deprecation. " + "Use `Meta.extra_kwargs={: {'write_only': True}}` instead.", + PendingDeprecationWarning, + stacklevel=3 + ) + for field_name in write_only_fields: + kwargs = extra_kwargs.get(field_name, {}) + kwargs['write_only'] = True + extra_kwargs[field_name] = kwargs + + view_name = getattr(self.Meta, 'view_name', None) + if view_name is not None: + warnings.warn( + "The `Meta.view_name` option is pending deprecation. " + "Use `Meta.extra_kwargs={'url': {'view_name': ...}}` instead.", + PendingDeprecationWarning, + stacklevel=3 + ) + kwargs = extra_kwargs.get(field_name, {}) + kwargs['view_name'] = view_name + extra_kwargs[api_settings.URL_FIELD_NAME] = kwargs + + lookup_field = getattr(self.Meta, 'lookup_field', None) + if lookup_field is not None: + warnings.warn( + "The `Meta.lookup_field` option is pending deprecation. " + "Use `Meta.extra_kwargs={'url': {'lookup_field': ...}}` instead.", + PendingDeprecationWarning, + stacklevel=3 + ) + kwargs = extra_kwargs.get(field_name, {}) + kwargs['lookup_field'] = lookup_field + extra_kwargs[api_settings.URL_FIELD_NAME] = kwargs + + return extra_kwargs + def _get_default_field_names(self, declared_fields, model_info): return ( [model_info.pk.name] + diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index fd6da699..6db37146 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -71,6 +71,17 @@ def get_field_kwargs(field_name, model_field): if model_field.help_text: kwargs['help_text'] = model_field.help_text + max_digits = getattr(model_field, 'max_digits', None) + if max_digits is not None: + kwargs['max_digits'] = max_digits + + decimal_places = getattr(model_field, 'decimal_places', None) + if decimal_places is not None: + kwargs['decimal_places'] = decimal_places + + if isinstance(model_field, models.TextField): + kwargs['style'] = {'type': 'textarea'} + if isinstance(model_field, models.AutoField) or not model_field.editable: # If this field is read-only, then return early. # Further keyword arguments are not valid. @@ -86,9 +97,6 @@ def get_field_kwargs(field_name, model_field): kwargs['choices'] = model_field.flatchoices return kwargs - if isinstance(model_field, models.TextField): - kwargs['style'] = {'type': 'textarea'} - if model_field.null and not isinstance(model_field, models.NullBooleanField): kwargs['allow_null'] = True @@ -171,14 +179,6 @@ def get_field_kwargs(field_name, model_field): validator = UniqueValidator(queryset=model_field.model._default_manager) validator_kwarg.append(validator) - max_digits = getattr(model_field, 'max_digits', None) - if max_digits is not None: - kwargs['max_digits'] = max_digits - - decimal_places = getattr(model_field, 'decimal_places', None) - if decimal_places is not None: - kwargs['decimal_places'] = decimal_places - if validator_kwarg: kwargs['validators'] = validator_kwarg diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index b6c41174..7a95bcdd 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -107,8 +107,8 @@ def get_field_info(model): related=relation.model, to_many=True, has_through_model=( - hasattr(relation.field.rel, 'through') and - not relation.field.rel.through._meta.auto_created + (getattr(relation.field.rel, 'through', None) is not None) + and not relation.field.rel.through._meta.auto_created ) ) -- cgit v1.2.3 From 4c015df28cfb7dc7cf29f6dc4985c57e1f5cdc5d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 Oct 2014 16:43:33 +0100 Subject: Tweaks --- rest_framework/relations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index df5025b8..e9dd7dde 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -264,9 +264,10 @@ class ManyRelation(Field): ] def to_representation(self, obj): + iterable = obj.all() if (hasattr(obj, 'all')) else obj return [ self.child_relation.to_representation(value) - for value in obj.all() + for value in iterable ] @property -- cgit v1.2.3 From 5ead8dc89d1a99d6189170dc8dac19cdc8ba7750 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 Oct 2014 16:59:52 +0100 Subject: Support empty file fields --- rest_framework/fields.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 5fb0ec8d..f86f6626 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -913,6 +913,8 @@ class FileField(Field): def to_representation(self, value): if self.use_url: + if not value: + return None url = settings.MEDIA_URL + value.url request = self.context.get('request', None) if request is not None: -- cgit v1.2.3 From f7d43f530a94e686d2f93781471b9ac4e90d0f58 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 Oct 2014 17:03:14 +0100 Subject: Limit blank string -> None to just be on relational fields --- rest_framework/fields.py | 4 ---- rest_framework/relations.py | 8 +++++++- 2 files changed, 7 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f86f6626..b371c7d0 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -110,7 +110,6 @@ class Field(object): default_validators = [] default_empty_html = empty initial = None - coerce_blank_to_null = True def __init__(self, read_only=False, write_only=False, required=None, default=empty, initial=empty, source=None, @@ -248,9 +247,6 @@ class Field(object): self.fail('required') return self.get_default() - if data == '' and self.coerce_blank_to_null: - data = None - if data is None: if not self.allow_null: self.fail('null') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index e9dd7dde..c1e5aa18 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,5 +1,5 @@ from rest_framework.compat import smart_text, urlparse -from rest_framework.fields import Field +from rest_framework.fields import empty, Field from rest_framework.reverse import reverse from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch, Resolver404 @@ -31,6 +31,12 @@ class RelatedField(Field): ) return super(RelatedField, cls).__new__(cls, *args, **kwargs) + def run_validation(self, data=empty): + # We force empty strings to None values for relational fields. + if data == '': + data = None + return super(RelatedField, self).run_validation(data) + def get_queryset(self): queryset = self.queryset if isinstance(queryset, QuerySet): -- cgit v1.2.3 From 5f4cc52ef5c0f603420c6ea809594710a372d336 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 9 Oct 2014 10:11:44 +0100 Subject: Tweaking --- rest_framework/validators.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 5bb69ad8..f76faaa4 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -12,6 +12,9 @@ from rest_framework.utils.representation import smart_repr class UniqueValidator: + """ + Validator that corresponds to `unique=True` on a model field. + """ # Validators with `requires_context` will have the field instance # passed to them when the field is instantiated. requires_context = True @@ -46,6 +49,9 @@ class UniqueValidator: class UniqueTogetherValidator: + """ + Validator that corresponds to `unique_together = (...)` on a model class. + """ requires_context = True message = _('The fields {field_names} must make a unique set.') -- cgit v1.2.3 From 5d247a65c89594a7ab5ce2333612f23eadc6828d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 9 Oct 2014 15:11:19 +0100 Subject: First pass on nested serializers in HTML --- rest_framework/compat.py | 16 +++++++++- rest_framework/fields.py | 28 +++++++++++++---- rest_framework/relations.py | 20 +++++++++++-- rest_framework/renderers.py | 10 ++++++- rest_framework/serializers.py | 35 +++++++++++++++++----- .../rest_framework/fields/horizontal/fieldset.html | 5 ++-- .../fields/horizontal/list_fieldset.html | 13 ++++++++ .../rest_framework/fields/inline/fieldset.html | 5 ++-- .../rest_framework/fields/vertical/fieldset.html | 5 ++-- .../fields/vertical/list_fieldset.html | 7 +++++ 10 files changed, 120 insertions(+), 24 deletions(-) create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/list_fieldset.html create mode 100644 rest_framework/templates/rest_framework/fields/vertical/list_fieldset.html (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index e4e69580..4ab23a4d 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -114,12 +114,15 @@ else: -# MinValueValidator and MaxValueValidator only accept `message` in 1.8+ +# MinValueValidator, MaxValueValidator et al. only accept `message` in 1.8+ if django.VERSION >= (1, 8): from django.core.validators import MinValueValidator, MaxValueValidator + from django.core.validators import MinLengthValidator, MaxLengthValidator else: from django.core.validators import MinValueValidator as DjangoMinValueValidator from django.core.validators import MaxValueValidator as DjangoMaxValueValidator + from django.core.validators import MinLengthValidator as DjangoMinLengthValidator + from django.core.validators import MaxLengthValidator as DjangoMaxLengthValidator class MinValueValidator(DjangoMinValueValidator): def __init__(self, *args, **kwargs): @@ -131,6 +134,17 @@ else: self.message = kwargs.pop('message', self.message) super(MaxValueValidator, self).__init__(*args, **kwargs) + class MinLengthValidator(DjangoMinLengthValidator): + def __init__(self, *args, **kwargs): + self.message = kwargs.pop('message', self.message) + super(MinLengthValidator, self).__init__(*args, **kwargs) + + class MaxLengthValidator(DjangoMaxLengthValidator): + def __init__(self, *args, **kwargs): + self.message = kwargs.pop('message', self.message) + super(MaxLengthValidator, self).__init__(*args, **kwargs) + + # URLValidator only accepts `message` in 1.6+ if django.VERSION >= (1, 6): from django.core.validators import URLValidator diff --git a/rest_framework/fields.py b/rest_framework/fields.py index b371c7d0..7053acee 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -8,7 +8,10 @@ from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ from rest_framework import ISO_8601 -from rest_framework.compat import smart_text, EmailValidator, MinValueValidator, MaxValueValidator, URLValidator +from rest_framework.compat import ( + smart_text, EmailValidator, MinValueValidator, MaxValueValidator, + MinLengthValidator, MaxLengthValidator, URLValidator +) from rest_framework.settings import api_settings from rest_framework.utils import html, representation, humanize_datetime import copy @@ -138,7 +141,7 @@ class Field(object): self.label = label self.help_text = help_text self.style = {} if style is None else style - self.validators = validators or self.default_validators[:] + self.validators = validators[:] or self.default_validators[:] self.allow_null = allow_null # These are set up by `.bind()` when the field is added to a serializer. @@ -412,16 +415,24 @@ class NullBooleanField(Field): class CharField(Field): default_error_messages = { - 'blank': _('This field may not be blank.') + 'blank': _('This field may not be blank.'), + 'max_length': _('Ensure this field has no more than {max_length} characters.'), + 'min_length': _('Ensure this field has no more than {min_length} characters.') } initial = '' coerce_blank_to_null = False def __init__(self, **kwargs): self.allow_blank = kwargs.pop('allow_blank', False) - self.max_length = kwargs.pop('max_length', None) - self.min_length = kwargs.pop('min_length', None) + max_length = kwargs.pop('max_length', None) + min_length = kwargs.pop('min_length', None) super(CharField, self).__init__(**kwargs) + if max_length is not None: + message = self.error_messages['max_length'].format(max_length=max_length) + self.validators.append(MaxLengthValidator(max_length, message=message)) + if min_length is not None: + message = self.error_messages['min_length'].format(min_length=min_length) + self.validators.append(MinLengthValidator(min_length, message=message)) def run_validation(self, data=empty): # Test for the empty string here so that it does not get validated, @@ -857,6 +868,13 @@ class MultipleChoiceField(ChoiceField): } default_empty_html = [] + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if html.is_html_input(dictionary): + return dictionary.getlist(self.field_name) + return dictionary.get(self.field_name, empty) + def to_internal_value(self, data): if isinstance(data, type('')) or not hasattr(data, '__iter__'): self.fail('not_a_list', input_type=type(data).__name__) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index c1e5aa18..268b95cf 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,6 +1,7 @@ from rest_framework.compat import smart_text, urlparse from rest_framework.fields import empty, Field from rest_framework.reverse import reverse +from rest_framework.utils import html from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch, Resolver404 from django.db.models.query import QuerySet @@ -263,6 +264,13 @@ class ManyRelation(Field): super(ManyRelation, self).__init__(*args, **kwargs) self.child_relation.bind(field_name='', parent=self) + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if html.is_html_input(dictionary): + return dictionary.getlist(self.field_name) + return dictionary.get(self.field_name, empty) + def to_internal_value(self, data): return [ self.child_relation.to_internal_value(item) @@ -278,10 +286,16 @@ class ManyRelation(Field): @property def choices(self): + queryset = self.child_relation.queryset + iterable = queryset.all() if (hasattr(queryset, 'all')) else queryset + items_and_representations = [ + (item, self.child_relation.to_representation(item)) + for item in iterable + ] return dict([ ( - str(self.child_relation.to_representation(item)), - str(item) + str(item_representation), + str(item) + ' - ' + str(item_representation) ) - for item in self.child_relation.queryset.all() + for item, item_representation in items_and_representations ]) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 931dd434..4fb36060 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -364,6 +364,12 @@ class HTMLFormRenderer(BaseRenderer): serializers.ManyRelation: { 'default': 'select_multiple.html', 'checkbox': 'select_checkbox.html' + }, + serializers.Serializer: { + 'default': 'fieldset.html' + }, + serializers.ListSerializer: { + 'default': 'list_fieldset.html' } }) @@ -392,7 +398,9 @@ class HTMLFormRenderer(BaseRenderer): template = loader.get_template(template_name) context = Context({ 'field': field, - 'input_type': input_type + 'input_type': input_type, + 'renderer': self, + 'layout': layout }) return template.render(context) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 9fcbcba7..1c006990 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -166,14 +166,25 @@ class BoundField(object): Returned when iterating over a serializer instance, providing an API similar to Django forms and form fields. """ - def __init__(self, field, value, errors): + def __init__(self, field, value, errors, prefix=''): self._field = field self.value = value self.errors = errors + self.name = prefix + self.field_name def __getattr__(self, attr_name): return getattr(self._field, attr_name) + def __iter__(self): + for field in self.fields.values(): + yield self[field.field_name] + + def __getitem__(self, key): + field = self.fields[key] + value = self.value.get(key) if self.value else None + error = self.errors.get(key) if self.errors else None + return BoundField(field, value, error, prefix=self.name + '.') + @property def _proxy_class(self): return self._field.__class__ @@ -355,15 +366,22 @@ class Serializer(BaseSerializer): def validate(self, attrs): return attrs + def __repr__(self): + return representation.serializer_repr(self, indent=1) + + # The following are used for accessing `BoundField` instances on the + # serializer, for the purposes of presenting a form-like API onto the + # field values and field errors. + def __iter__(self): - errors = self.errors if hasattr(self, '_errors') else {} for field in self.fields.values(): - value = self.data.get(field.field_name) if self.data else None - error = errors.get(field.field_name) - yield BoundField(field, value, error) + yield self[field.field_name] - def __repr__(self): - return representation.serializer_repr(self, indent=1) + def __getitem__(self, key): + field = self.fields[key] + value = self.data.get(key) + error = self.errors.get(key) if hasattr(self, '_errors') else None + return BoundField(field, value, error) # There's some replication of `ListField` here, @@ -404,8 +422,9 @@ class ListSerializer(BaseSerializer): """ List of object instances -> List of dicts of primitive datatypes. """ + iterable = data.all() if (hasattr(data, 'all')) else data return ReturnList( - [self.child.to_representation(item) for item in data], + [self.child.to_representation(item) for item in iterable], serializer=self ) diff --git a/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html b/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html index 843a56b2..ff93c6ba 100644 --- a/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html +++ b/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html @@ -1,10 +1,11 @@ +{% load rest_framework %}
{% if field.label %}
{{ field.label }}
{% endif %} - {% for field_item in field.value.field_items.values() %} - {{ renderer.render_field(field_item, layout=layout) }} + {% for nested_field in field %} + {% render_field nested_field layout=layout renderer=renderer %} {% endfor %}
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/list_fieldset.html b/rest_framework/templates/rest_framework/fields/horizontal/list_fieldset.html new file mode 100644 index 00000000..68c75d4f --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/horizontal/list_fieldset.html @@ -0,0 +1,13 @@ +{% load rest_framework %} +
+ {% if field.label %} +
+ {{ field.label }} +
+ {% endif %} +
    + {% for child in field.value %} +
  • TODO
  • + {% endfor %} +
+
diff --git a/rest_framework/templates/rest_framework/fields/inline/fieldset.html b/rest_framework/templates/rest_framework/fields/inline/fieldset.html index 380d4627..ba9f1835 100644 --- a/rest_framework/templates/rest_framework/fields/inline/fieldset.html +++ b/rest_framework/templates/rest_framework/fields/inline/fieldset.html @@ -1,3 +1,4 @@ -{% for field_item in field.value.field_items.values() %} - {{ renderer.render_field(field_item, layout=layout) }} +{% load rest_framework %} +{% for nested_field in field %} + {% render_field nested_field layout=layout renderer=renderer %} {% endfor %} diff --git a/rest_framework/templates/rest_framework/fields/vertical/fieldset.html b/rest_framework/templates/rest_framework/fields/vertical/fieldset.html index 8708916b..248fe904 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/fieldset.html +++ b/rest_framework/templates/rest_framework/fields/vertical/fieldset.html @@ -1,6 +1,7 @@ +{% load rest_framework %}
{% if field.label %}{{ field.label }}{% endif %} - {% for field_item in field.value.field_items.values() %} - {{ renderer.render_field(field_item, layout=layout) }} + {% for nested_field in field %} + {% render_field nested_field layout=layout renderer=renderer %} {% endfor %}
diff --git a/rest_framework/templates/rest_framework/fields/vertical/list_fieldset.html b/rest_framework/templates/rest_framework/fields/vertical/list_fieldset.html new file mode 100644 index 00000000..6b99a834 --- /dev/null +++ b/rest_framework/templates/rest_framework/fields/vertical/list_fieldset.html @@ -0,0 +1,7 @@ +
+ {% if field.label %}{{ field.label }}{% endif %} + +
-- cgit v1.2.3 From f83ed19d22250eb646c9d77ccb1614a78d134e75 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 9 Oct 2014 16:29:34 +0100 Subject: Checks and repr on BoundField --- rest_framework/serializers.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 1c006990..3bd7b17b 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -180,6 +180,7 @@ class BoundField(object): yield self[field.field_name] def __getitem__(self, key): + assert hasattr(self, 'fields'), '"%s" is not a nested field. Cannot perform indexing.' % self.name field = self.fields[key] value = self.value.get(key) if self.value else None error = self.errors.get(key) if self.errors else None @@ -189,6 +190,9 @@ class BoundField(object): def _proxy_class(self): return self._field.__class__ + def __repr__(self): + return '<%s value=%s errors=%s>' % (self.__class__.__name__, self.value, self.errors) + class BindingDict(object): """ -- cgit v1.2.3 From a0e852a4d52558db93209b4616f030b4ae2dcedb Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 9 Oct 2014 16:30:06 +0100 Subject: Use BoundField .name on fields --- rest_framework/templates/rest_framework/fields/attrs.html | 2 +- .../templates/rest_framework/fields/horizontal/checkbox.html | 2 +- rest_framework/templates/rest_framework/fields/horizontal/select.html | 2 +- .../templates/rest_framework/fields/horizontal/select_checkbox.html | 4 ++-- .../templates/rest_framework/fields/horizontal/select_multiple.html | 2 +- .../templates/rest_framework/fields/horizontal/select_radio.html | 4 ++-- rest_framework/templates/rest_framework/fields/inline/checkbox.html | 2 +- rest_framework/templates/rest_framework/fields/inline/select.html | 2 +- .../templates/rest_framework/fields/inline/select_checkbox.html | 2 +- .../templates/rest_framework/fields/inline/select_multiple.html | 2 +- .../templates/rest_framework/fields/inline/select_radio.html | 2 +- rest_framework/templates/rest_framework/fields/vertical/checkbox.html | 2 +- rest_framework/templates/rest_framework/fields/vertical/select.html | 2 +- .../templates/rest_framework/fields/vertical/select_checkbox.html | 4 ++-- .../templates/rest_framework/fields/vertical/select_multiple.html | 2 +- .../templates/rest_framework/fields/vertical/select_radio.html | 4 ++-- 16 files changed, 20 insertions(+), 20 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/templates/rest_framework/fields/attrs.html b/rest_framework/templates/rest_framework/fields/attrs.html index b5a4dbcf..1e23c465 100644 --- a/rest_framework/templates/rest_framework/fields/attrs.html +++ b/rest_framework/templates/rest_framework/fields/attrs.html @@ -1 +1 @@ -name="{{ field.field_name }}" {% if field.style.placeholder %}placeholder="{{ field.style.placeholder }}"{% endif %} {% if field.style.rows %}rows="{{ field.style.rows }}"{% endif %} +name="{{ field.name }}" {% if field.style.placeholder %}placeholder="{{ field.style.placeholder }}"{% endif %} {% if field.style.rows %}rows="{{ field.style.rows }}"{% endif %} diff --git a/rest_framework/templates/rest_framework/fields/horizontal/checkbox.html b/rest_framework/templates/rest_framework/fields/horizontal/checkbox.html index dd3c3cef..ee3bf936 100644 --- a/rest_framework/templates/rest_framework/fields/horizontal/checkbox.html +++ b/rest_framework/templates/rest_framework/fields/horizontal/checkbox.html @@ -2,7 +2,7 @@
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select.html b/rest_framework/templates/rest_framework/fields/horizontal/select.html index 7367d726..10b4b139 100644 --- a/rest_framework/templates/rest_framework/fields/horizontal/select.html +++ b/rest_framework/templates/rest_framework/fields/horizontal/select.html @@ -1,7 +1,7 @@
{% include "rest_framework/fields/horizontal/label.html" %}
- {% for key, text in field.choices.items %} {% endfor %} diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html b/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html index 381cda2c..6041fa74 100644 --- a/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html +++ b/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html @@ -4,7 +4,7 @@ {% if field.style.inline %} {% for key, text in field.choices.items %} {% endfor %} @@ -12,7 +12,7 @@ {% for key, text in field.choices.items %}
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html b/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html index 29ba8661..c0dbb989 100644 --- a/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html +++ b/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html @@ -1,7 +1,7 @@
{% include "rest_framework/fields/horizontal/label.html" %}
- {% for key, text in field.choices.items %} {% endfor %} diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select_radio.html b/rest_framework/templates/rest_framework/fields/horizontal/select_radio.html index 20aab8b2..0eeb9bc6 100644 --- a/rest_framework/templates/rest_framework/fields/horizontal/select_radio.html +++ b/rest_framework/templates/rest_framework/fields/horizontal/select_radio.html @@ -4,7 +4,7 @@ {% if field.style.inline %} {% for key, text in field.choices.items %} {% endfor %} @@ -12,7 +12,7 @@ {% for key, text in field.choices.items %}
diff --git a/rest_framework/templates/rest_framework/fields/inline/checkbox.html b/rest_framework/templates/rest_framework/fields/inline/checkbox.html index 289bbb4d..57fa5dc0 100644 --- a/rest_framework/templates/rest_framework/fields/inline/checkbox.html +++ b/rest_framework/templates/rest_framework/fields/inline/checkbox.html @@ -1,6 +1,6 @@
diff --git a/rest_framework/templates/rest_framework/fields/inline/select.html b/rest_framework/templates/rest_framework/fields/inline/select.html index 9f361c4a..eebb91d2 100644 --- a/rest_framework/templates/rest_framework/fields/inline/select.html +++ b/rest_framework/templates/rest_framework/fields/inline/select.html @@ -1,6 +1,6 @@
{% include "rest_framework/fields/inline/label.html" %} - {% for key, text in field.choices.items %} {% endfor %} diff --git a/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html b/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html index 0f33fb69..b7561cd4 100644 --- a/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html +++ b/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html @@ -3,7 +3,7 @@ {% for key, text in field.choices.items %}
diff --git a/rest_framework/templates/rest_framework/fields/inline/select_multiple.html b/rest_framework/templates/rest_framework/fields/inline/select_multiple.html index 7c9e5168..74e17f9f 100644 --- a/rest_framework/templates/rest_framework/fields/inline/select_multiple.html +++ b/rest_framework/templates/rest_framework/fields/inline/select_multiple.html @@ -1,6 +1,6 @@
{% include "rest_framework/fields/inline/label.html" %} - {% for key, text in field.choices.items %} {% endfor %} diff --git a/rest_framework/templates/rest_framework/fields/inline/select_radio.html b/rest_framework/templates/rest_framework/fields/inline/select_radio.html index 177c0eeb..27927a62 100644 --- a/rest_framework/templates/rest_framework/fields/inline/select_radio.html +++ b/rest_framework/templates/rest_framework/fields/inline/select_radio.html @@ -3,7 +3,7 @@ {% for key, text in field.choices.items %}
diff --git a/rest_framework/templates/rest_framework/fields/vertical/checkbox.html b/rest_framework/templates/rest_framework/fields/vertical/checkbox.html index 01d30aae..9fd4cdaa 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/checkbox.html +++ b/rest_framework/templates/rest_framework/fields/vertical/checkbox.html @@ -1,6 +1,6 @@
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select.html b/rest_framework/templates/rest_framework/fields/vertical/select.html index dcc9a3cd..1a651663 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/select.html +++ b/rest_framework/templates/rest_framework/fields/vertical/select.html @@ -1,6 +1,6 @@
{% include "rest_framework/fields/vertical/label.html" %} - {% for key, text in field.choices.items %} {% endfor %} diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html b/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html index 1fbe6a94..2e792e6a 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html +++ b/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html @@ -4,7 +4,7 @@
{% for key, text in field.choices.items %} {% endfor %} @@ -13,7 +13,7 @@ {% for key, text in field.choices.items %}
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html b/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html index 2cc40d99..5f4166cd 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html +++ b/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html @@ -1,6 +1,6 @@
{% include "rest_framework/fields/vertical/label.html" %} - {% for key, text in field.choices.items %} {% endfor %} diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_radio.html b/rest_framework/templates/rest_framework/fields/vertical/select_radio.html index 470bcb0b..2aa0fe28 100644 --- a/rest_framework/templates/rest_framework/fields/vertical/select_radio.html +++ b/rest_framework/templates/rest_framework/fields/vertical/select_radio.html @@ -4,7 +4,7 @@
{% for key, text in field.choices.items %} {% endfor %} @@ -13,7 +13,7 @@ {% for key, text in field.choices.items %}
-- cgit v1.2.3 From d9a199ca0ddf92f999aa37b396596d0e3e0a26d9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 10 Oct 2014 14:16:09 +0100 Subject: exceptions.ValidationFailed, not Django's ValidationError --- rest_framework/authtoken/serializers.py | 8 ++--- rest_framework/exceptions.py | 14 ++++++++ rest_framework/fields.py | 21 +++++++----- rest_framework/serializers.py | 60 +++++++++++++++++++-------------- rest_framework/views.py | 27 ++++++--------- 5 files changed, 76 insertions(+), 54 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index c2c456de..a808d0a3 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -1,7 +1,7 @@ from django.contrib.auth import authenticate from django.utils.translation import ugettext_lazy as _ -from rest_framework import serializers +from rest_framework import exceptions, serializers class AuthTokenSerializer(serializers.Serializer): @@ -18,13 +18,13 @@ class AuthTokenSerializer(serializers.Serializer): if user: if not user.is_active: msg = _('User account is disabled.') - raise serializers.ValidationError(msg) + raise exceptions.ValidationFailed(msg) else: msg = _('Unable to log in with provided credentials.') - raise serializers.ValidationError(msg) + raise exceptions.ValidationFailed(msg) else: msg = _('Must include "username" and "password"') - raise serializers.ValidationError(msg) + raise exceptions.ValidationFailed(msg) attrs['user'] = user return attrs diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 06b5e8a2..b7c2d16d 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -24,6 +24,20 @@ class APIException(Exception): return self.detail +class ValidationFailed(APIException): + status_code = status.HTTP_400_BAD_REQUEST + + def __init__(self, detail): + # For validation errors the 'detail' key is always required. + # The details should always be coerced to a list if not already. + if not isinstance(detail, dict) and not isinstance(detail, list): + detail = [detail] + self.detail = detail + + def __str__(self): + return str(self.detail) + + class ParseError(APIException): status_code = status.HTTP_400_BAD_REQUEST default_detail = 'Malformed request.' diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 7053acee..b881ad13 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,7 +1,8 @@ -from django import forms from django.conf import settings from django.core import validators -from django.core.exceptions import ObjectDoesNotExist, ValidationError +from django.core.exceptions import ObjectDoesNotExist +from django.core.exceptions import ValidationError as DjangoValidationError +from django.forms import ImageField as DjangoImageField from django.utils import six, timezone from django.utils.datastructures import SortedDict from django.utils.dateparse import parse_date, parse_datetime, parse_time @@ -12,6 +13,7 @@ from rest_framework.compat import ( smart_text, EmailValidator, MinValueValidator, MaxValueValidator, MinLengthValidator, MaxLengthValidator, URLValidator ) +from rest_framework.exceptions import ValidationFailed from rest_framework.settings import api_settings from rest_framework.utils import html, representation, humanize_datetime import copy @@ -98,7 +100,7 @@ NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField' MISSING_ERROR_MESSAGE = ( - 'ValidationError raised by `{class_name}`, but error key `{key}` does ' + 'ValidationFailed raised by `{class_name}`, but error key `{key}` does ' 'not exist in the `error_messages` dictionary.' ) @@ -263,7 +265,7 @@ class Field(object): def run_validators(self, value): """ Test the given value against all the validators on the field, - and either raise a `ValidationError` or simply return. + and either raise a `ValidationFailed` or simply return. """ errors = [] for validator in self.validators: @@ -271,10 +273,12 @@ class Field(object): validator.serializer_field = self try: validator(value) - except ValidationError as exc: + except ValidationFailed as exc: + errors.extend(exc.detail) + except DjangoValidationError as exc: errors.extend(exc.messages) if errors: - raise ValidationError(errors) + raise ValidationFailed(errors) def validate(self, value): pass @@ -301,7 +305,8 @@ class Field(object): class_name = self.__class__.__name__ msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) raise AssertionError(msg) - raise ValidationError(msg.format(**kwargs)) + message_string = msg.format(**kwargs) + raise ValidationFailed(message_string) @property def root(self): @@ -946,7 +951,7 @@ class ImageField(FileField): } def __init__(self, *args, **kwargs): - self._DjangoImageField = kwargs.pop('_DjangoImageField', forms.ImageField) + self._DjangoImageField = kwargs.pop('_DjangoImageField', DjangoImageField) super(ImageField, self).__init__(*args, **kwargs) def to_internal_value(self, data): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 3bd7b17b..2f683562 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,10 +10,11 @@ python primitives. 2. The process of marshalling between python primitives and request and response content is handled by parsers and renderers. """ -from django.core.exceptions import ImproperlyConfigured, ValidationError +from django.core.exceptions import ImproperlyConfigured from django.db import models from django.utils import six from django.utils.datastructures import SortedDict +from rest_framework.exceptions import ValidationFailed from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings from rest_framework.utils import html, model_meta, representation @@ -100,14 +101,14 @@ class BaseSerializer(Field): if not hasattr(self, '_validated_data'): try: self._validated_data = self.run_validation(self._initial_data) - except ValidationError as exc: + except ValidationFailed as exc: self._validated_data = {} - self._errors = exc.message_dict + self._errors = exc.detail else: self._errors = {} if self._errors and raise_exception: - raise ValidationError(self._errors) + raise ValidationFailed(self._errors) return not bool(self._errors) @@ -175,24 +176,34 @@ class BoundField(object): def __getattr__(self, attr_name): return getattr(self._field, attr_name) + @property + def _proxy_class(self): + return self._field.__class__ + + def __repr__(self): + return '<%s value=%s errors=%s>' % ( + self.__class__.__name__, self.value, self.errors + ) + + +class NestedBoundField(BoundField): + """ + This BoundField additionally implements __iter__ and __getitem__ + in order to support nested bound fields. This class is the type of + BoundField that is used for serializer fields. + """ def __iter__(self): for field in self.fields.values(): yield self[field.field_name] def __getitem__(self, key): - assert hasattr(self, 'fields'), '"%s" is not a nested field. Cannot perform indexing.' % self.name field = self.fields[key] value = self.value.get(key) if self.value else None error = self.errors.get(key) if self.errors else None + if isinstance(field, Serializer): + return NestedBoundField(field, value, error, prefix=self.name + '.') return BoundField(field, value, error, prefix=self.name + '.') - @property - def _proxy_class(self): - return self._field.__class__ - - def __repr__(self): - return '<%s value=%s errors=%s>' % (self.__class__.__name__, self.value, self.errors) - class BindingDict(object): """ @@ -308,7 +319,7 @@ class Serializer(BaseSerializer): return None if not isinstance(data, dict): - raise ValidationError({ + raise ValidationFailed({ api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data'] }) @@ -317,9 +328,9 @@ class Serializer(BaseSerializer): self.run_validators(value) value = self.validate(value) assert value is not None, '.validate() should return the validated data' - except ValidationError as exc: - raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: exc.messages + except ValidationFailed as exc: + raise ValidationFailed({ + api_settings.NON_FIELD_ERRORS_KEY: exc.detail }) return value @@ -338,15 +349,15 @@ class Serializer(BaseSerializer): validated_value = field.run_validation(primitive_value) if validate_method is not None: validated_value = validate_method(validated_value) - except ValidationError as exc: - errors[field.field_name] = exc.messages + except ValidationFailed as exc: + errors[field.field_name] = exc.detail except SkipField: pass else: set_value(ret, field.source_attrs, validated_value) if errors: - raise ValidationError(errors) + raise ValidationFailed(errors) return ret @@ -385,6 +396,8 @@ class Serializer(BaseSerializer): field = self.fields[key] value = self.data.get(key) error = self.errors.get(key) if hasattr(self, '_errors') else None + if isinstance(field, Serializer): + return NestedBoundField(field, value, error) return BoundField(field, value, error) @@ -538,9 +551,12 @@ class ModelSerializer(Serializer): ret = SortedDict() model = getattr(self.Meta, 'model') fields = getattr(self.Meta, 'fields', None) + exclude = getattr(self.Meta, 'exclude', None) depth = getattr(self.Meta, 'depth', 0) extra_kwargs = getattr(self.Meta, 'extra_kwargs', {}) + assert not fields and exclude, "Cannot set both 'fields' and 'exclude'." + extra_kwargs = self._include_additional_options(extra_kwargs) # Retrieve metadata about fields & relationships on the model class. @@ -551,12 +567,6 @@ class ModelSerializer(Serializer): fields = self._get_default_field_names(declared_fields, info) exclude = getattr(self.Meta, 'exclude', None) if exclude is not None: - warnings.warn( - "The `Meta.exclude` option is pending deprecation. " - "Use the explicit `Meta.fields` instead.", - PendingDeprecationWarning, - stacklevel=3 - ) for field_name in exclude: fields.remove(field_name) diff --git a/rest_framework/views.py b/rest_framework/views.py index 979229eb..292431c8 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework. """ from __future__ import unicode_literals -from django.core.exceptions import PermissionDenied, ValidationError, NON_FIELD_ERRORS +from django.core.exceptions import PermissionDenied from django.http import Http404 from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions @@ -63,27 +63,20 @@ def exception_handler(exc): if getattr(exc, 'wait', None): headers['Retry-After'] = '%d' % exc.wait - return Response({'detail': exc.detail}, - status=exc.status_code, - headers=headers) + if isinstance(exc.detail, (list, dict)): + data = exc.detail + else: + data = {'detail': exc.detail} - elif isinstance(exc, ValidationError): - # ValidationErrors may include the non-field key named '__all__'. - # When returning a response we map this to a key name that can be - # modified in settings. - if NON_FIELD_ERRORS in exc.message_dict: - errors = exc.message_dict.pop(NON_FIELD_ERRORS) - exc.message_dict[api_settings.NON_FIELD_ERRORS_KEY] = errors - return Response(exc.message_dict, - status=status.HTTP_400_BAD_REQUEST) + return Response(data, status=exc.status_code, headers=headers) elif isinstance(exc, Http404): - return Response({'detail': 'Not found'}, - status=status.HTTP_404_NOT_FOUND) + data = {'detail': 'Not found'} + return Response(data, status=status.HTTP_404_NOT_FOUND) elif isinstance(exc, PermissionDenied): - return Response({'detail': 'Permission denied'}, - status=status.HTTP_403_FORBIDDEN) + data = {'detail': 'Permission denied'} + return Response(data, status=status.HTTP_403_FORBIDDEN) # Note: Unhandled exceptions will raise a 500 error. return None -- cgit v1.2.3 From d8a8987ab1eb6abbaee1a0de8cfea38eafe21293 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 10 Oct 2014 14:32:02 +0100 Subject: Tweaks --- rest_framework/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 2f683562..0f6cf2bc 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -555,7 +555,7 @@ class ModelSerializer(Serializer): depth = getattr(self.Meta, 'depth', 0) extra_kwargs = getattr(self.Meta, 'extra_kwargs', {}) - assert not fields and exclude, "Cannot set both 'fields' and 'exclude'." + assert not (fields and exclude), "Cannot set both 'fields' and 'exclude'." extra_kwargs = self._include_additional_options(extra_kwargs) -- cgit v1.2.3 From b5a4216aff06bfb36238d0f587d8645db0ee4a69 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 10 Oct 2014 15:08:43 +0100 Subject: Flake8 --- rest_framework/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 0f6cf2bc..f3f5c837 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -555,7 +555,7 @@ class ModelSerializer(Serializer): depth = getattr(self.Meta, 'depth', 0) extra_kwargs = getattr(self.Meta, 'extra_kwargs', {}) - assert not (fields and exclude), "Cannot set both 'fields' and 'exclude'." + assert not (fields and exclude), "Cannot set both 'fields' and 'exclude'." extra_kwargs = self._include_additional_options(extra_kwargs) -- cgit v1.2.3 From 826b5a889704452c53c05a44905f9fa62889ff34 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 10 Oct 2014 15:34:00 +0100 Subject: Relations in 'read_only_fields' should not include a queryset kwarg --- rest_framework/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index f3f5c837..bc9c15eb 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -640,7 +640,7 @@ class ModelSerializer(Serializer): for attr in [ 'required', 'default', 'allow_blank', 'allow_null', 'min_length', 'max_length', 'min_value', 'max_value', - 'validators' + 'validators', 'queryset' ]: kwargs.pop(attr, None) kwargs.update(extras) -- cgit v1.2.3 From 81abf2bf341d8d7b27e2974a01a78c30c796b4d6 Mon Sep 17 00:00:00 2001 From: Andy Freeland Date: Sun, 12 Oct 2014 01:19:14 -0400 Subject: Rename `preform_update` to `perform_update` --- rest_framework/mixins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 4c62debb..467ff515 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -67,10 +67,10 @@ class UpdateModelMixin(object): instance = self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) serializer.is_valid(raise_exception=True) - self.preform_update(serializer) + self.perform_update(serializer) return Response(serializer.data) - def preform_update(self, serializer): + def perform_update(self, serializer): serializer.save() def partial_update(self, request, *args, **kwargs): -- cgit v1.2.3 From e272a36c9b444c1da3a3d8bc809070deb26d9c64 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 15 Oct 2014 09:24:49 +0100 Subject: Fix 'lookup_field' on ModelSerializer. Closes #1944. --- rest_framework/serializers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index bc9c15eb..c844605f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -680,7 +680,7 @@ class ModelSerializer(Serializer): PendingDeprecationWarning, stacklevel=3 ) - kwargs = extra_kwargs.get(field_name, {}) + kwargs = extra_kwargs.get(api_settings.URL_FIELD_NAME, {}) kwargs['view_name'] = view_name extra_kwargs[api_settings.URL_FIELD_NAME] = kwargs @@ -692,7 +692,7 @@ class ModelSerializer(Serializer): PendingDeprecationWarning, stacklevel=3 ) - kwargs = extra_kwargs.get(field_name, {}) + kwargs = extra_kwargs.get(api_settings.URL_FIELD_NAME, {}) kwargs['lookup_field'] = lookup_field extra_kwargs[api_settings.URL_FIELD_NAME] = kwargs -- cgit v1.2.3 From e558f806c0e87a329915b7077783f9ed3a79bb07 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 15 Oct 2014 10:04:01 +0100 Subject: Drop template includes --- rest_framework/templates/rest_framework/fields/attrs.html | 1 - .../templates/rest_framework/fields/horizontal/input.html | 6 ++++-- .../templates/rest_framework/fields/horizontal/label.html | 1 - .../templates/rest_framework/fields/horizontal/select.html | 8 +++++--- .../rest_framework/fields/horizontal/select_checkbox.html | 4 +++- .../rest_framework/fields/horizontal/select_multiple.html | 4 +++- .../templates/rest_framework/fields/horizontal/select_radio.html | 4 +++- .../templates/rest_framework/fields/horizontal/textarea.html | 6 ++++-- rest_framework/templates/rest_framework/fields/inline/input.html | 6 ++++-- rest_framework/templates/rest_framework/fields/inline/label.html | 1 - rest_framework/templates/rest_framework/fields/inline/select.html | 4 +++- .../templates/rest_framework/fields/inline/select_checkbox.html | 4 +++- .../templates/rest_framework/fields/inline/select_multiple.html | 4 +++- .../templates/rest_framework/fields/inline/select_radio.html | 4 +++- .../templates/rest_framework/fields/inline/textarea.html | 6 ++++-- .../templates/rest_framework/fields/vertical/input.html | 6 ++++-- .../templates/rest_framework/fields/vertical/label.html | 1 - .../templates/rest_framework/fields/vertical/select.html | 4 +++- .../templates/rest_framework/fields/vertical/select_checkbox.html | 4 +++- .../templates/rest_framework/fields/vertical/select_multiple.html | 4 +++- .../templates/rest_framework/fields/vertical/select_radio.html | 4 +++- .../templates/rest_framework/fields/vertical/textarea.html | 6 ++++-- 22 files changed, 62 insertions(+), 30 deletions(-) delete mode 100644 rest_framework/templates/rest_framework/fields/attrs.html delete mode 100644 rest_framework/templates/rest_framework/fields/horizontal/label.html delete mode 100644 rest_framework/templates/rest_framework/fields/inline/label.html delete mode 100644 rest_framework/templates/rest_framework/fields/vertical/label.html (limited to 'rest_framework') diff --git a/rest_framework/templates/rest_framework/fields/attrs.html b/rest_framework/templates/rest_framework/fields/attrs.html deleted file mode 100644 index 1e23c465..00000000 --- a/rest_framework/templates/rest_framework/fields/attrs.html +++ /dev/null @@ -1 +0,0 @@ -name="{{ field.name }}" {% if field.style.placeholder %}placeholder="{{ field.style.placeholder }}"{% endif %} {% if field.style.rows %}rows="{{ field.style.rows }}"{% endif %} diff --git a/rest_framework/templates/rest_framework/fields/horizontal/input.html b/rest_framework/templates/rest_framework/fields/horizontal/input.html index 6f1a504b..6621c7e6 100644 --- a/rest_framework/templates/rest_framework/fields/horizontal/input.html +++ b/rest_framework/templates/rest_framework/fields/horizontal/input.html @@ -1,7 +1,9 @@
- {% include "rest_framework/fields/horizontal/label.html" %} + {% if field.label %} + + {% endif %}
- + {% if field.help_text %}

{{ field.help_text }}

{% endif %}
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/label.html b/rest_framework/templates/rest_framework/fields/horizontal/label.html deleted file mode 100644 index bf21f78c..00000000 --- a/rest_framework/templates/rest_framework/fields/horizontal/label.html +++ /dev/null @@ -1 +0,0 @@ -{% if field.label %}{% endif %} diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select.html b/rest_framework/templates/rest_framework/fields/horizontal/select.html index 10b4b139..eaa6d575 100644 --- a/rest_framework/templates/rest_framework/fields/horizontal/select.html +++ b/rest_framework/templates/rest_framework/fields/horizontal/select.html @@ -1,10 +1,12 @@
- {% include "rest_framework/fields/horizontal/label.html" %} + {% if field.label %} + + {% endif %}
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html b/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html index 6041fa74..ff3fab57 100644 --- a/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html +++ b/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html @@ -1,5 +1,7 @@
- {% include "rest_framework/fields/horizontal/label.html" %} + {% if field.label %} + + {% endif %}
{% if field.style.inline %} {% for key, text in field.choices.items %} diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html b/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html index c0dbb989..3ed2874b 100644 --- a/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html +++ b/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html @@ -1,5 +1,7 @@
- {% include "rest_framework/fields/horizontal/label.html" %} + {% if field.label %} + + {% endif %}
+ {% if field.help_text %}

{{ field.help_text }}

{% endif %}
diff --git a/rest_framework/templates/rest_framework/fields/inline/input.html b/rest_framework/templates/rest_framework/fields/inline/input.html index e4a92ccd..bdf25ffe 100644 --- a/rest_framework/templates/rest_framework/fields/inline/input.html +++ b/rest_framework/templates/rest_framework/fields/inline/input.html @@ -1,4 +1,6 @@
- {% include "rest_framework/fields/inline/label.html" %} - + {% if field.label %} + + {% endif %} +
diff --git a/rest_framework/templates/rest_framework/fields/inline/label.html b/rest_framework/templates/rest_framework/fields/inline/label.html deleted file mode 100644 index 7d546a57..00000000 --- a/rest_framework/templates/rest_framework/fields/inline/label.html +++ /dev/null @@ -1 +0,0 @@ -{% if field.label %}{% endif %} diff --git a/rest_framework/templates/rest_framework/fields/inline/select.html b/rest_framework/templates/rest_framework/fields/inline/select.html index eebb91d2..730fcce6 100644 --- a/rest_framework/templates/rest_framework/fields/inline/select.html +++ b/rest_framework/templates/rest_framework/fields/inline/select.html @@ -1,5 +1,7 @@
- {% include "rest_framework/fields/inline/label.html" %} + {% if field.label %} + + {% endif %} {% for key, text in field.choices.items %} diff --git a/rest_framework/templates/rest_framework/fields/inline/select_radio.html b/rest_framework/templates/rest_framework/fields/inline/select_radio.html index 27927a62..3fffceac 100644 --- a/rest_framework/templates/rest_framework/fields/inline/select_radio.html +++ b/rest_framework/templates/rest_framework/fields/inline/select_radio.html @@ -1,5 +1,7 @@
- {% include "rest_framework/fields/inline/label.html" %} + {% if field.label %} + + {% endif %} {% for key, text in field.choices.items %}