From f27a28682bdb1b4eea0ec9afca2eb2835c735f55 Mon Sep 17 00:00:00 2001
From: Greg Doermann
Date: Wed, 20 Aug 2014 11:00:37 -0600
Subject: Frameworks throws AssertionError saying you cannot set required=True
and read_only=True on editable=False model fields. We should not make the
field required if editable=False.
---
rest_framework/serializers.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index be8ad3f2..27af7ef3 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -831,7 +831,7 @@ class ModelSerializer(Serializer):
}
if model_field:
- kwargs['required'] = not(model_field.null or model_field.blank)
+ kwargs['required'] = not(model_field.null or model_field.blank) and model_field.editable
if model_field.help_text is not None:
kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
@@ -854,7 +854,7 @@ class ModelSerializer(Serializer):
"""
kwargs = {}
- if model_field.null or model_field.blank:
+ if model_field.null or model_field.blank and model_field.editable:
kwargs['required'] = False
if isinstance(model_field, models.AutoField) or not model_field.editable:
@@ -1110,7 +1110,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
}
if model_field:
- kwargs['required'] = not(model_field.null or model_field.blank)
+ kwargs['required'] = not(model_field.null or model_field.blank) and model_field.editable
if model_field.help_text is not None:
kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
--
cgit v1.2.3
From f62c874ea9621ae67fb56e7e453dca8fd5039051 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 29 Aug 2014 10:48:40 +0100
Subject: Remove `filter_backend`.
Closes #1775.
---
rest_framework/generics.py | 20 +-------------------
rest_framework/settings.py | 10 ----------
2 files changed, 1 insertion(+), 29 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index a6f68657..8bacf470 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -83,7 +83,6 @@ class GenericAPIView(views.APIView):
slug_url_kwarg = 'slug'
slug_field = 'slug'
allow_empty = True
- filter_backend = api_settings.FILTER_BACKEND
def get_serializer_context(self):
"""
@@ -191,24 +190,7 @@ class GenericAPIView(views.APIView):
"""
Returns the list of filter backends that this view requires.
"""
- if self.filter_backends is None:
- filter_backends = []
- else:
- # Note that we are returning a *copy* of the class attribute,
- # so that it is safe for the view to mutate it if needed.
- filter_backends = list(self.filter_backends)
-
- if not filter_backends and self.filter_backend:
- warnings.warn(
- 'The `filter_backend` attribute and `FILTER_BACKEND` setting '
- 'are deprecated in favor of a `filter_backends` '
- 'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take '
- 'a *list* of filter backend classes.',
- DeprecationWarning, stacklevel=2
- )
- filter_backends = [self.filter_backend]
-
- return filter_backends
+ return list(self.filter_backends)
# The following methods provide default implementations
# that you may want to override for more complex cases.
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 644751f8..bbe7a56a 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -111,9 +111,6 @@ DEFAULTS = {
),
'TIME_FORMAT': None,
- # Pending deprecation
- 'FILTER_BACKEND': None,
-
}
@@ -129,7 +126,6 @@ IMPORT_STRINGS = (
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'DEFAULT_FILTER_BACKENDS',
'EXCEPTION_HANDLER',
- 'FILTER_BACKEND',
'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
@@ -196,15 +192,9 @@ class APISettings(object):
if val and attr in self.import_strings:
val = perform_import(val, attr)
- self.validate_setting(attr, val)
-
# Cache the result
setattr(self, attr, val)
return val
- def validate_setting(self, attr, val):
- if attr == 'FILTER_BACKEND' and val is not None:
- # Make sure we can initialize the class
- val()
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)
--
cgit v1.2.3
From 0f8fdf4e72b67ff46474c13c8b532bf319a58099 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 29 Aug 2014 10:57:24 +0100
Subject: Remove `allow_empty`.
Closes #1774.
---
rest_framework/generics.py | 12 +-----------
1 file changed, 1 insertion(+), 11 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 8bacf470..cb8061b7 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -82,7 +82,6 @@ class GenericAPIView(views.APIView):
pk_url_kwarg = 'pk'
slug_url_kwarg = 'slug'
slug_field = 'slug'
- allow_empty = True
def get_serializer_context(self):
"""
@@ -140,16 +139,7 @@ class GenericAPIView(views.APIView):
if not page_size:
return None
- if not self.allow_empty:
- warnings.warn(
- 'The `allow_empty` parameter is deprecated. '
- 'To use `allow_empty=False` style behavior, You should override '
- '`get_queryset()` and explicitly raise a 404 on empty querysets.',
- DeprecationWarning, stacklevel=2
- )
-
- paginator = self.paginator_class(queryset, page_size,
- allow_empty_first_page=self.allow_empty)
+ paginator = self.paginator_class(queryset, page_size)
page_kwarg = self.kwargs.get(self.page_kwarg)
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1
--
cgit v1.2.3
From b3bbf416707cf8c71861b0fd6e966a557acef412 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 29 Aug 2014 11:09:35 +0100
Subject: Remove `allow_empty`
---
rest_framework/mixins.py | 15 ---------------
1 file changed, 15 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 2cc87eef..dc4c9f35 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -70,24 +70,9 @@ class ListModelMixin(object):
"""
List a queryset.
"""
- empty_error = "Empty list and '%(class_name)s.allow_empty' is False."
-
def list(self, request, *args, **kwargs):
self.object_list = self.filter_queryset(self.get_queryset())
- # Default is to allow empty querysets. This can be altered by setting
- # `.allow_empty = False`, to raise 404 errors on empty querysets.
- if not self.allow_empty and not self.object_list:
- warnings.warn(
- 'The `allow_empty` parameter is deprecated. '
- 'To use `allow_empty=False` style behavior, You should override '
- '`get_queryset()` and explicitly raise a 404 on empty querysets.',
- DeprecationWarning
- )
- class_name = self.__class__.__name__
- error_msg = self.empty_error % {'class_name': class_name}
- raise Http404(error_msg)
-
# Switch between paginated or standard style responses
page = self.paginate_queryset(self.object_list)
if page is not None:
--
cgit v1.2.3
From e5e6329a222def3b0745f90fc55ee36de95ada83 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 29 Aug 2014 11:29:26 +0100
Subject: Remove `pk_url_field`, `slug_url_field`, `slug_field`.
Closes #1773.
---
rest_framework/generics.py | 36 ++------------
rest_framework/mixins.py | 36 +++-----------
rest_framework/relations.py | 117 ++------------------------------------------
3 files changed, 15 insertions(+), 174 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index cb8061b7..e21dc5c7 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -76,13 +76,6 @@ class GenericAPIView(views.APIView):
model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
paginator_class = Paginator
- ######################################
- # These are pending deprecation...
-
- pk_url_kwarg = 'pk'
- slug_url_kwarg = 'slug'
- slug_field = 'slug'
-
def get_serializer_context(self):
"""
Extra context provided to the serializer class.
@@ -270,7 +263,7 @@ class GenericAPIView(views.APIView):
error_format = "'%s' must define 'queryset' or 'model'"
raise ImproperlyConfigured(error_format % self.__class__.__name__)
- def get_object(self, queryset=None):
+ def get_object(self):
"""
Returns the object the view is displaying.
@@ -278,36 +271,14 @@ class GenericAPIView(views.APIView):
queryset lookups. Eg if objects are referenced using multiple
keyword arguments in the url conf.
"""
- # Determine the base queryset to use.
- if queryset is None:
- queryset = self.filter_queryset(self.get_queryset())
- else:
- pass # Deprecation warning
+ queryset = self.filter_queryset(self.get_queryset())
# Perform the lookup filtering.
# Note that `pk` and `slug` are deprecated styles of lookup filtering.
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup = self.kwargs.get(lookup_url_kwarg, None)
- pk = self.kwargs.get(self.pk_url_kwarg, None)
- slug = self.kwargs.get(self.slug_url_kwarg, None)
- if lookup is not None:
- filter_kwargs = {self.lookup_field: lookup}
- elif pk is not None and self.lookup_field == 'pk':
- warnings.warn(
- 'The `pk_url_kwarg` attribute is deprecated. '
- 'Use the `lookup_field` attribute instead',
- DeprecationWarning
- )
- filter_kwargs = {'pk': pk}
- elif slug is not None and self.lookup_field == 'pk':
- warnings.warn(
- 'The `slug_url_kwarg` attribute is deprecated. '
- 'Use the `lookup_field` attribute instead',
- DeprecationWarning
- )
- filter_kwargs = {self.slug_field: slug}
- else:
+ if lookup is None:
raise ImproperlyConfigured(
'Expected view %s to be called with a URL keyword argument '
'named "%s". Fix your URL conf, or set the `.lookup_field` '
@@ -315,6 +286,7 @@ class GenericAPIView(views.APIView):
(self.__class__.__name__, self.lookup_field)
)
+ filter_kwargs = {self.lookup_field: lookup}
obj = get_object_or_404(queryset, **filter_kwargs)
# May raise a permission denied
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index dc4c9f35..ac59d979 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -12,10 +12,9 @@ from rest_framework import status
from rest_framework.response import Response
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
-import warnings
-def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None):
+def _get_validation_exclusions(obj, lookup_field=None):
"""
Given a model instance, and an optional pk and slug field,
return the full list of all other field names on that model.
@@ -23,23 +22,13 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None)
For use when performing full_clean on a model instance,
so we only clean the required fields.
"""
- include = []
-
- if pk:
- # Deprecated
+ if lookup_field == 'pk':
pk_field = obj._meta.pk
while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk
- include.append(pk_field.name)
-
- if slug_field:
- # Deprecated
- include.append(slug_field)
-
- if lookup_field and lookup_field != 'pk':
- include.append(lookup_field)
+ lookup_field = pk_field.name
- return [field.name for field in obj._meta.fields if field.name not in include]
+ return [field.name for field in obj._meta.fields if field.name != lookup_field]
class CreateModelMixin(object):
@@ -146,26 +135,15 @@ class UpdateModelMixin(object):
"""
Set any attributes on the object that are implicit in the request.
"""
- # pk and/or slug attributes are implicit in the URL.
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
- lookup = self.kwargs.get(lookup_url_kwarg, None)
- pk = self.kwargs.get(self.pk_url_kwarg, None)
- slug = self.kwargs.get(self.slug_url_kwarg, None)
- slug_field = slug and self.slug_field or None
-
- if lookup:
- setattr(obj, self.lookup_field, lookup)
-
- if pk:
- setattr(obj, 'pk', pk)
+ lookup_value = self.kwargs[lookup_url_kwarg]
- if slug:
- setattr(obj, slug_field, slug)
+ setattr(obj, self.lookup_field, lookup_value)
# Ensure we clean the attributes so that we don't eg return integer
# pk using a string representation, as provided by the url conf kwarg.
if hasattr(obj, 'full_clean'):
- exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field)
+ exclude = _get_validation_exclusions(obj, self.lookup_field)
obj.full_clean(exclude)
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 1acbdce2..56870b40 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -16,7 +16,6 @@ from rest_framework.fields import Field, WritableField, get_component, is_simple
from rest_framework.reverse import reverse
from rest_framework.compat import urlparse
from rest_framework.compat import smart_text
-import warnings
# Relational fields
@@ -320,11 +319,6 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
}
- # These are all deprecated
- pk_url_kwarg = 'pk'
- slug_field = 'slug'
- slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
-
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
@@ -334,22 +328,6 @@ class HyperlinkedRelatedField(RelatedField):
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.format = kwargs.pop('format', None)
- # These are deprecated
- if 'pk_url_kwarg' in kwargs:
- msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- if 'slug_url_kwarg' in kwargs:
- msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- if 'slug_field' in kwargs:
- msg = 'slug_field is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
-
- self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
- self.slug_field = kwargs.pop('slug_field', self.slug_field)
- default_slug_kwarg = self.slug_url_kwarg or self.slug_field
- self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
-
super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
def get_url(self, obj, view_name, request, format):
@@ -361,39 +339,7 @@ class HyperlinkedRelatedField(RelatedField):
"""
lookup_field = getattr(obj, self.lookup_field)
kwargs = {self.lookup_field: lookup_field}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- if self.pk_url_kwarg != 'pk':
- # Only try pk if it has been explicitly set.
- # Otherwise, the default `lookup_field = 'pk'` has us covered.
- pk = obj.pk
- kwargs = {self.pk_url_kwarg: pk}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- slug = getattr(obj, self.slug_field, None)
- if slug is not None:
- # Only try slug if it corresponds to an attribute on the object.
- kwargs = {self.slug_url_kwarg: slug}
- try:
- ret = reverse(view_name, kwargs=kwargs, request=request, format=format)
- if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug':
- # If the lookup succeeds using the default slug params,
- # then `slug_field` is being used implicitly, and we
- # we need to warn about the pending deprecation.
- msg = 'Implicit slug field hyperlinked fields are deprecated.' \
- 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- return ret
- except NoReverseMatch:
- pass
-
- raise NoReverseMatch()
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
def get_object(self, queryset, view_name, view_args, view_kwargs):
"""
@@ -402,19 +348,8 @@ class HyperlinkedRelatedField(RelatedField):
Takes the matched URL conf arguments, and the queryset, and should
return an object instance, or raise an `ObjectDoesNotExist` exception.
"""
- lookup = view_kwargs.get(self.lookup_field, None)
- pk = view_kwargs.get(self.pk_url_kwarg, None)
- slug = view_kwargs.get(self.slug_url_kwarg, None)
-
- if lookup is not None:
- filter_kwargs = {self.lookup_field: lookup}
- elif pk is not None:
- filter_kwargs = {'pk': pk}
- elif slug is not None:
- filter_kwargs = {self.slug_field: slug}
- else:
- raise ObjectDoesNotExist()
-
+ lookup_value = view_kwargs[self.lookup_field]
+ filter_kwargs = {self.lookup_field: lookup_value}
return queryset.get(**filter_kwargs)
def to_native(self, obj):
@@ -486,11 +421,6 @@ class HyperlinkedIdentityField(Field):
lookup_field = 'pk'
read_only = True
- # These are all deprecated
- pk_url_kwarg = 'pk'
- slug_field = 'slug'
- slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
-
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
@@ -502,22 +432,6 @@ class HyperlinkedIdentityField(Field):
lookup_field = kwargs.pop('lookup_field', None)
self.lookup_field = lookup_field or self.lookup_field
- # These are deprecated
- if 'pk_url_kwarg' in kwargs:
- msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- if 'slug_url_kwarg' in kwargs:
- msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- if 'slug_field' in kwargs:
- msg = 'slug_field is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
-
- self.slug_field = kwargs.pop('slug_field', self.slug_field)
- default_slug_kwarg = self.slug_url_kwarg or self.slug_field
- self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
- self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
-
super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name):
@@ -569,27 +483,4 @@ class HyperlinkedIdentityField(Field):
if lookup_field is None:
return None
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- if self.pk_url_kwarg != 'pk':
- # Only try pk lookup if it has been explicitly set.
- # Otherwise, the default `lookup_field = 'pk'` has us covered.
- kwargs = {self.pk_url_kwarg: obj.pk}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- slug = getattr(obj, self.slug_field, None)
- if slug:
- # Only use slug lookup if a slug field exists on the model
- kwargs = {self.slug_url_kwarg: slug}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- raise NoReverseMatch()
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
--
cgit v1.2.3
From b8c8d10a18741b76355ed7035655d0101c1d778a Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 29 Aug 2014 11:38:54 +0100
Subject: Remove `page_size` argument.
`paginate_queryset` no longer takes an optional `page_size` argument.
---
rest_framework/generics.py | 22 ++++------------------
1 file changed, 4 insertions(+), 18 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index e21dc5c7..09035303 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -111,26 +111,14 @@ class GenericAPIView(views.APIView):
context = self.get_serializer_context()
return pagination_serializer_class(instance=page, context=context)
- def paginate_queryset(self, queryset, page_size=None):
+ def paginate_queryset(self, queryset):
"""
Paginate a queryset if required, either returning a page object,
or `None` if pagination is not configured for this view.
"""
- deprecated_style = False
- if page_size is not None:
- warnings.warn('The `page_size` parameter to `paginate_queryset()` '
- 'is deprecated. '
- 'Note that the return style of this method is also '
- 'changed, and will simply return a page object '
- 'when called without a `page_size` argument.',
- DeprecationWarning, stacklevel=2)
- deprecated_style = True
- else:
- # Determine the required page size.
- # If pagination is not configured, simply return None.
- page_size = self.get_paginate_by()
- if not page_size:
- return None
+ page_size = self.get_paginate_by()
+ if not page_size:
+ return None
paginator = self.paginator_class(queryset, page_size)
page_kwarg = self.kwargs.get(self.page_kwarg)
@@ -152,8 +140,6 @@ class GenericAPIView(views.APIView):
'message': str(exc)
})
- if deprecated_style:
- return (paginator, page, page.object_list, page.has_other_pages())
return page
def filter_queryset(self, queryset):
--
cgit v1.2.3
From b3253b42836acd123224e88c0927f1ee6a031d94 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 29 Aug 2014 12:35:53 +0100
Subject: Remove `.model` usage in tests.
Remove the shortcut `.model` view attribute usage from test cases.
---
rest_framework/generics.py | 49 ++++++++++++----------------------------------
1 file changed, 12 insertions(+), 37 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 09035303..68222864 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -51,11 +51,6 @@ class GenericAPIView(views.APIView):
queryset = None
serializer_class = None
- # This shortcut may be used instead of setting either or both
- # of the `queryset`/`serializer_class` attributes, although using
- # the explicit style is generally preferred.
- model = None
-
# If you want to use object lookups other than pk, set this attribute.
# For more complex lookup requirements override `get_object()`.
lookup_field = 'pk'
@@ -71,9 +66,8 @@ class GenericAPIView(views.APIView):
# The filter backend classes to use for queryset filtering
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
- # The following attributes may be subject to change,
+ # The following attribute may be subject to change,
# and should be considered private API.
- model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
paginator_class = Paginator
def get_serializer_context(self):
@@ -199,26 +193,13 @@ class GenericAPIView(views.APIView):
(Eg. admins get full serialization, others get basic serialization)
"""
- serializer_class = self.serializer_class
- if serializer_class is not None:
- return serializer_class
-
- warnings.warn(
- 'The `.model` attribute on view classes is now deprecated in favor '
- 'of the more explicit `serializer_class` and `queryset` attributes.',
- DeprecationWarning, stacklevel=2
- )
-
- assert self.model is not None, \
- "'%s' should either include a 'serializer_class' attribute, " \
- "or use the 'model' attribute as a shortcut for " \
- "automatically generating a serializer class." \
+ assert self.serializer_class is not None, (
+ "'%s' should either include a `serializer_class` attribute, "
+ "or override the `get_serializer_class()` method."
% self.__class__.__name__
+ )
- class DefaultSerializer(self.model_serializer_class):
- class Meta:
- model = self.model
- return DefaultSerializer
+ return self.serializer_class
def get_queryset(self):
"""
@@ -235,19 +216,13 @@ class GenericAPIView(views.APIView):
(Eg. return a list of items that is specific to the user)
"""
- if self.queryset is not None:
- return self.queryset._clone()
-
- if self.model is not None:
- warnings.warn(
- 'The `.model` attribute on view classes is now deprecated in favor '
- 'of the more explicit `serializer_class` and `queryset` attributes.',
- DeprecationWarning, stacklevel=2
- )
- return self.model._default_manager.all()
+ assert self.queryset is not None, (
+ "'%s' should either include a `queryset` attribute, "
+ "or override the `get_queryset()` method."
+ % self.__class__.__name__
+ )
- error_format = "'%s' must define 'queryset' or 'model'"
- raise ImproperlyConfigured(error_format % self.__class__.__name__)
+ return self.queryset._clone()
def get_object(self):
"""
--
cgit v1.2.3
From 72c0811576feb89decf6fc6dc4ee5e25eca0aece Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 29 Aug 2014 12:48:04 +0100
Subject: Minor tidy up.
---
rest_framework/generics.py | 19 ++++++++-----------
1 file changed, 8 insertions(+), 11 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 68222864..d0adeaec 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -3,7 +3,7 @@ Generic views that provide commonly needed behaviour.
"""
from __future__ import unicode_literals
-from django.core.exceptions import ImproperlyConfigured, PermissionDenied
+from django.core.exceptions import PermissionDenied
from django.core.paginator import Paginator, InvalidPage
from django.http import Http404
from django.shortcuts import get_object_or_404 as _get_object_or_404
@@ -235,19 +235,16 @@ class GenericAPIView(views.APIView):
queryset = self.filter_queryset(self.get_queryset())
# Perform the lookup filtering.
- # Note that `pk` and `slug` are deprecated styles of lookup filtering.
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
- lookup = self.kwargs.get(lookup_url_kwarg, None)
- if lookup is None:
- raise ImproperlyConfigured(
- 'Expected view %s to be called with a URL keyword argument '
- 'named "%s". Fix your URL conf, or set the `.lookup_field` '
- 'attribute on the view correctly.' %
- (self.__class__.__name__, self.lookup_field)
- )
+ assert lookup_url_kwarg in self.kwargs, (
+ 'Expected view %s to be called with a URL keyword argument '
+ 'named "%s". Fix your URL conf, or set the `.lookup_field` '
+ 'attribute on the view correctly.' %
+ (self.__class__.__name__, lookup_url_kwarg)
+ )
- filter_kwargs = {self.lookup_field: lookup}
+ filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
obj = get_object_or_404(queryset, **filter_kwargs)
# May raise a permission denied
--
cgit v1.2.3
From ce7b2cded94abc12ae1be076642de96684d0927b Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 29 Aug 2014 12:48:49 +0100
Subject: Remove deprecated generic views.
`MultipleObjectAPIView` and `SingleObjectAPIView` are no longer
required.
---
rest_framework/generics.py | 22 ----------------------
1 file changed, 22 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index d0adeaec..e6cbfca9 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -442,25 +442,3 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
-
-
-# Deprecated classes
-
-class MultipleObjectAPIView(GenericAPIView):
- def __init__(self, *args, **kwargs):
- warnings.warn(
- 'Subclassing `MultipleObjectAPIView` is deprecated. '
- 'You should simply subclass `GenericAPIView` instead.',
- DeprecationWarning, stacklevel=2
- )
- super(MultipleObjectAPIView, self).__init__(*args, **kwargs)
-
-
-class SingleObjectAPIView(GenericAPIView):
- def __init__(self, *args, **kwargs):
- warnings.warn(
- 'Subclassing `SingleObjectAPIView` is deprecated. '
- 'You should simply subclass `GenericAPIView` instead.',
- DeprecationWarning, stacklevel=2
- )
- super(SingleObjectAPIView, self).__init__(*args, **kwargs)
--
cgit v1.2.3
From f87d32558eb3b36f14a798ec48e4943d25380b92 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 29 Aug 2014 12:53:45 +0100
Subject: Remove `.link()` and `.action()` decorators.
---
rest_framework/decorators.py | 34 ----------------------------------
1 file changed, 34 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 449ba0a2..cc5d92c2 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -130,37 +130,3 @@ def list_route(methods=['get'], **kwargs):
func.kwargs = kwargs
return func
return decorator
-
-
-# These are now pending deprecation, in favor of `detail_route` and `list_route`.
-
-def link(**kwargs):
- """
- Used to mark a method on a ViewSet that should be routed for detail GET requests.
- """
- msg = 'link is pending deprecation. Use detail_route instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
-
- def decorator(func):
- func.bind_to_methods = ['get']
- func.detail = True
- func.kwargs = kwargs
- return func
-
- return decorator
-
-
-def action(methods=['post'], **kwargs):
- """
- Used to mark a method on a ViewSet that should be routed for detail POST requests.
- """
- msg = 'action is pending deprecation. Use detail_route instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
-
- def decorator(func):
- func.bind_to_methods = methods
- func.detail = True
- func.kwargs = kwargs
- return func
-
- return decorator
--
cgit v1.2.3
From b552b62540e5144272c9c13c28f120ffe5fcbe45 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 29 Aug 2014 12:54:03 +0100
Subject: `get_paginate_by` no longer takes optional `.queryset`
---
rest_framework/generics.py | 7 +------
1 file changed, 1 insertion(+), 6 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index e6cbfca9..40c49844 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -158,7 +158,7 @@ class GenericAPIView(views.APIView):
# The following methods provide default implementations
# that you may want to override for more complex cases.
- def get_paginate_by(self, queryset=None):
+ def get_paginate_by(self):
"""
Return the size of pages to use with pagination.
@@ -167,11 +167,6 @@ class GenericAPIView(views.APIView):
Otherwise defaults to using `self.paginate_by`.
"""
- if queryset is not None:
- warnings.warn('The `queryset` parameter to `get_paginate_by()` '
- 'is deprecated.',
- DeprecationWarning, stacklevel=2)
-
if self.paginate_by_param:
try:
return strict_positive_int(
--
cgit v1.2.3
From 371d30aa8737c4b3aaf28ee10cc2b77a9c4d1fd9 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 29 Aug 2014 12:54:52 +0100
Subject: Remove unused imports.
---
rest_framework/decorators.py | 1 -
rest_framework/generics.py | 1 -
2 files changed, 2 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index cc5d92c2..d28d6e22 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -10,7 +10,6 @@ from __future__ import unicode_literals
from django.utils import six
from rest_framework.views import APIView
import types
-import warnings
def api_view(http_method_names):
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 40c49844..b3bd6ce9 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -11,7 +11,6 @@ from django.utils.translation import ugettext as _
from rest_framework import views, mixins, exceptions
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
-import warnings
def strict_positive_int(integer_string, cutoff=None):
--
cgit v1.2.3
From 4ac4676a40b121d27cfd1173ff548d96b8d3de2f Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 29 Aug 2014 16:46:26 +0100
Subject: First pass
---
rest_framework/fields.py | 1166 +++++++-------------------------------
rest_framework/generics.py | 10 +-
rest_framework/mixins.py | 24 +-
rest_framework/pagination.py | 22 +-
rest_framework/relations.py | 486 ----------------
rest_framework/renderers.py | 10 +-
rest_framework/serializers.py | 1096 ++++++++++-------------------------
rest_framework/utils/encoders.py | 18 +-
rest_framework/utils/html.py | 86 +++
9 files changed, 632 insertions(+), 2286 deletions(-)
create mode 100644 rest_framework/utils/html.py
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 9d707c9b..a83bf94c 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,1038 +1,308 @@
-"""
-Serializer fields perform validation on incoming data.
+from rest_framework.utils import html
-They are very similar to Django's form fields.
-"""
-from __future__ import unicode_literals
-import copy
-import datetime
-import inspect
-import re
-import warnings
-from decimal import Decimal, DecimalException
-from django import forms
-from django.core import validators
-from django.core.exceptions import ValidationError
-from django.conf import settings
-from django.db.models.fields import BLANK_CHOICE_DASH
-from django.http import QueryDict
-from django.forms import widgets
-from django.utils import six, timezone
-from django.utils.encoding import is_protected_type
-from django.utils.translation import ugettext_lazy as _
-from django.utils.datastructures import SortedDict
-from django.utils.dateparse import parse_date, parse_datetime, parse_time
-from rest_framework import ISO_8601
-from rest_framework.compat import (
- BytesIO, smart_text,
- force_text, is_non_str_iterable
-)
-from rest_framework.settings import api_settings
-
-
-def is_simple_callable(obj):
+class empty:
"""
- True if the object is a callable that takes no arguments.
- """
- function = inspect.isfunction(obj)
- method = inspect.ismethod(obj)
-
- if not (function or method):
- return False
+ This class is used to represent no data being provided for a given input
+ or output value.
- args, _, _, defaults = inspect.getargspec(obj)
- len_args = len(args) if function else len(args) - 1
- len_defaults = len(defaults) if defaults else 0
- return len_args <= len_defaults
+ It is required because `None` may be a valid input or output value.
+ """
+ pass
-def get_component(obj, attr_name):
+def get_attribute(instance, attrs):
"""
- Given an object, and an attribute name,
- return that attribute on the object.
+ Similar to Python's built in `getattr(instance, attr)`,
+ but takes a list of nested attributes, instead of a single attribute.
"""
- if isinstance(obj, dict):
- val = obj.get(attr_name)
- else:
- val = getattr(obj, attr_name)
-
- if is_simple_callable(val):
- return val()
- return val
-
-
-def readable_datetime_formats(formats):
- format = ', '.join(formats).replace(
- ISO_8601,
- 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
- )
- return humanize_strptime(format)
-
-
-def readable_date_formats(formats):
- format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
- return humanize_strptime(format)
+ for attr in attrs:
+ instance = getattr(instance, attr)
+ return instance
-def readable_time_formats(formats):
- format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
- return humanize_strptime(format)
-
-
-def humanize_strptime(format_string):
- # Note that we're missing some of the locale specific mappings that
- # don't really make sense.
- mapping = {
- "%Y": "YYYY",
- "%y": "YY",
- "%m": "MM",
- "%b": "[Jan-Dec]",
- "%B": "[January-December]",
- "%d": "DD",
- "%H": "hh",
- "%I": "hh", # Requires '%p' to differentiate from '%H'.
- "%M": "mm",
- "%S": "ss",
- "%f": "uuuuuu",
- "%a": "[Mon-Sun]",
- "%A": "[Monday-Sunday]",
- "%p": "[AM|PM]",
- "%z": "[+HHMM|-HHMM]"
- }
- for key, val in mapping.items():
- format_string = format_string.replace(key, val)
- return format_string
-
-
-def strip_multiple_choice_msg(help_text):
+def set_value(dictionary, keys, value):
"""
- Remove the 'Hold down "control" ...' message that is Django enforces in
- select multiple fields on ModelForms. (Required for 1.5 and earlier)
+ Similar to Python's built in `dictionary[key] = value`,
+ but takes a list of nested keys instead of a single key.
- See https://code.djangoproject.com/ticket/9321
+ set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2}
+ set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2}
+ set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}}
"""
- multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.')
- multiple_choice_msg = force_text(multiple_choice_msg)
+ if not keys:
+ dictionary.update(value)
+ return
- return help_text.replace(multiple_choice_msg, '')
+ for key in keys[:-1]:
+ if key not in dictionary:
+ dictionary[key] = {}
+ dictionary = dictionary[key]
+ dictionary[keys[-1]] = value
-class Field(object):
- read_only = True
- creation_counter = 0
- empty = ''
- type_name = None
- partial = False
- use_files = False
- form_field_class = forms.CharField
- type_label = 'field'
- widget = None
- def __init__(self, source=None, label=None, help_text=None):
- self.parent = None
+class ValidationError(Exception):
+ pass
- self.creation_counter = Field.creation_counter
- Field.creation_counter += 1
- self.source = source
+class SkipField(Exception):
+ pass
- if label is not None:
- self.label = smart_text(label)
- else:
- self.label = None
- if help_text is not None:
- self.help_text = strip_multiple_choice_msg(smart_text(help_text))
- else:
- self.help_text = None
+class Field(object):
+ _creation_counter = 0
- self._errors = []
- self._value = None
- self._name = None
+ MESSAGES = {
+ 'required': 'This field is required.'
+ }
- @property
- def errors(self):
- return self._errors
+ _NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
+ _NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
+ _NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
+ _NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
+ _MISSING_ERROR_MESSAGE = (
+ 'ValidationError raised by `{class_name}`, but error key `{key}` does '
+ 'not exist in the `MESSAGES` dictionary.'
+ )
- def widget_html(self):
- if not self.widget:
- return ''
+ def __init__(self, read_only=False, write_only=False,
+ required=None, default=empty, initial=None, source=None,
+ label=None, style=None):
+ self._creation_counter = Field._creation_counter
+ Field._creation_counter += 1
- attrs = {}
- if 'id' not in self.widget.attrs:
- attrs['id'] = self._name
+ # If `required` is unset, then use `True` unless a default is provided.
+ if required is None:
+ required = default is empty and not read_only
- return self.widget.render(self._name, self._value, attrs=attrs)
+ # Some combinations of keyword arguments do not make sense.
+ assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY
+ assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED
+ assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT
+ assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT
- def label_tag(self):
- return '' % (self._name, self.label)
+ self.read_only = read_only
+ self.write_only = write_only
+ self.required = required
+ self.default = default
+ self.source = source
+ self.initial = initial
+ self.label = label
+ self.style = {} if style is None else style
- def initialize(self, parent, field_name):
+ def bind(self, field_name, parent, root):
"""
- Called to set up a field prior to field_to_native or field_from_native.
-
- parent - The parent serializer.
- field_name - The name of the field being initialized.
+ Setup the context for the field instance.
"""
+ self.field_name = field_name
self.parent = parent
- self.root = parent.root or parent
- self.context = self.root.context
- self.partial = self.root.partial
- if self.partial:
- self.required = False
+ self.root = root
- def field_from_native(self, data, files, field_name, into):
- """
- Given a dictionary and a field name, updates the dictionary `into`,
- with the field and it's deserialized value.
- """
- return
+ # `self.label` should deafult to being based on the field name.
+ if self.label is None:
+ self.label = self.field_name.replace('_', ' ').capitalize()
- def field_to_native(self, obj, field_name):
- """
- Given an object and a field name, returns the value that should be
- serialized for that field.
- """
- if obj is None:
- return self.empty
+ # self.source should default to being the same as the field name.
+ if self.source is None:
+ self.source = field_name
+ # self.source_attrs is a list of attributes that need to be looked up
+ # when serializing the instance, or populating the validated data.
if self.source == '*':
- return self.to_native(obj)
-
- source = self.source or field_name
- value = obj
-
- for component in source.split('.'):
- value = get_component(value, component)
- if value is None:
- break
-
- return self.to_native(value)
+ self.source_attrs = []
+ else:
+ self.source_attrs = self.source.split('.')
- def to_native(self, value):
+ def get_initial(self):
"""
- Converts the field's value into it's simple representation.
+ Return a value to use when the field is being returned as a primative
+ value, without any object instance.
"""
- if is_simple_callable(value):
- value = value()
-
- if is_protected_type(value):
- return value
- elif (is_non_str_iterable(value) and
- not isinstance(value, (dict, six.string_types))):
- return [self.to_native(item) for item in value]
- elif isinstance(value, dict):
- # Make sure we preserve field ordering, if it exists
- ret = SortedDict()
- for key, val in value.items():
- ret[key] = self.to_native(val)
- return ret
- return force_text(value)
+ return self.initial
- def attributes(self):
+ def get_value(self, dictionary):
"""
- Returns a dictionary of attributes to be used when serializing to xml.
+ Given the *incoming* primative data, return the value for this field
+ that should be validated and transformed to a native value.
"""
- if self.type_name:
- return {'type': self.type_name}
- return {}
-
- def metadata(self):
- metadata = SortedDict()
- metadata['type'] = self.type_label
- metadata['required'] = getattr(self, 'required', False)
- optional_attrs = ['read_only', 'label', 'help_text',
- 'min_length', 'max_length']
- for attr in optional_attrs:
- value = getattr(self, attr, None)
- if value is not None and value != '':
- metadata[attr] = force_text(value, strings_only=True)
- return metadata
-
-
-class WritableField(Field):
- """
- Base for read/write fields.
- """
- write_only = False
- default_validators = []
- default_error_messages = {
- 'required': _('This field is required.'),
- 'invalid': _('Invalid value.'),
- }
- widget = widgets.TextInput
- default = None
-
- def __init__(self, source=None, label=None, help_text=None,
- read_only=False, write_only=False, required=None,
- validators=[], error_messages=None, widget=None,
- default=None, blank=None):
-
- super(WritableField, self).__init__(source=source, label=label, help_text=help_text)
-
- self.read_only = read_only
- self.write_only = write_only
-
- assert not (read_only and write_only), "Cannot set read_only=True and write_only=True"
-
- if required is None:
- self.required = not(read_only)
- else:
- assert not (read_only and required), "Cannot set required=True and read_only=True"
- self.required = required
-
- messages = {}
- for c in reversed(self.__class__.__mro__):
- messages.update(getattr(c, 'default_error_messages', {}))
- messages.update(error_messages or {})
- self.error_messages = messages
+ return dictionary.get(self.field_name, empty)
- self.validators = self.default_validators + validators
- self.default = default if default is not None else self.default
-
- # Widgets are only used for HTML forms.
- widget = widget or self.widget
- if isinstance(widget, type):
- widget = widget()
- self.widget = widget
+ def get_attribute(self, instance):
+ """
+ Given the *outgoing* object instance, return the value for this field
+ that should be returned as a primative value.
+ """
+ return get_attribute(instance, self.source_attrs)
- def __deepcopy__(self, memo):
- result = copy.copy(self)
- memo[id(self)] = result
- result.validators = self.validators[:]
- return result
+ def get_default(self):
+ """
+ Return the default value to use when validating data if no input
+ is provided for this field.
- def get_default_value(self):
- if is_simple_callable(self.default):
- return self.default()
+ If a default has not been set for this field then this will simply
+ return `empty`, indicating that no value should be set in the
+ validated data for this field.
+ """
+ if self.default is empty:
+ raise SkipField()
return self.default
- def validate(self, value):
- if value in validators.EMPTY_VALUES and self.required:
- raise ValidationError(self.error_messages['required'])
+ def validate(self, data=empty):
+ """
+ Validate a simple representation and return the internal value.
- def run_validators(self, value):
- if value in validators.EMPTY_VALUES:
- return
- errors = []
- for v in self.validators:
- try:
- v(value)
- except ValidationError as e:
- if hasattr(e, 'code') and e.code in self.error_messages:
- message = self.error_messages[e.code]
- if e.params:
- message = message % e.params
- errors.append(message)
- else:
- errors.extend(e.messages)
- if errors:
- raise ValidationError(errors)
+ The provided data may be `empty` if no representation was included.
+ May return `empty` if the field should not be included in the
+ validated data.
+ """
+ if data is empty:
+ if self.required:
+ self.fail('required')
+ return self.get_default()
- def field_to_native(self, obj, field_name):
- if self.write_only:
- return None
- return super(WritableField, self).field_to_native(obj, field_name)
+ return self.to_native(data)
- def field_from_native(self, data, files, field_name, into):
+ def to_native(self, data):
"""
- Given a dictionary and a field name, updates the dictionary `into`,
- with the field and it's deserialized value.
+ Transform the *incoming* primative data into a native value.
"""
- if self.read_only:
- return
-
- try:
- data = data or {}
- if self.use_files:
- files = files or {}
- try:
- native = files[field_name]
- except KeyError:
- native = data[field_name]
- else:
- native = data[field_name]
- except KeyError:
- if self.default is not None and not self.partial:
- # Note: partial updates shouldn't set defaults
- native = self.get_default_value()
- else:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- return
-
- value = self.from_native(native)
- if self.source == '*':
- if value:
- into.update(value)
- else:
- self.validate(value)
- self.run_validators(value)
- into[self.source or field_name] = value
+ return data
- def from_native(self, value):
+ def to_primative(self, value):
"""
- Reverts a simple representation back to the field's value.
+ Transform the *outgoing* native value into primative data.
"""
return value
-
-class ModelField(WritableField):
- """
- A generic field that can be used against an arbitrary model field.
- """
- def __init__(self, *args, **kwargs):
+ def fail(self, key, **kwargs):
+ """
+ A helper method that simply raises a validation error.
+ """
try:
- self.model_field = kwargs.pop('model_field')
+ raise ValidationError(self.MESSAGES[key].format(**kwargs))
except KeyError:
- raise ValueError("ModelField requires 'model_field' kwarg")
-
- self.min_length = kwargs.pop('min_length',
- getattr(self.model_field, 'min_length', None))
- self.max_length = kwargs.pop('max_length',
- getattr(self.model_field, 'max_length', None))
- self.min_value = kwargs.pop('min_value',
- getattr(self.model_field, 'min_value', None))
- self.max_value = kwargs.pop('max_value',
- getattr(self.model_field, 'max_value', None))
-
- super(ModelField, self).__init__(*args, **kwargs)
-
- if self.min_length is not None:
- self.validators.append(validators.MinLengthValidator(self.min_length))
- if self.max_length is not None:
- self.validators.append(validators.MaxLengthValidator(self.max_length))
- if self.min_value is not None:
- self.validators.append(validators.MinValueValidator(self.min_value))
- if self.max_value is not None:
- self.validators.append(validators.MaxValueValidator(self.max_value))
-
- def from_native(self, value):
- rel = getattr(self.model_field, "rel", None)
- if rel is not None:
- return rel.to._meta.get_field(rel.field_name).to_python(value)
- else:
- return self.model_field.to_python(value)
+ class_name = self.__class__.__name__
+ msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
+ raise AssertionError(msg)
- def field_to_native(self, obj, field_name):
- value = self.model_field._get_val_from_obj(obj)
- if is_protected_type(value):
- return value
- return self.model_field.value_to_string(obj)
- def attributes(self):
- return {
- "type": self.model_field.get_internal_type()
- }
-
-
-# Typed Fields
-
-class BooleanField(WritableField):
- type_name = 'BooleanField'
- type_label = 'boolean'
- form_field_class = forms.BooleanField
- widget = widgets.CheckboxInput
- default_error_messages = {
- 'invalid': _("'%s' value must be either True or False."),
+class BooleanField(Field):
+ MESSAGES = {
+ 'required': 'This field is required.',
+ 'invalid_value': '`{input}` is not a valid boolean.'
}
- empty = False
-
- def field_from_native(self, data, files, field_name, into):
- # HTML checkboxes do not explicitly represent unchecked as `False`
- # we deal with that here...
- if isinstance(data, QueryDict) and self.default is None:
- self.default = False
-
- return super(BooleanField, self).field_from_native(
- data, files, field_name, into
- )
-
- def from_native(self, value):
- if value in ('true', 't', 'True', '1'):
+ TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True}
+ FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False}
+
+ def get_value(self, dictionary):
+ if html.is_html_input(dictionary):
+ # HTML forms do not send a `False` value on an empty checkbox,
+ # so we override the default empty value to be False.
+ return dictionary.get(self.field_name, False)
+ return dictionary.get(self.field_name, empty)
+
+ def to_native(self, data):
+ if data in self.TRUE_VALUES:
return True
- if value in ('false', 'f', 'False', '0'):
+ elif data in self.FALSE_VALUES:
return False
- return bool(value)
+ self.fail('invalid_value', input=data)
-class CharField(WritableField):
- type_name = 'CharField'
- type_label = 'string'
- form_field_class = forms.CharField
+class CharField(Field):
+ MESSAGES = {
+ 'required': 'This field is required.',
+ 'blank': 'This field may not be blank.'
+ }
- def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs):
- self.max_length, self.min_length = max_length, min_length
- self.allow_none = allow_none
+ def __init__(self, *args, **kwargs):
+ self.allow_blank = kwargs.pop('allow_blank', False)
super(CharField, self).__init__(*args, **kwargs)
- if min_length is not None:
- self.validators.append(validators.MinLengthValidator(min_length))
- if max_length is not None:
- self.validators.append(validators.MaxLengthValidator(max_length))
-
- def from_native(self, value):
- if isinstance(value, six.string_types):
- return value
-
- if value is None and not self.allow_none:
- return ''
-
- return smart_text(value)
-
-class URLField(CharField):
- type_name = 'URLField'
- type_label = 'url'
-
- def __init__(self, **kwargs):
- if 'validators' not in kwargs:
- kwargs['validators'] = [validators.URLValidator()]
- super(URLField, self).__init__(**kwargs)
+ def to_native(self, data):
+ if data == '' and not self.allow_blank:
+ self.fail('blank')
+ return str(data)
-class SlugField(CharField):
- type_name = 'SlugField'
- type_label = 'slug'
- form_field_class = forms.SlugField
-
- default_error_messages = {
- 'invalid': _("Enter a valid 'slug' consisting of letters, numbers,"
- " underscores or hyphens."),
+class ChoiceField(Field):
+ MESSAGES = {
+ 'required': 'This field is required.',
+ 'invalid_choice': '`{input}` is not a valid choice.'
}
- default_validators = [validators.validate_slug]
+ coerce_to_type = str
def __init__(self, *args, **kwargs):
- super(SlugField, self).__init__(*args, **kwargs)
-
+ choices = kwargs.pop('choices')
+
+ assert choices, '`choices` argument is required and may not be empty'
+
+ # Allow either single or paired choices style:
+ # choices = [1, 2, 3]
+ # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
+ pairs = [
+ isinstance(item, (list, tuple)) and len(item) == 2
+ for item in choices
+ ]
+ if all(pairs):
+ self.choices = {key: val for key, val in choices}
+ else:
+ self.choices = {item: item for item in choices}
-class ChoiceField(WritableField):
- type_name = 'ChoiceField'
- type_label = 'choice'
- form_field_class = forms.ChoiceField
- widget = widgets.Select
- default_error_messages = {
- 'invalid_choice': _('Select a valid choice. %(value)s is not one of '
- 'the available choices.'),
- }
+ # Map the string representation of choices to the underlying value.
+ # Allows us to deal with eg. integer choices while supporting either
+ # integer or string input, but still get the correct datatype out.
+ self.choice_strings_to_values = {
+ str(key): key for key in self.choices.keys()
+ }
- def __init__(self, choices=(), blank_display_value=None, *args, **kwargs):
- self.empty = kwargs.pop('empty', '')
super(ChoiceField, self).__init__(*args, **kwargs)
- self.choices = choices
- if not self.required:
- if blank_display_value is None:
- blank_choice = BLANK_CHOICE_DASH
- else:
- blank_choice = [('', blank_display_value)]
- self.choices = blank_choice + self.choices
-
- def _get_choices(self):
- return self._choices
-
- def _set_choices(self, value):
- # Setting choices also sets the choices on the widget.
- # choices can be any iterable, but we call list() on it because
- # it will be consumed more than once.
- self._choices = self.widget.choices = list(value)
-
- choices = property(_get_choices, _set_choices)
-
- def metadata(self):
- data = super(ChoiceField, self).metadata()
- data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices]
- return data
-
- def validate(self, value):
- """
- Validates that the input is in self.choices.
- """
- super(ChoiceField, self).validate(value)
- if value and not self.valid_value(value):
- raise ValidationError(self.error_messages['invalid_choice'] % {'value': value})
-
- def valid_value(self, value):
- """
- Check to see if the provided value is a valid choice.
- """
- for k, v in self.choices:
- if isinstance(v, (list, tuple)):
- # This is an optgroup, so look inside the group for options
- for k2, v2 in v:
- if value == smart_text(k2):
- return True
- else:
- if value == smart_text(k) or value == k:
- return True
- return False
-
- def from_native(self, value):
- value = super(ChoiceField, self).from_native(value)
- if value == self.empty or value in validators.EMPTY_VALUES:
- return self.empty
- return value
-
-
-class EmailField(CharField):
- type_name = 'EmailField'
- type_label = 'email'
- form_field_class = forms.EmailField
-
- default_error_messages = {
- 'invalid': _('Enter a valid email address.'),
- }
- default_validators = [validators.validate_email]
-
- def from_native(self, value):
- ret = super(EmailField, self).from_native(value)
- if ret is None:
- return None
- return ret.strip()
-
-
-class RegexField(CharField):
- type_name = 'RegexField'
- type_label = 'regex'
- form_field_class = forms.RegexField
-
- def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs):
- super(RegexField, self).__init__(max_length, min_length, *args, **kwargs)
- self.regex = regex
-
- def _get_regex(self):
- return self._regex
-
- def _set_regex(self, regex):
- if isinstance(regex, six.string_types):
- regex = re.compile(regex)
- self._regex = regex
- if hasattr(self, '_regex_validator') and self._regex_validator in self.validators:
- self.validators.remove(self._regex_validator)
- self._regex_validator = validators.RegexValidator(regex=regex)
- self.validators.append(self._regex_validator)
-
- regex = property(_get_regex, _set_regex)
-
-
-class DateField(WritableField):
- type_name = 'DateField'
- type_label = 'date'
- widget = widgets.DateInput
- form_field_class = forms.DateField
-
- default_error_messages = {
- 'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
- }
- empty = None
- input_formats = api_settings.DATE_INPUT_FORMATS
- format = api_settings.DATE_FORMAT
-
- def __init__(self, input_formats=None, format=None, *args, **kwargs):
- self.input_formats = input_formats if input_formats is not None else self.input_formats
- self.format = format if format is not None else self.format
- super(DateField, self).__init__(*args, **kwargs)
-
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
-
- if isinstance(value, datetime.datetime):
- if timezone and settings.USE_TZ and timezone.is_aware(value):
- # Convert aware datetimes to the default time zone
- # before casting them to dates (#17742).
- default_timezone = timezone.get_default_timezone()
- value = timezone.make_naive(value, default_timezone)
- return value.date()
- if isinstance(value, datetime.date):
- return value
-
- for format in self.input_formats:
- if format.lower() == ISO_8601:
- try:
- parsed = parse_date(value)
- except (ValueError, TypeError):
- pass
- else:
- if parsed is not None:
- return parsed
- else:
- try:
- parsed = datetime.datetime.strptime(value, format)
- except (ValueError, TypeError):
- pass
- else:
- return parsed.date()
-
- msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats)
- raise ValidationError(msg)
-
- def to_native(self, value):
- if value is None or self.format is None:
- return value
-
- if isinstance(value, datetime.datetime):
- value = value.date()
-
- if self.format.lower() == ISO_8601:
- return value.isoformat()
- return value.strftime(self.format)
-
-
-class DateTimeField(WritableField):
- type_name = 'DateTimeField'
- type_label = 'datetime'
- widget = widgets.DateTimeInput
- form_field_class = forms.DateTimeField
-
- default_error_messages = {
- 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
- }
- empty = None
- input_formats = api_settings.DATETIME_INPUT_FORMATS
- format = api_settings.DATETIME_FORMAT
-
- def __init__(self, input_formats=None, format=None, *args, **kwargs):
- self.input_formats = input_formats if input_formats is not None else self.input_formats
- self.format = format if format is not None else self.format
- super(DateTimeField, self).__init__(*args, **kwargs)
-
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
-
- if isinstance(value, datetime.datetime):
- return value
- if isinstance(value, datetime.date):
- value = datetime.datetime(value.year, value.month, value.day)
- if settings.USE_TZ:
- # For backwards compatibility, interpret naive datetimes in
- # local time. This won't work during DST change, but we can't
- # do much about it, so we let the exceptions percolate up the
- # call stack.
- warnings.warn("DateTimeField received a naive datetime (%s)"
- " while time zone support is active." % value,
- RuntimeWarning)
- default_timezone = timezone.get_default_timezone()
- value = timezone.make_aware(value, default_timezone)
- return value
-
- for format in self.input_formats:
- if format.lower() == ISO_8601:
- try:
- parsed = parse_datetime(value)
- except (ValueError, TypeError):
- pass
- else:
- if parsed is not None:
- return parsed
- else:
- try:
- parsed = datetime.datetime.strptime(value, format)
- except (ValueError, TypeError):
- pass
- else:
- return parsed
-
- msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats)
- raise ValidationError(msg)
-
- def to_native(self, value):
- if value is None or self.format is None:
- return value
-
- if self.format.lower() == ISO_8601:
- ret = value.isoformat()
- if ret.endswith('+00:00'):
- ret = ret[:-6] + 'Z'
- return ret
- return value.strftime(self.format)
-
-
-class TimeField(WritableField):
- type_name = 'TimeField'
- type_label = 'time'
- widget = widgets.TimeInput
- form_field_class = forms.TimeField
-
- default_error_messages = {
- 'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
- }
- empty = None
- input_formats = api_settings.TIME_INPUT_FORMATS
- format = api_settings.TIME_FORMAT
-
- def __init__(self, input_formats=None, format=None, *args, **kwargs):
- self.input_formats = input_formats if input_formats is not None else self.input_formats
- self.format = format if format is not None else self.format
- super(TimeField, self).__init__(*args, **kwargs)
-
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
-
- if isinstance(value, datetime.time):
- return value
-
- for format in self.input_formats:
- if format.lower() == ISO_8601:
- try:
- parsed = parse_time(value)
- except (ValueError, TypeError):
- pass
- else:
- if parsed is not None:
- return parsed
- else:
- try:
- parsed = datetime.datetime.strptime(value, format)
- except (ValueError, TypeError):
- pass
- else:
- return parsed.time()
-
- msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats)
- raise ValidationError(msg)
-
- def to_native(self, value):
- if value is None or self.format is None:
- return value
-
- if isinstance(value, datetime.datetime):
- value = value.time()
-
- if self.format.lower() == ISO_8601:
- return value.isoformat()
- return value.strftime(self.format)
-
-
-class IntegerField(WritableField):
- type_name = 'IntegerField'
- type_label = 'integer'
- form_field_class = forms.IntegerField
- empty = 0
-
- default_error_messages = {
- 'invalid': _('Enter a whole number.'),
- 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
- 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
- }
-
- def __init__(self, max_value=None, min_value=None, *args, **kwargs):
- self.max_value, self.min_value = max_value, min_value
- super(IntegerField, self).__init__(*args, **kwargs)
-
- if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
- if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
-
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
-
- try:
- value = int(str(value))
- except (ValueError, TypeError):
- raise ValidationError(self.error_messages['invalid'])
- return value
-
-
-class FloatField(WritableField):
- type_name = 'FloatField'
- type_label = 'float'
- form_field_class = forms.FloatField
- empty = 0
-
- default_error_messages = {
- 'invalid': _("'%s' value must be a float."),
- }
-
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
+ def to_native(self, data):
try:
- return float(value)
- except (TypeError, ValueError):
- msg = self.error_messages['invalid'] % value
- raise ValidationError(msg)
-
+ return self.choice_strings_to_values[str(data)]
+ except KeyError:
+ self.fail('invalid_choice', input=data)
-class DecimalField(WritableField):
- type_name = 'DecimalField'
- type_label = 'decimal'
- form_field_class = forms.DecimalField
- empty = Decimal('0')
- default_error_messages = {
- 'invalid': _('Enter a number.'),
- 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
- 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
- 'max_digits': _('Ensure that there are no more than %s digits in total.'),
- 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'),
- 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.')
+class MultipleChoiceField(ChoiceField):
+ MESSAGES = {
+ 'required': 'This field is required.',
+ 'invalid_choice': '`{input}` is not a valid choice.',
+ 'not_a_list': 'Expected a list of items but got type `{input_type}`'
}
- def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs):
- self.max_value, self.min_value = max_value, min_value
- self.max_digits, self.decimal_places = max_digits, decimal_places
- super(DecimalField, self).__init__(*args, **kwargs)
-
- if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
- if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
-
- def from_native(self, value):
- """
- Validates that the input is a decimal number. Returns a Decimal
- instance. Returns None for empty values. Ensures that there are no more
- than max_digits in the number, and no more than decimal_places digits
- after the decimal point.
- """
- if value in validators.EMPTY_VALUES:
- return None
- value = smart_text(value).strip()
- try:
- value = Decimal(value)
- except DecimalException:
- raise ValidationError(self.error_messages['invalid'])
- return value
-
- def validate(self, value):
- super(DecimalField, self).validate(value)
- if value in validators.EMPTY_VALUES:
- return
- # Check for NaN, Inf and -Inf values. We can't compare directly for NaN,
- # since it is never equal to itself. However, NaN is the only value that
- # isn't equal to itself, so we can use this to identify NaN
- if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
- raise ValidationError(self.error_messages['invalid'])
- sign, digittuple, exponent = value.as_tuple()
- decimals = abs(exponent)
- # digittuple doesn't include any leading zeros.
- digits = len(digittuple)
- if decimals > digits:
- # We have leading zeros up to or past the decimal point. Count
- # everything past the decimal point as a digit. We do not count
- # 0 before the decimal point as a digit since that would mean
- # we would not allow max_digits = decimal_places.
- digits = decimals
- whole_digits = digits - decimals
-
- if self.max_digits is not None and digits > self.max_digits:
- raise ValidationError(self.error_messages['max_digits'] % self.max_digits)
- if self.decimal_places is not None and decimals > self.decimal_places:
- raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places)
- if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
- raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places))
- return value
-
+ def to_native(self, data):
+ if not hasattr(data, '__iter__'):
+ self.fail('not_a_list', input_type=type(data).__name__)
+ return set([
+ super(MultipleChoiceField, self).to_native(item)
+ for item in data
+ ])
-class FileField(WritableField):
- use_files = True
- type_name = 'FileField'
- type_label = 'file upload'
- form_field_class = forms.FileField
- widget = widgets.FileInput
- default_error_messages = {
- 'invalid': _("No file was submitted. Check the encoding type on the form."),
- 'missing': _("No file was submitted."),
- 'empty': _("The submitted file is empty."),
- 'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'),
- 'contradiction': _('Please either submit a file or check the clear checkbox, not both.')
+class IntegerField(Field):
+ MESSAGES = {
+ 'required': 'This field is required.',
+ 'invalid_integer': 'A valid integer is required.'
}
- def __init__(self, *args, **kwargs):
- self.max_length = kwargs.pop('max_length', None)
- self.allow_empty_file = kwargs.pop('allow_empty_file', False)
- super(FileField, self).__init__(*args, **kwargs)
-
- def from_native(self, data):
- if data in validators.EMPTY_VALUES:
- return None
-
- # UploadedFile objects should have name and size attributes.
+ def to_native(self, data):
try:
- file_name = data.name
- file_size = data.size
- except AttributeError:
- raise ValidationError(self.error_messages['invalid'])
-
- if self.max_length is not None and len(file_name) > self.max_length:
- error_values = {'max': self.max_length, 'length': len(file_name)}
- raise ValidationError(self.error_messages['max_length'] % error_values)
- if not file_name:
- raise ValidationError(self.error_messages['invalid'])
- if not self.allow_empty_file and not file_size:
- raise ValidationError(self.error_messages['empty'])
-
+ data = int(str(data))
+ except (ValueError, TypeError):
+ self.fail('invalid_integer')
return data
- def to_native(self, value):
- return value.name
-
-
-class ImageField(FileField):
- use_files = True
- type_name = 'ImageField'
- type_label = 'image upload'
- form_field_class = forms.ImageField
-
- default_error_messages = {
- 'invalid_image': _("Upload a valid image. The file you uploaded was "
- "either not an image or a corrupted image."),
- }
-
- def from_native(self, data):
- """
- Checks that the file-upload field data contains a valid image (GIF, JPG,
- PNG, possibly others -- whatever the Python Imaging Library supports).
- """
- f = super(ImageField, self).from_native(data)
- if f is None:
- return None
-
- from rest_framework.compat import Image
- assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.'
-
- # We need to get a file object for PIL. We might have a path or we might
- # have to read the data into memory.
- if hasattr(data, 'temporary_file_path'):
- file = data.temporary_file_path()
- else:
- if hasattr(data, 'read'):
- file = BytesIO(data.read())
- else:
- file = BytesIO(data['content'])
-
- try:
- # load() could spot a truncated JPEG, but it loads the entire
- # image in memory, which is a DoS vector. See #3848 and #18520.
- # verify() must be called immediately after the constructor.
- Image.open(file).verify()
- except ImportError:
- # Under PyPy, it is possible to import PIL. However, the underlying
- # _imaging C module isn't available, so an ImportError will be
- # raised. Catch and re-raise.
- raise
- except Exception: # Python Imaging Library doesn't recognize it as an image
- raise ValidationError(self.error_messages['invalid_image'])
- if hasattr(f, 'seek') and callable(f.seek):
- f.seek(0)
- return f
-
-class SerializerMethodField(Field):
- """
- A field that gets its value by calling a method on the serializer it's attached to.
- """
-
- def __init__(self, method_name, *args, **kwargs):
- self.method_name = method_name
- super(SerializerMethodField, self).__init__(*args, **kwargs)
-
- def field_to_native(self, obj, field_name):
- value = getattr(self.parent, self.method_name)(obj)
- return self.to_native(value)
+class MethodField(Field):
+ def __init__(self, **kwargs):
+ kwargs['source'] = '*'
+ kwargs['read_only'] = True
+ super(MethodField, self).__init__(**kwargs)
+
+ def to_primative(self, value):
+ attr = 'get_{field_name}'.format(field_name=self.field_name)
+ method = getattr(self.parent, attr)
+ return method(value)
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index b3bd6ce9..6705cbb2 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -79,18 +79,16 @@ class GenericAPIView(views.APIView):
'view': self
}
- def get_serializer(self, instance=None, data=None, files=None, many=False,
- partial=False, allow_add_remove=False):
+ def get_serializer(self, instance=None, data=None, many=False, partial=False):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
"""
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
- return serializer_class(instance, data=data, files=files,
- many=many, partial=partial,
- allow_add_remove=allow_add_remove,
- context=context)
+ return serializer_class(
+ instance, data=data, many=many, partial=partial, context=context
+ )
def get_pagination_serializer(self, page):
"""
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index ac59d979..ee01cabc 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -36,12 +36,10 @@ class CreateModelMixin(object):
Create a model instance.
"""
def create(self, request, *args, **kwargs):
- serializer = self.get_serializer(data=request.DATA, files=request.FILES)
+ serializer = self.get_serializer(data=request.DATA)
if serializer.is_valid():
- self.pre_save(serializer.object)
- self.object = serializer.save(force_insert=True)
- self.post_save(self.object, created=True)
+ self.object = serializer.save()
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED,
headers=headers)
@@ -90,26 +88,20 @@ class UpdateModelMixin(object):
partial = kwargs.pop('partial', False)
self.object = self.get_object_or_none()
- serializer = self.get_serializer(self.object, data=request.DATA,
- files=request.FILES, partial=partial)
+ serializer = self.get_serializer(self.object, data=request.DATA, partial=partial)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- try:
- self.pre_save(serializer.object)
- except ValidationError as err:
- # full_clean on model instance may be called in pre_save,
- # so we have to handle eventual errors.
- return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST)
+ lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
+ lookup_value = self.kwargs[lookup_url_kwarg]
+ extras = {self.lookup_field: lookup_value}
if self.object is None:
- self.object = serializer.save(force_insert=True)
- self.post_save(self.object, created=True)
+ self.object = serializer.save(extras=extras)
return Response(serializer.data, status=status.HTTP_201_CREATED)
- self.object = serializer.save(force_update=True)
- self.post_save(self.object, created=False)
+ self.object = serializer.save(extras=extras)
return Response(serializer.data, status=status.HTTP_200_OK)
def partial_update(self, request, *args, **kwargs):
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index d51ea929..83ef97c5 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -48,17 +48,17 @@ class DefaultObjectSerializer(serializers.Field):
super(DefaultObjectSerializer, self).__init__(source=source)
-class PaginationSerializerOptions(serializers.SerializerOptions):
- """
- An object that stores the options that may be provided to a
- pagination serializer by using the inner `Meta` class.
+# class PaginationSerializerOptions(serializers.SerializerOptions):
+# """
+# An object that stores the options that may be provided to a
+# pagination serializer by using the inner `Meta` class.
- Accessible on the instance as `serializer.opts`.
- """
- def __init__(self, meta):
- super(PaginationSerializerOptions, self).__init__(meta)
- self.object_serializer_class = getattr(meta, 'object_serializer_class',
- DefaultObjectSerializer)
+# Accessible on the instance as `serializer.opts`.
+# """
+# def __init__(self, meta):
+# super(PaginationSerializerOptions, self).__init__(meta)
+# self.object_serializer_class = getattr(meta, 'object_serializer_class',
+# DefaultObjectSerializer)
class BasePaginationSerializer(serializers.Serializer):
@@ -66,7 +66,7 @@ class BasePaginationSerializer(serializers.Serializer):
A base class for pagination serializers to inherit from,
to make implementing custom serializers more easy.
"""
- _options_class = PaginationSerializerOptions
+ # _options_class = PaginationSerializerOptions
results_field = 'results'
def __init__(self, *args, **kwargs):
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 56870b40..e69de29b 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -1,486 +0,0 @@
-"""
-Serializer fields that deal with relationships.
-
-These fields allow you to specify the style that should be used to represent
-model relationships, including hyperlinks, primary keys, or slugs.
-"""
-from __future__ import unicode_literals
-from django.core.exceptions import ObjectDoesNotExist, ValidationError
-from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
-from django import forms
-from django.db.models.fields import BLANK_CHOICE_DASH
-from django.forms import widgets
-from django.forms.models import ModelChoiceIterator
-from django.utils.translation import ugettext_lazy as _
-from rest_framework.fields import Field, WritableField, get_component, is_simple_callable
-from rest_framework.reverse import reverse
-from rest_framework.compat import urlparse
-from rest_framework.compat import smart_text
-
-
-# Relational fields
-
-# Not actually Writable, but subclasses may need to be.
-class RelatedField(WritableField):
- """
- Base class for related model fields.
-
- This represents a relationship using the unicode representation of the target.
- """
- widget = widgets.Select
- many_widget = widgets.SelectMultiple
- form_field_class = forms.ChoiceField
- many_form_field_class = forms.MultipleChoiceField
- null_values = (None, '', 'None')
-
- cache_choices = False
- empty_label = None
- read_only = True
- many = False
-
- def __init__(self, *args, **kwargs):
- queryset = kwargs.pop('queryset', None)
- self.many = kwargs.pop('many', self.many)
- if self.many:
- self.widget = self.many_widget
- self.form_field_class = self.many_form_field_class
-
- kwargs['read_only'] = kwargs.pop('read_only', self.read_only)
- super(RelatedField, self).__init__(*args, **kwargs)
-
- if not self.required:
- # Accessed in ModelChoiceIterator django/forms/models.py:1034
- # If set adds empty choice.
- self.empty_label = BLANK_CHOICE_DASH[0][1]
-
- self.queryset = queryset
-
- def initialize(self, parent, field_name):
- super(RelatedField, self).initialize(parent, field_name)
- if self.queryset is None and not self.read_only:
- manager = getattr(self.parent.opts.model, self.source or field_name)
- if hasattr(manager, 'related'): # Forward
- self.queryset = manager.related.model._default_manager.all()
- else: # Reverse
- self.queryset = manager.field.rel.to._default_manager.all()
-
- # We need this stuff to make form choices work...
-
- def prepare_value(self, obj):
- return self.to_native(obj)
-
- def label_from_instance(self, obj):
- """
- Return a readable representation for use with eg. select widgets.
- """
- desc = smart_text(obj)
- ident = smart_text(self.to_native(obj))
- if desc == ident:
- return desc
- return "%s - %s" % (desc, ident)
-
- def _get_queryset(self):
- return self._queryset
-
- def _set_queryset(self, queryset):
- self._queryset = queryset
- self.widget.choices = self.choices
-
- queryset = property(_get_queryset, _set_queryset)
-
- def _get_choices(self):
- # If self._choices is set, then somebody must have manually set
- # the property self.choices. In this case, just return self._choices.
- if hasattr(self, '_choices'):
- return self._choices
-
- # Otherwise, execute the QuerySet in self.queryset to determine the
- # choices dynamically. Return a fresh ModelChoiceIterator that has not been
- # consumed. Note that we're instantiating a new ModelChoiceIterator *each*
- # time _get_choices() is called (and, thus, each time self.choices is
- # accessed) so that we can ensure the QuerySet has not been consumed. This
- # construct might look complicated but it allows for lazy evaluation of
- # the queryset.
- return ModelChoiceIterator(self)
-
- def _set_choices(self, value):
- # Setting choices also sets the choices on the widget.
- # choices can be any iterable, but we call list() on it because
- # it will be consumed more than once.
- self._choices = self.widget.choices = list(value)
-
- choices = property(_get_choices, _set_choices)
-
- # Default value handling
-
- def get_default_value(self):
- default = super(RelatedField, self).get_default_value()
- if self.many and default is None:
- return []
- return default
-
- # Regular serializer stuff...
-
- def field_to_native(self, obj, field_name):
- try:
- if self.source == '*':
- return self.to_native(obj)
-
- source = self.source or field_name
- value = obj
-
- for component in source.split('.'):
- if value is None:
- break
- value = get_component(value, component)
- except ObjectDoesNotExist:
- return None
-
- if value is None:
- return None
-
- if self.many:
- if is_simple_callable(getattr(value, 'all', None)):
- return [self.to_native(item) for item in value.all()]
- else:
- # Also support non-queryset iterables.
- # This allows us to also support plain lists of related items.
- return [self.to_native(item) for item in value]
- return self.to_native(value)
-
- def field_from_native(self, data, files, field_name, into):
- if self.read_only:
- return
-
- try:
- if self.many:
- try:
- # Form data
- value = data.getlist(field_name)
- if value == [''] or value == []:
- raise KeyError
- except AttributeError:
- # Non-form data
- value = data[field_name]
- else:
- value = data[field_name]
- except KeyError:
- if self.partial:
- return
- value = self.get_default_value()
-
- if value in self.null_values:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- into[(self.source or field_name)] = None
- elif self.many:
- into[(self.source or field_name)] = [self.from_native(item) for item in value]
- else:
- into[(self.source or field_name)] = self.from_native(value)
-
-
-# PrimaryKey relationships
-
-class PrimaryKeyRelatedField(RelatedField):
- """
- Represents a relationship as a pk value.
- """
- read_only = False
-
- default_error_messages = {
- 'does_not_exist': _("Invalid pk '%s' - object does not exist."),
- 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'),
- }
-
- # TODO: Remove these field hacks...
- def prepare_value(self, obj):
- return self.to_native(obj.pk)
-
- def label_from_instance(self, obj):
- """
- Return a readable representation for use with eg. select widgets.
- """
- desc = smart_text(obj)
- ident = smart_text(self.to_native(obj.pk))
- if desc == ident:
- return desc
- return "%s - %s" % (desc, ident)
-
- # TODO: Possibly change this to just take `obj`, through prob less performant
- def to_native(self, pk):
- return pk
-
- def from_native(self, data):
- if self.queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
-
- try:
- return self.queryset.get(pk=data)
- except ObjectDoesNotExist:
- msg = self.error_messages['does_not_exist'] % smart_text(data)
- raise ValidationError(msg)
- except (TypeError, ValueError):
- received = type(data).__name__
- msg = self.error_messages['incorrect_type'] % received
- raise ValidationError(msg)
-
- def field_to_native(self, obj, field_name):
- if self.many:
- # To-many relationship
-
- queryset = None
- if not self.source:
- # Prefer obj.serializable_value for performance reasons
- try:
- queryset = obj.serializable_value(field_name)
- except AttributeError:
- pass
- if queryset is None:
- # RelatedManager (reverse relationship)
- source = self.source or field_name
- queryset = obj
- for component in source.split('.'):
- if queryset is None:
- return []
- queryset = get_component(queryset, component)
-
- # Forward relationship
- if is_simple_callable(getattr(queryset, 'all', None)):
- return [self.to_native(item.pk) for item in queryset.all()]
- else:
- # Also support non-queryset iterables.
- # This allows us to also support plain lists of related items.
- return [self.to_native(item.pk) for item in queryset]
-
- # To-one relationship
- try:
- # Prefer obj.serializable_value for performance reasons
- pk = obj.serializable_value(self.source or field_name)
- except AttributeError:
- # RelatedObject (reverse relationship)
- try:
- pk = getattr(obj, self.source or field_name).pk
- except (ObjectDoesNotExist, AttributeError):
- return None
-
- # Forward relationship
- return self.to_native(pk)
-
-
-# Slug relationships
-
-class SlugRelatedField(RelatedField):
- """
- Represents a relationship using a unique field on the target.
- """
- read_only = False
-
- default_error_messages = {
- 'does_not_exist': _("Object with %s=%s does not exist."),
- 'invalid': _('Invalid value.'),
- }
-
- def __init__(self, *args, **kwargs):
- self.slug_field = kwargs.pop('slug_field', None)
- assert self.slug_field, 'slug_field is required'
- super(SlugRelatedField, self).__init__(*args, **kwargs)
-
- def to_native(self, obj):
- return getattr(obj, self.slug_field)
-
- def from_native(self, data):
- if self.queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
-
- try:
- return self.queryset.get(**{self.slug_field: data})
- except ObjectDoesNotExist:
- raise ValidationError(self.error_messages['does_not_exist'] %
- (self.slug_field, smart_text(data)))
- except (TypeError, ValueError):
- msg = self.error_messages['invalid']
- raise ValidationError(msg)
-
-
-# Hyperlinked relationships
-
-class HyperlinkedRelatedField(RelatedField):
- """
- Represents a relationship using hyperlinking.
- """
- read_only = False
- lookup_field = 'pk'
-
- default_error_messages = {
- 'no_match': _('Invalid hyperlink - No URL match'),
- 'incorrect_match': _('Invalid hyperlink - Incorrect URL match'),
- 'configuration_error': _('Invalid hyperlink due to configuration error'),
- 'does_not_exist': _("Invalid hyperlink - object does not exist."),
- 'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
- }
-
- def __init__(self, *args, **kwargs):
- try:
- self.view_name = kwargs.pop('view_name')
- except KeyError:
- raise ValueError("Hyperlinked field requires 'view_name' kwarg")
-
- self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
- self.format = kwargs.pop('format', None)
-
- super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
-
- def get_url(self, obj, view_name, request, format):
- """
- Given an object, return the URL that hyperlinks to the object.
-
- May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
- attributes are not configured to correctly match the URL conf.
- """
- lookup_field = getattr(obj, self.lookup_field)
- kwargs = {self.lookup_field: lookup_field}
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
-
- def get_object(self, queryset, view_name, view_args, view_kwargs):
- """
- Return the object corresponding to a matched URL.
-
- Takes the matched URL conf arguments, and the queryset, and should
- return an object instance, or raise an `ObjectDoesNotExist` exception.
- """
- lookup_value = view_kwargs[self.lookup_field]
- filter_kwargs = {self.lookup_field: lookup_value}
- return queryset.get(**filter_kwargs)
-
- def to_native(self, obj):
- view_name = self.view_name
- request = self.context.get('request', None)
- format = self.format or self.context.get('format', None)
-
- assert request is not None, (
- "`HyperlinkedRelatedField` requires the request in the serializer "
- "context. Add `context={'request': request}` when instantiating "
- "the serializer."
- )
-
- # If the object has not yet been saved then we cannot hyperlink to it.
- if getattr(obj, 'pk', None) is None:
- return
-
- # Return the hyperlink, or error if incorrectly configured.
- try:
- return self.get_url(obj, view_name, request, format)
- except NoReverseMatch:
- msg = (
- 'Could not resolve URL for hyperlinked relationship using '
- 'view name "%s". You may have failed to include the related '
- 'model in your API, or incorrectly configured the '
- '`lookup_field` attribute on this field.'
- )
- raise Exception(msg % view_name)
-
- def from_native(self, value):
- # Convert URL -> model instance pk
- # TODO: Use values_list
- queryset = self.queryset
- if queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
-
- try:
- http_prefix = value.startswith(('http:', 'https:'))
- except AttributeError:
- msg = self.error_messages['incorrect_type']
- raise ValidationError(msg % type(value).__name__)
-
- if http_prefix:
- # If needed convert absolute URLs to relative path
- value = urlparse.urlparse(value).path
- prefix = get_script_prefix()
- if value.startswith(prefix):
- value = '/' + value[len(prefix):]
-
- try:
- match = resolve(value)
- except Exception:
- raise ValidationError(self.error_messages['no_match'])
-
- if match.view_name != self.view_name:
- raise ValidationError(self.error_messages['incorrect_match'])
-
- try:
- return self.get_object(queryset, match.view_name,
- match.args, match.kwargs)
- except (ObjectDoesNotExist, TypeError, ValueError):
- raise ValidationError(self.error_messages['does_not_exist'])
-
-
-class HyperlinkedIdentityField(Field):
- """
- Represents the instance, or a property on the instance, using hyperlinking.
- """
- lookup_field = 'pk'
- read_only = True
-
- def __init__(self, *args, **kwargs):
- try:
- self.view_name = kwargs.pop('view_name')
- except KeyError:
- msg = "HyperlinkedIdentityField requires 'view_name' argument"
- raise ValueError(msg)
-
- self.format = kwargs.pop('format', None)
- lookup_field = kwargs.pop('lookup_field', None)
- self.lookup_field = lookup_field or self.lookup_field
-
- super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
-
- def field_to_native(self, obj, field_name):
- request = self.context.get('request', None)
- format = self.context.get('format', None)
- view_name = self.view_name
-
- assert request is not None, (
- "`HyperlinkedIdentityField` requires the request in the serializer"
- " context. Add `context={'request': request}` when instantiating "
- "the serializer."
- )
-
- # By default use whatever format is given for the current context
- # unless the target is a different type to the source.
- #
- # Eg. Consider a HyperlinkedIdentityField pointing from a json
- # representation to an html property of that representation...
- #
- # '/snippets/1/' should link to '/snippets/1/highlight/'
- # ...but...
- # '/snippets/1/.json' should link to '/snippets/1/highlight/.html'
- if format and self.format and self.format != format:
- format = self.format
-
- # Return the hyperlink, or error if incorrectly configured.
- try:
- return self.get_url(obj, view_name, request, format)
- except NoReverseMatch:
- msg = (
- 'Could not resolve URL for hyperlinked relationship using '
- 'view name "%s". You may have failed to include the related '
- 'model in your API, or incorrectly configured the '
- '`lookup_field` attribute on this field.'
- )
- raise Exception(msg % view_name)
-
- def get_url(self, obj, view_name, request, format):
- """
- Given an object, return the URL that hyperlinks to the object.
-
- May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
- attributes are not configured to correctly match the URL conf.
- """
- lookup_field = getattr(obj, self.lookup_field, None)
- kwargs = {self.lookup_field: lookup_field}
-
- # Handle unsaved object case
- if lookup_field is None:
- return None
-
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 748ebac9..e8935b01 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -458,7 +458,7 @@ class BrowsableAPIRenderer(BaseRenderer):
):
return
- serializer = view.get_serializer(instance=obj, data=data, files=files)
+ serializer = view.get_serializer(instance=obj, data=data)
serializer.is_valid()
data = serializer.data
@@ -579,10 +579,10 @@ class BrowsableAPIRenderer(BaseRenderer):
'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
'response_headers': response_headers,
- 'put_form': self.get_rendered_html_form(view, 'PUT', request),
- 'post_form': self.get_rendered_html_form(view, 'POST', request),
- 'delete_form': self.get_rendered_html_form(view, 'DELETE', request),
- 'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),
+ #'put_form': self.get_rendered_html_form(view, 'PUT', request),
+ #'post_form': self.get_rendered_html_form(view, 'POST', request),
+ #'delete_form': self.get_rendered_html_form(view, 'DELETE', request),
+ #'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),
'raw_data_put_form': raw_data_put_form,
'raw_data_post_form': raw_data_post_form,
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index be8ad3f2..d121812d 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -10,21 +10,14 @@ python primitives.
2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
-from __future__ import unicode_literals
-import copy
-import datetime
-import inspect
-import types
-from decimal import Decimal
-from django.contrib.contenttypes.generic import GenericForeignKey
-from django.core.paginator import Page
from django.db import models
-from django.forms import widgets
from django.utils import six
-from django.utils.datastructures import SortedDict
-from django.core.exceptions import ObjectDoesNotExist
+from collections import namedtuple, OrderedDict
+from rest_framework.fields import empty, set_value, Field, SkipField, ValidationError
from rest_framework.settings import api_settings
-
+from rest_framework.utils import html
+import copy
+import inspect
# Note: We do the following so that users of the framework can use this style:
#
@@ -37,635 +30,339 @@ from rest_framework.relations import * # NOQA
from rest_framework.fields import * # NOQA
-def _resolve_model(obj):
- """
- Resolve supplied `obj` to a Django model class.
+FieldResult = namedtuple('FieldResult', ['field', 'value', 'error'])
- `obj` must be a Django model class itself, or a string
- representation of one. Useful in situtations like GH #1225 where
- Django may not have resolved a string-based reference to a model in
- another model's foreign key definition.
-
- String representations should have the format:
- 'appname.ModelName'
- """
- if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
- app_name, model_name = obj.split('.')
- return models.get_model(app_name, model_name)
- elif inspect.isclass(obj) and issubclass(obj, models.Model):
- return obj
- else:
- raise ValueError("{0} is not a Django model".format(obj))
-
-
-def pretty_name(name):
- """Converts 'first_name' to 'First name'"""
- if not name:
- return ''
- return name.replace('_', ' ').capitalize()
+class BaseSerializer(Field):
+ def __init__(self, instance=None, data=None, **kwargs):
+ super(BaseSerializer, self).__init__(**kwargs)
+ self.instance = instance
+ self._initial_data = data
-class RelationsList(list):
- _deleted = []
+ def to_native(self, data):
+ raise NotImplementedError()
+ def to_primative(self, instance):
+ raise NotImplementedError()
-class NestedValidationError(ValidationError):
- """
- The default ValidationError behavior is to stringify each item in the list
- if the messages are a list of error messages.
+ def update(self, instance):
+ raise NotImplementedError()
- In the case of nested serializers, where the parent has many children,
- then the child's `serializer.errors` will be a list of dicts. In the case
- of a single child, the `serializer.errors` will be a dict.
+ def create(self):
+ raise NotImplementedError()
- We need to override the default behavior to get properly nested error dicts.
- """
+ def save(self, extras=None):
+ if extras is not None:
+ self._validated_data.update(extras)
- def __init__(self, message):
- if isinstance(message, dict):
- self._messages = [message]
+ if self.instance is not None:
+ self.update(self.instance)
else:
- self._messages = message
-
- @property
- def messages(self):
- return self._messages
+ self.instance = self.create()
+ return self.instance
-class DictWithMetadata(dict):
- """
- A dict-like object, that can have additional properties attached.
- """
- def __getstate__(self):
- """
- Used by pickle (e.g., caching).
- Overridden to remove the metadata from the dict, since it shouldn't be
- pickled and may in some instances be unpickleable.
- """
- return dict(self)
-
+ def is_valid(self):
+ try:
+ self._validated_data = self.to_native(self._initial_data)
+ except ValidationError as exc:
+ self._validated_data = {}
+ self._errors = exc.args[0]
+ return False
+ self._errors = {}
+ return True
-class SortedDictWithMetadata(SortedDict):
- """
- A sorted dict-like object, that can have additional properties attached.
- """
- def __getstate__(self):
- """
- Used by pickle (e.g., caching).
- Overriden to remove the metadata from the dict, since it shouldn't be
- pickle and may in some instances be unpickleable.
- """
- return SortedDict(self).__dict__
+ @property
+ def data(self):
+ if not hasattr(self, '_data'):
+ if self.instance is not None:
+ self._data = self.to_primative(self.instance)
+ elif self._initial_data is not None:
+ self._data = {
+ field_name: field.get_value(self._initial_data)
+ for field_name, field in self.fields.items()
+ }
+ else:
+ self._data = self.get_initial()
+ return self._data
+ @property
+ def errors(self):
+ if not hasattr(self, '_errors'):
+ msg = 'You must call `.is_valid()` before accessing `.errors`.'
+ raise AssertionError(msg)
+ return self._errors
-def _is_protected_type(obj):
- """
- True if the object is a native datatype that does not need to
- be serialized further.
- """
- return isinstance(obj, (
- types.NoneType,
- int, long,
- datetime.datetime, datetime.date, datetime.time,
- float, Decimal,
- basestring)
- )
+ @property
+ def validated_data(self):
+ if not hasattr(self, '_validated_data'):
+ msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
+ raise AssertionError(msg)
+ return self._validated_data
-def _get_declared_fields(bases, attrs):
+class SerializerMetaclass(type):
"""
- Create a list of serializer field instances from the passed in 'attrs',
- plus any fields on the base classes (in 'bases').
+ This metaclass sets a dictionary named `base_fields` on the class.
- Note that all fields from the base classes are used.
+ Any fields included as attributes on either the class or it's superclasses
+ will be include in the `base_fields` dictionary.
"""
- fields = [(field_name, attrs.pop(field_name))
- for field_name, obj in list(six.iteritems(attrs))
- if isinstance(obj, Field)]
- fields.sort(key=lambda x: x[1].creation_counter)
- # If this class is subclassing another Serializer, add that Serializer's
- # fields. Note that we loop over the bases in *reverse*. This is necessary
- # in order to maintain the correct order of fields.
- for base in bases[::-1]:
- if hasattr(base, 'base_fields'):
- fields = list(base.base_fields.items()) + fields
+ @classmethod
+ def _get_fields(cls, bases, attrs):
+ fields = [(field_name, attrs.pop(field_name))
+ for field_name, obj in list(attrs.items())
+ if isinstance(obj, Field)]
+ fields.sort(key=lambda x: x[1]._creation_counter)
- return SortedDict(fields)
+ # If this class is subclassing another Serializer, add that Serializer's
+ # fields. Note that we loop over the bases in *reverse*. This is necessary
+ # in order to maintain the correct order of fields.
+ for base in bases[::-1]:
+ if hasattr(base, 'base_fields'):
+ fields = list(base.base_fields.items()) + fields
+ return OrderedDict(fields)
-class SerializerMetaclass(type):
def __new__(cls, name, bases, attrs):
- attrs['base_fields'] = _get_declared_fields(bases, attrs)
+ attrs['base_fields'] = cls._get_fields(bases, attrs)
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
-class SerializerOptions(object):
- """
- Meta class options for Serializer
- """
- def __init__(self, meta):
- self.depth = getattr(meta, 'depth', 0)
- self.fields = getattr(meta, 'fields', ())
- self.exclude = getattr(meta, 'exclude', ())
+@six.add_metaclass(SerializerMetaclass)
+class Serializer(BaseSerializer):
+ def __new__(cls, *args, **kwargs):
+ many = kwargs.pop('many', False)
+ if many:
+ class DynamicListSerializer(ListSerializer):
+ child = cls()
+ return DynamicListSerializer(*args, **kwargs)
+ return super(Serializer, cls).__new__(cls)
-class BaseSerializer(WritableField):
- """
- This is the Serializer implementation.
- We need to implement it as `BaseSerializer` due to metaclass magicks.
- """
- class Meta(object):
- pass
-
- _options_class = SerializerOptions
- _dict_class = SortedDictWithMetadata
-
- def __init__(self, instance=None, data=None, files=None,
- context=None, partial=False, many=False,
- allow_add_remove=False, **kwargs):
- super(BaseSerializer, self).__init__(**kwargs)
- self.opts = self._options_class(self.Meta)
- self.parent = None
- self.root = None
- self.partial = partial
- self.many = many
- self.allow_add_remove = allow_add_remove
+ def __init__(self, *args, **kwargs):
+ kwargs.pop('context', None)
+ kwargs.pop('partial', None)
+ kwargs.pop('many', False)
- self.context = context or {}
+ super(Serializer, self).__init__(*args, **kwargs)
- self.init_data = data
- self.init_files = files
- self.object = instance
+ # Every new serializer is created with a clone of the field instances.
+ # This allows users to dynamically modify the fields on a serializer
+ # instance without affecting every other serializer class.
self.fields = self.get_fields()
- self._data = None
- self._files = None
- self._errors = None
-
- if many and instance is not None and not hasattr(instance, '__iter__'):
- raise ValueError('instance should be a queryset or other iterable with many=True')
-
- if allow_add_remove and not many:
- raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')
-
- #####
- # Methods to determine which fields to use when (de)serializing objects.
-
- def get_default_fields(self):
- """
- Return the complete set of default fields for the object, as a dict.
- """
- return {}
-
- def get_fields(self):
- """
- Returns the complete set of fields for the object as a dict.
-
- This will be the set of any explicitly declared fields,
- plus the set of fields returned by get_default_fields().
- """
- ret = SortedDict()
-
- # Get the explicitly declared fields
- base_fields = copy.deepcopy(self.base_fields)
- for key, field in base_fields.items():
- ret[key] = field
-
- # Add in the default fields
- default_fields = self.get_default_fields()
- for key, val in default_fields.items():
- if key not in ret:
- ret[key] = val
-
- # If 'fields' is specified, use those fields, in that order.
- if self.opts.fields:
- assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple'
- new = SortedDict()
- for key in self.opts.fields:
- new[key] = ret[key]
- ret = new
-
- # Remove anything in 'exclude'
- if self.opts.exclude:
- assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple'
- for key in self.opts.exclude:
- ret.pop(key, None)
-
- for key, field in ret.items():
- field.initialize(parent=self, field_name=key)
-
- return ret
-
- #####
- # Methods to convert or revert from objects <--> primitive representations.
-
- def get_field_key(self, field_name):
- """
- Return the key that should be used for a given field.
- """
- return field_name
-
- def restore_fields(self, data, files):
- """
- Core of deserialization, together with `restore_object`.
- Converts a dictionary of data into a dictionary of deserialized fields.
- """
- reverted_data = {}
-
- if data is not None and not isinstance(data, dict):
- self._errors['non_field_errors'] = ['Invalid data']
- return None
-
+ # Setup all the child fields, to provide them with the current context.
for field_name, field in self.fields.items():
- field.initialize(parent=self, field_name=field_name)
- try:
- field.field_from_native(data, files, field_name, reverted_data)
- except ValidationError as err:
- self._errors[field_name] = list(err.messages)
+ field.bind(field_name, self, self)
- return reverted_data
+ def get_fields(self):
+ return copy.deepcopy(self.base_fields)
- def perform_validation(self, attrs):
- """
- Run `validate_()` and `validate()` methods on the serializer
- """
+ def bind(self, field_name, parent, root):
+ # If the serializer is used as a field then when it becomes bound
+ # it also needs to bind all its child fields.
+ super(Serializer, self).bind(field_name, parent, root)
for field_name, field in self.fields.items():
- if field_name in self._errors:
- continue
+ field.bind(field_name, self, root)
- source = field.source or field_name
- if self.partial and source not in attrs:
- continue
- try:
- validate_method = getattr(self, 'validate_%s' % field_name, None)
- if validate_method:
- attrs = validate_method(attrs, source)
- except ValidationError as err:
- self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
-
- # If there are already errors, we don't run .validate() because
- # field-validation failed and thus `attrs` may not be complete.
- # which in turn can cause inconsistent validation errors.
- if not self._errors:
- try:
- attrs = self.validate(attrs)
- except ValidationError as err:
- if hasattr(err, 'message_dict'):
- for field_name, error_messages in err.message_dict.items():
- self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages)
- elif hasattr(err, 'messages'):
- self._errors['non_field_errors'] = err.messages
-
- return attrs
+ def get_initial(self):
+ return {
+ field.field_name: field.get_initial()
+ for field in self.fields.values()
+ }
- def validate(self, attrs):
- """
- Stub method, to be overridden in Serializer subclasses
- """
- return attrs
+ def get_value(self, dictionary):
+ # We override the default field access in order to support
+ # nested HTML forms.
+ if html.is_html_input(dictionary):
+ return html.parse_html_dict(dictionary, prefix=self.field_name)
+ return dictionary.get(self.field_name, empty)
- def restore_object(self, attrs, instance=None):
+ def to_native(self, data):
"""
- Deserialize a dictionary of attributes into an object instance.
- You should override this method to control how deserialized objects
- are instantiated.
+ Dict of native values <- Dict of primitive datatypes.
"""
- if instance is not None:
- instance.update(attrs)
- return instance
- return attrs
+ ret = {}
+ errors = {}
+ fields = [field for field in self.fields.values() if not field.read_only]
- def to_native(self, obj):
- """
- Serialize objects -> primitives.
- """
- ret = self._dict_class()
- ret.fields = self._dict_class()
+ for field in fields:
+ primitive_value = field.get_value(data)
+ try:
+ validated_value = field.validate(primitive_value)
+ except ValidationError as exc:
+ errors[field.field_name] = str(exc)
+ except SkipField:
+ pass
+ else:
+ set_value(ret, field.source_attrs, validated_value)
- for field_name, field in self.fields.items():
- if field.read_only and obj is None:
- continue
- field.initialize(parent=self, field_name=field_name)
- key = self.get_field_key(field_name)
- value = field.field_to_native(obj, field_name)
- method = getattr(self, 'transform_%s' % field_name, None)
- if callable(method):
- value = method(obj, value)
- if not getattr(field, 'write_only', False):
- ret[key] = value
- ret.fields[key] = self.augment_field(field, field_name, key, value)
+ if errors:
+ raise ValidationError(errors)
return ret
- def from_native(self, data, files=None):
- """
- Deserialize primitives -> objects.
- """
- self._errors = {}
-
- if data is not None or files is not None:
- attrs = self.restore_fields(data, files)
- if attrs is not None:
- attrs = self.perform_validation(attrs)
- else:
- self._errors['non_field_errors'] = ['No input provided']
-
- if not self._errors:
- return self.restore_object(attrs, instance=getattr(self, 'object', None))
-
- def augment_field(self, field, field_name, key, value):
- # This horrible stuff is to manage serializers rendering to HTML
- field._errors = self._errors.get(key) if self._errors else None
- field._name = field_name
- field._value = self.init_data.get(key) if self._errors and self.init_data else value
- if not field.label:
- field.label = pretty_name(key)
- return field
-
- def field_to_native(self, obj, field_name):
+ def to_primative(self, instance):
"""
- Override default so that the serializer can be used as a nested field
- across relationships.
+ Object instance -> Dict of primitive datatypes.
"""
- if self.write_only:
- return None
+ ret = OrderedDict()
+ fields = [field for field in self.fields.values() if not field.write_only]
- if self.source == '*':
- return self.to_native(obj)
+ for field in fields:
+ native_value = field.get_attribute(instance)
+ ret[field.field_name] = field.to_primative(native_value)
- # Get the raw field value
- try:
- source = self.source or field_name
- value = obj
-
- for component in source.split('.'):
- if value is None:
- break
- value = get_component(value, component)
- except ObjectDoesNotExist:
- return None
+ return ret
- if is_simple_callable(getattr(value, 'all', None)):
- return [self.to_native(item) for item in value.all()]
+ def __iter__(self):
+ errors = self.errors if hasattr(self, '_errors') else {}
+ for field in self.fields.values():
+ value = self.data.get(field.field_name) if self.data else None
+ error = errors.get(field.field_name)
+ yield FieldResult(field, value, error)
- if value is None:
- return None
- if self.many:
- return [self.to_native(item) for item in value]
- return self.to_native(value)
+class ListSerializer(BaseSerializer):
+ child = None
+ initial = []
- def field_from_native(self, data, files, field_name, into):
- """
- Override default so that the serializer can be used as a writable
- nested field across relationships.
- """
- if self.read_only:
- return
+ def __init__(self, *args, **kwargs):
+ self.child = kwargs.pop('child', copy.deepcopy(self.child))
+ assert self.child is not None, '`child` is a required argument.'
- try:
- value = data[field_name]
- except KeyError:
- if self.default is not None and not self.partial:
- # Note: partial updates shouldn't set defaults
- value = copy.deepcopy(self.default)
- else:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- return
-
- if self.source == '*':
- if value:
- reverted_data = self.restore_fields(value, {})
- if not self._errors:
- into.update(reverted_data)
- else:
- if value in (None, ''):
- into[(self.source or field_name)] = None
- else:
- # Set the serializer object if it exists
- obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
-
- # If we have a model manager or similar object then we need
- # to iterate through each instance.
- if (
- self.many and
- not hasattr(obj, '__iter__') and
- is_simple_callable(getattr(obj, 'all', None))
- ):
- obj = obj.all()
-
- kwargs = {
- 'instance': obj,
- 'data': value,
- 'context': self.context,
- 'partial': self.partial,
- 'many': self.many,
- 'allow_add_remove': self.allow_add_remove
- }
- serializer = self.__class__(**kwargs)
+ kwargs.pop('context', None)
+ kwargs.pop('partial', None)
- if serializer.is_valid():
- into[self.source or field_name] = serializer.object
- else:
- # Propagate errors up to our parent
- raise NestedValidationError(serializer.errors)
+ super(ListSerializer, self).__init__(*args, **kwargs)
+ self.child.bind('', self, self)
- def get_identity(self, data):
- """
- This hook is required for bulk update.
- It is used to determine the canonical identity of a given object.
+ def bind(self, field_name, parent, root):
+ # If the list is used as a field then it needs to provide
+ # the current context to the child serializer.
+ super(ListSerializer, self).bind(field_name, parent, root)
+ self.child.bind(field_name, self, root)
- Note that the data has not been validated at this point, so we need
- to make sure that we catch any cases of incorrect datatypes being
- passed to this method.
- """
- try:
- return data.get('id', None)
- except AttributeError:
- return None
+ def get_value(self, dictionary):
+ # We override the default field access in order to support
+ # lists in HTML forms.
+ if is_html_input(dictionary):
+ return html.parse_html_list(dictionary, prefix=self.field_name)
+ return dictionary.get(self.field_name, empty)
- @property
- def errors(self):
+ def to_native(self, data):
"""
- Run deserialization and return error data,
- setting self.object if no errors occurred.
+ List of dicts of native values <- List of dicts of primitive datatypes.
"""
- if self._errors is None:
- data, files = self.init_data, self.init_files
+ if html.is_html_input(data):
+ data = html.parse_html_list(data)
- if self.many is not None:
- many = self.many
- else:
- many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))
- if many:
- warnings.warn('Implicit list/queryset serialization is deprecated. '
- 'Use the `many=True` flag when instantiating the serializer.',
- DeprecationWarning, stacklevel=3)
-
- if many:
- ret = RelationsList()
- errors = []
- update = self.object is not None
-
- if update:
- # If this is a bulk update we need to map all the objects
- # to a canonical identity so we can determine which
- # individual object is being updated for each item in the
- # incoming data
- objects = self.object
- identities = [self.get_identity(self.to_native(obj)) for obj in objects]
- identity_to_objects = dict(zip(identities, objects))
-
- if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)):
- for item in data:
- if update:
- # Determine which object we're updating
- identity = self.get_identity(item)
- self.object = identity_to_objects.pop(identity, None)
- if self.object is None and not self.allow_add_remove:
- ret.append(None)
- errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
- continue
-
- ret.append(self.from_native(item, None))
- errors.append(self._errors)
-
- if update and self.allow_add_remove:
- ret._deleted = identity_to_objects.values()
-
- self._errors = any(errors) and errors or []
- else:
- self._errors = {'non_field_errors': ['Expected a list of items.']}
- else:
- ret = self.from_native(data, files)
-
- if not self._errors:
- self.object = ret
-
- return self._errors
-
- def is_valid(self):
- return not self.errors
+ return [self.child.validate(item) for item in data]
- @property
- def data(self):
+ def to_primative(self, data):
"""
- Returns the serialized data on the serializer.
+ List of object instances -> List of dicts of primitive datatypes.
"""
- if self._data is None:
- obj = self.object
+ return [self.child.to_primative(item) for item in data]
- if self.many is not None:
- many = self.many
- else:
- many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
- if many:
- warnings.warn('Implicit list/queryset serialization is deprecated. '
- 'Use the `many=True` flag when instantiating the serializer.',
- DeprecationWarning, stacklevel=2)
-
- if many:
- self._data = [self.to_native(item) for item in obj]
- else:
- self._data = self.to_native(obj)
+ def create(self, attrs_list):
+ return [self.child.create(attrs) for attrs in attrs_list]
- return self._data
+ def save(self):
+ if self.instance is not None:
+ self.update(self.instance, self.validated_data)
+ self.instance = self.create(self.validated_data)
+ return self.instance
- def save_object(self, obj, **kwargs):
- obj.save(**kwargs)
- def delete_object(self, obj):
- obj.delete()
-
- def save(self, **kwargs):
- """
- Save the deserialized object and return it.
- """
- # Clear cached _data, which may be invalidated by `save()`
- self._data = None
-
- if isinstance(self.object, list):
- [self.save_object(item, **kwargs) for item in self.object]
-
- if self.object._deleted:
- [self.delete_object(item) for item in self.object._deleted]
- else:
- self.save_object(self.object, **kwargs)
-
- return self.object
-
- def metadata(self):
- """
- Return a dictionary of metadata about the fields on the serializer.
- Useful for things like responding to OPTIONS requests, or generating
- API schemas for auto-documentation.
- """
- return SortedDict(
- [
- (field_name, field.metadata())
- for field_name, field in six.iteritems(self.fields)
- ]
- )
+def _resolve_model(obj):
+ """
+ Resolve supplied `obj` to a Django model class.
+ `obj` must be a Django model class itself, or a string
+ representation of one. Useful in situtations like GH #1225 where
+ Django may not have resolved a string-based reference to a model in
+ another model's foreign key definition.
-class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)):
- pass
+ String representations should have the format:
+ 'appname.ModelName'
+ """
+ if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
+ app_name, model_name = obj.split('.')
+ return models.get_model(app_name, model_name)
+ elif inspect.isclass(obj) and issubclass(obj, models.Model):
+ return obj
+ else:
+ raise ValueError("{0} is not a Django model".format(obj))
-class ModelSerializerOptions(SerializerOptions):
+class ModelSerializerOptions(object):
"""
Meta class options for ModelSerializer
"""
def __init__(self, meta):
- super(ModelSerializerOptions, self).__init__(meta)
- self.model = getattr(meta, 'model', None)
- self.read_only_fields = getattr(meta, 'read_only_fields', ())
- self.write_only_fields = getattr(meta, 'write_only_fields', ())
+ self.model = getattr(meta, 'model')
+ self.fields = getattr(meta, 'fields', ())
+ self.depth = getattr(meta, 'depth', 0)
class ModelSerializer(Serializer):
- """
- A serializer that deals with model instances and querysets.
- """
- _options_class = ModelSerializerOptions
-
field_mapping = {
models.AutoField: IntegerField,
- models.FloatField: FloatField,
+ # models.FloatField: FloatField,
models.IntegerField: IntegerField,
models.PositiveIntegerField: IntegerField,
models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
- models.DateTimeField: DateTimeField,
- models.DateField: DateField,
- models.TimeField: TimeField,
- models.DecimalField: DecimalField,
- models.EmailField: EmailField,
+ # models.DateTimeField: DateTimeField,
+ # models.DateField: DateField,
+ # models.TimeField: TimeField,
+ # models.DecimalField: DecimalField,
+ # models.EmailField: EmailField,
models.CharField: CharField,
- models.URLField: URLField,
- models.SlugField: SlugField,
+ # models.URLField: URLField,
+ # models.SlugField: SlugField,
models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField,
models.NullBooleanField: BooleanField,
- models.FileField: FileField,
- models.ImageField: ImageField,
+ # models.FileField: FileField,
+ # models.ImageField: ImageField,
}
+ _options_class = ModelSerializerOptions
+
+ def __init__(self, *args, **kwargs):
+ self.opts = self._options_class(self.Meta)
+ super(ModelSerializer, self).__init__(*args, **kwargs)
+
+ def get_fields(self):
+ # Get the explicitly declared fields.
+ fields = copy.deepcopy(self.base_fields)
+
+ # Add in the default fields.
+ for key, val in self.get_default_fields().items():
+ if key not in fields:
+ fields[key] = val
+
+ # If `fields` is set on the `Meta` class,
+ # then use only those fields, and in that order.
+ if self.opts.fields:
+ fields = OrderedDict([
+ (key, fields[key]) for key in self.opts.fields
+ ])
+
+ return fields
+
def get_default_fields(self):
"""
Return all the fields that should be serialized for the model.
"""
-
cls = self.opts.model
- assert cls is not None, (
- "Serializer class '%s' is missing 'model' Meta option" %
- self.__class__.__name__
- )
opts = cls._meta.concrete_model._meta
- ret = SortedDict()
+ ret = OrderedDict()
nested = bool(self.opts.depth)
# Deal with adding the primary key field
@@ -694,29 +391,9 @@ class ModelSerializer(Serializer):
has_through_model = True
if model_field.rel and nested:
- if len(inspect.getargspec(self.get_nested_field).args) == 2:
- warnings.warn(
- 'The `get_nested_field(model_field)` call signature '
- 'is deprecated. '
- 'Use `get_nested_field(model_field, related_model, '
- 'to_many) instead',
- DeprecationWarning
- )
- field = self.get_nested_field(model_field)
- else:
- field = self.get_nested_field(model_field, related_model, to_many)
+ field = self.get_nested_field(model_field, related_model, to_many)
elif model_field.rel:
- if len(inspect.getargspec(self.get_nested_field).args) == 3:
- warnings.warn(
- 'The `get_related_field(model_field, to_many)` call '
- 'signature is deprecated. '
- 'Use `get_related_field(model_field, related_model, '
- 'to_many) instead',
- DeprecationWarning
- )
- field = self.get_related_field(model_field, to_many=to_many)
- else:
- field = self.get_related_field(model_field, related_model, to_many)
+ field = self.get_related_field(model_field, related_model, to_many)
else:
field = self.get_field(model_field)
@@ -763,38 +440,6 @@ class ModelSerializer(Serializer):
ret[accessor_name] = field
- # Ensure that 'read_only_fields' is an iterable
- assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
-
- # Add the `read_only` flag to any fields that have been specified
- # in the `read_only_fields` option
- for field_name in self.opts.read_only_fields:
- assert field_name not in self.base_fields.keys(), (
- "field '%s' on serializer '%s' specified in "
- "`read_only_fields`, but also added "
- "as an explicit field. Remove it from `read_only_fields`." %
- (field_name, self.__class__.__name__))
- assert field_name in ret, (
- "Non-existant field '%s' specified in `read_only_fields` "
- "on serializer '%s'." %
- (field_name, self.__class__.__name__))
- ret[field_name].read_only = True
-
- # Ensure that 'write_only_fields' is an iterable
- assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
-
- for field_name in self.opts.write_only_fields:
- assert field_name not in self.base_fields.keys(), (
- "field '%s' on serializer '%s' specified in "
- "`write_only_fields`, but also added "
- "as an explicit field. Remove it from `write_only_fields`." %
- (field_name, self.__class__.__name__))
- assert field_name in ret, (
- "Non-existant field '%s' specified in `write_only_fields` "
- "on serializer '%s'." %
- (field_name, self.__class__.__name__))
- ret[field_name].write_only = True
-
return ret
def get_pk_field(self, model_field):
@@ -825,28 +470,24 @@ class ModelSerializer(Serializer):
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
- kwargs = {
- 'queryset': related_model._default_manager,
- 'many': to_many
- }
+ kwargs = {}
+ # 'queryset': related_model._default_manager,
+ # 'many': to_many
+ # }
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
- if model_field.help_text is not None:
- kwargs['help_text'] = model_field.help_text
+ # if model_field.help_text is not None:
+ # kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
-
if not model_field.editable:
kwargs['read_only'] = True
-
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
- if model_field.help_text is not None:
- kwargs['help_text'] = model_field.help_text
-
- return PrimaryKeyRelatedField(**kwargs)
+ return IntegerField(**kwargs)
+ # TODO: return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
"""
@@ -869,8 +510,8 @@ class ModelSerializer(Serializer):
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
- if model_field.help_text is not None:
- kwargs['help_text'] = model_field.help_text
+ # if model_field.help_text is not None:
+ # kwargs['help_text'] = model_field.help_text
# TODO: TypedChoiceField?
if model_field.flatchoices: # This ModelField contains choices
@@ -880,7 +521,7 @@ class ModelSerializer(Serializer):
return ChoiceField(**kwargs)
# put this below the ChoiceField because min_value isn't a valid initializer
- if issubclass(model_field.__class__, models.PositiveIntegerField) or\
+ if issubclass(model_field.__class__, models.PositiveIntegerField) or \
issubclass(model_field.__class__, models.PositiveSmallIntegerField):
kwargs['min_value'] = 0
@@ -888,170 +529,27 @@ class ModelSerializer(Serializer):
issubclass(model_field.__class__, (models.CharField, models.TextField)):
kwargs['allow_none'] = True
- attribute_dict = {
- models.CharField: ['max_length'],
- models.CommaSeparatedIntegerField: ['max_length'],
- models.DecimalField: ['max_digits', 'decimal_places'],
- models.EmailField: ['max_length'],
- models.FileField: ['max_length'],
- models.ImageField: ['max_length'],
- models.SlugField: ['max_length'],
- models.URLField: ['max_length'],
- }
-
- if model_field.__class__ in attribute_dict:
- attributes = attribute_dict[model_field.__class__]
- for attribute in attributes:
- kwargs.update({attribute: getattr(model_field, attribute)})
+ # attribute_dict = {
+ # models.CharField: ['max_length'],
+ # models.CommaSeparatedIntegerField: ['max_length'],
+ # models.DecimalField: ['max_digits', 'decimal_places'],
+ # models.EmailField: ['max_length'],
+ # models.FileField: ['max_length'],
+ # models.ImageField: ['max_length'],
+ # models.SlugField: ['max_length'],
+ # models.URLField: ['max_length'],
+ # }
+
+ # if model_field.__class__ in attribute_dict:
+ # attributes = attribute_dict[model_field.__class__]
+ # for attribute in attributes:
+ # kwargs.update({attribute: getattr(model_field, attribute)})
try:
return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
- return ModelField(model_field=model_field, **kwargs)
-
- def get_validation_exclusions(self, instance=None):
- """
- Return a list of field names to exclude from model validation.
- """
- cls = self.opts.model
- opts = cls._meta.concrete_model._meta
- exclusions = [field.name for field in opts.fields + opts.many_to_many]
-
- for field_name, field in self.fields.items():
- field_name = field.source or field_name
- if (
- field_name in exclusions
- and not field.read_only
- and (field.required or hasattr(instance, field_name))
- and not isinstance(field, Serializer)
- ):
- exclusions.remove(field_name)
- return exclusions
-
- def full_clean(self, instance):
- """
- Perform Django's full_clean, and populate the `errors` dictionary
- if any validation errors occur.
-
- Note that we don't perform this inside the `.restore_object()` method,
- so that subclasses can override `.restore_object()`, and still get
- the full_clean validation checking.
- """
- try:
- instance.full_clean(exclude=self.get_validation_exclusions(instance))
- except ValidationError as err:
- self._errors = err.message_dict
- return None
- return instance
-
- def restore_object(self, attrs, instance=None):
- """
- Restore the model instance.
- """
- m2m_data = {}
- related_data = {}
- nested_forward_relations = {}
- meta = self.opts.model._meta
-
- # Reverse fk or one-to-one relations
- for (obj, model) in meta.get_all_related_objects_with_model():
- field_name = obj.get_accessor_name()
- if field_name in attrs:
- related_data[field_name] = attrs.pop(field_name)
-
- # Reverse m2m relations
- for (obj, model) in meta.get_all_related_m2m_objects_with_model():
- field_name = obj.get_accessor_name()
- if field_name in attrs:
- m2m_data[field_name] = attrs.pop(field_name)
-
- # Forward m2m relations
- for field in meta.many_to_many + meta.virtual_fields:
- if isinstance(field, GenericForeignKey):
- continue
- if field.name in attrs:
- m2m_data[field.name] = attrs.pop(field.name)
-
- # Nested forward relations - These need to be marked so we can save
- # them before saving the parent model instance.
- for field_name in attrs.keys():
- if isinstance(self.fields.get(field_name, None), Serializer):
- nested_forward_relations[field_name] = attrs[field_name]
-
- # Create an empty instance of the model
- if instance is None:
- instance = self.opts.model()
-
- for key, val in attrs.items():
- try:
- setattr(instance, key, val)
- except ValueError:
- self._errors[key] = [self.error_messages['required']]
-
- # Any relations that cannot be set until we've
- # saved the model get hidden away on these
- # private attributes, so we can deal with them
- # at the point of save.
- instance._related_data = related_data
- instance._m2m_data = m2m_data
- instance._nested_forward_relations = nested_forward_relations
-
- return instance
-
- def from_native(self, data, files):
- """
- Override the default method to also include model field validation.
- """
- instance = super(ModelSerializer, self).from_native(data, files)
- if not self._errors:
- return self.full_clean(instance)
-
- def save_object(self, obj, **kwargs):
- """
- Save the deserialized object.
- """
- if getattr(obj, '_nested_forward_relations', None):
- # Nested relationships need to be saved before we can save the
- # parent instance.
- for field_name, sub_object in obj._nested_forward_relations.items():
- if sub_object:
- self.save_object(sub_object)
- setattr(obj, field_name, sub_object)
-
- obj.save(**kwargs)
-
- if getattr(obj, '_m2m_data', None):
- for accessor_name, object_list in obj._m2m_data.items():
- setattr(obj, accessor_name, object_list)
- del(obj._m2m_data)
-
- if getattr(obj, '_related_data', None):
- related_fields = dict([
- (field.get_accessor_name(), field)
- for field, model
- in obj._meta.get_all_related_objects_with_model()
- ])
- for accessor_name, related in obj._related_data.items():
- if isinstance(related, RelationsList):
- # Nested reverse fk relationship
- for related_item in related:
- fk_field = related_fields[accessor_name].field.name
- setattr(related_item, fk_field, obj)
- self.save_object(related_item)
-
- # Delete any removed objects
- if related._deleted:
- [self.delete_object(item) for item in related._deleted]
-
- elif isinstance(related, models.Model):
- # Nested reverse one-one relationship
- fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
- setattr(related, fk_field, obj)
- self.save_object(related)
- else:
- # Reverse FK or reverse one-one
- setattr(obj, accessor_name, related)
- del(obj._related_data)
+ # TODO: Change this to `return ModelField(model_field=model_field, **kwargs)`
+ return CharField(**kwargs)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
@@ -1066,14 +564,10 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
class HyperlinkedModelSerializer(ModelSerializer):
- """
- A subclass of ModelSerializer that uses hyperlinked relationships,
- instead of primary key relationships.
- """
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
- _hyperlink_field_class = HyperlinkedRelatedField
- _hyperlink_identify_field_class = HyperlinkedIdentityField
+ #_hyperlink_field_class = HyperlinkedRelatedField
+ #_hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
@@ -1081,15 +575,15 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model)
- if self.opts.url_field_name not in fields:
- url_field = self._hyperlink_identify_field_class(
- view_name=self.opts.view_name,
- lookup_field=self.opts.lookup_field
- )
- ret = self._dict_class()
- ret[self.opts.url_field_name] = url_field
- ret.update(fields)
- fields = ret
+ # if self.opts.url_field_name not in fields:
+ # url_field = self._hyperlink_identify_field_class(
+ # view_name=self.opts.view_name,
+ # lookup_field=self.opts.lookup_field
+ # )
+ # ret = self._dict_class()
+ # ret[self.opts.url_field_name] = url_field
+ # ret.update(fields)
+ # fields = ret
return fields
@@ -1103,33 +597,25 @@ class HyperlinkedModelSerializer(ModelSerializer):
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
- kwargs = {
- 'queryset': related_model._default_manager,
- 'view_name': self._get_default_view_name(related_model),
- 'many': to_many
- }
+ # kwargs = {
+ # 'queryset': related_model._default_manager,
+ # 'view_name': self._get_default_view_name(related_model),
+ # 'many': to_many
+ # }
+ kwargs = {}
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
- if model_field.help_text is not None:
- kwargs['help_text'] = model_field.help_text
+ # if model_field.help_text is not None:
+ # kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
- if self.opts.lookup_field:
- kwargs['lookup_field'] = self.opts.lookup_field
-
- return self._hyperlink_field_class(**kwargs)
+ return IntegerField(**kwargs)
+ # if self.opts.lookup_field:
+ # kwargs['lookup_field'] = self.opts.lookup_field
- def get_identity(self, data):
- """
- This hook is required for bulk update.
- We need to override the default, to use the url as the identity.
- """
- try:
- return data.get(self.opts.url_field_name, None)
- except AttributeError:
- return None
+ # return self._hyperlink_field_class(**kwargs)
def _get_default_view_name(self, model):
"""
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 00ffdfba..6a2f6126 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -7,7 +7,7 @@ from django.db.models.query import QuerySet
from django.utils.datastructures import SortedDict
from django.utils.functional import Promise
from rest_framework.compat import force_text
-from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
+# from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
import datetime
import decimal
import types
@@ -106,14 +106,14 @@ else:
SortedDict,
yaml.representer.SafeRepresenter.represent_dict
)
- SafeDumper.add_representer(
- DictWithMetadata,
- yaml.representer.SafeRepresenter.represent_dict
- )
- SafeDumper.add_representer(
- SortedDictWithMetadata,
- yaml.representer.SafeRepresenter.represent_dict
- )
+ # SafeDumper.add_representer(
+ # DictWithMetadata,
+ # yaml.representer.SafeRepresenter.represent_dict
+ # )
+ # SafeDumper.add_representer(
+ # SortedDictWithMetadata,
+ # yaml.representer.SafeRepresenter.represent_dict
+ # )
SafeDumper.add_representer(
types.GeneratorType,
yaml.representer.SafeRepresenter.represent_list
diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py
new file mode 100644
index 00000000..bf17050d
--- /dev/null
+++ b/rest_framework/utils/html.py
@@ -0,0 +1,86 @@
+"""
+Helpers for dealing with HTML input.
+"""
+
+def is_html_input(dictionary):
+ # MultiDict type datastructures are used to represent HTML form input,
+ # which may have more than one value for each key.
+ return hasattr(dictionary, 'getlist')
+
+
+def parse_html_list(dictionary, prefix=''):
+ """
+ Used to suport list values in HTML forms.
+ Supports lists of primitives and/or dictionaries.
+
+ * List of primitives.
+
+ {
+ '[0]': 'abc',
+ '[1]': 'def',
+ '[2]': 'hij'
+ }
+ -->
+ [
+ 'abc',
+ 'def',
+ 'hij'
+ ]
+
+ * List of dictionaries.
+
+ {
+ '[0]foo': 'abc',
+ '[0]bar': 'def',
+ '[1]foo': 'hij',
+ '[2]bar': 'klm',
+ }
+ -->
+ [
+ {'foo': 'abc', 'bar': 'def'},
+ {'foo': 'hij', 'bar': 'klm'}
+ ]
+ """
+ Dict = type(dictionary)
+ ret = {}
+ regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix))
+ for field, value in dictionary.items():
+ match = regex.match(field)
+ if not match:
+ continue
+ index, key = match.groups()
+ index = int(index)
+ if not key:
+ ret[index] = value
+ elif isinstance(ret.get(index), dict):
+ ret[index][key] = value
+ else:
+ ret[index] = Dict({key: value})
+ return [ret[item] for item in sorted(ret.keys())]
+
+
+def parse_html_dict(dictionary, prefix):
+ """
+ Used to support dictionary values in HTML forms.
+
+ {
+ 'profile.username': 'example',
+ 'profile.email': 'example@example.com',
+ }
+ -->
+ {
+ 'profile': {
+ 'username': 'example,
+ 'email': 'example@example.com'
+ }
+ }
+ """
+ ret = {}
+ regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix))
+ for field, value in dictionary.items():
+ match = regex.match(field)
+ if not match:
+ continue
+ key = match.groups()[0]
+ ret[key] = value
+ return ret
--
cgit v1.2.3
From ec096a1caceff6a4f5c75a152dd1c7bea9ed281d Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Tue, 2 Sep 2014 15:07:56 +0100
Subject: Add relations and get tests running
---
rest_framework/fields.py | 30 ++++++++++--
rest_framework/mixins.py | 1 -
rest_framework/relations.py | 111 ++++++++++++++++++++++++++++++++++++++++++
rest_framework/renderers.py | 14 +++---
rest_framework/serializers.py | 8 +--
rest_framework/utils/html.py | 2 +
6 files changed, 149 insertions(+), 17 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index a83bf94c..3e0f7ca4 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -68,7 +68,7 @@ class Field(object):
def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None,
- label=None, style=None):
+ label=None, style=None, error_messages=None):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
@@ -216,9 +216,11 @@ class CharField(Field):
'blank': 'This field may not be blank.'
}
- def __init__(self, *args, **kwargs):
+ def __init__(self, **kwargs):
self.allow_blank = kwargs.pop('allow_blank', False)
- super(CharField, self).__init__(*args, **kwargs)
+ self.max_length = kwargs.pop('max_length', None)
+ self.min_length = kwargs.pop('min_length', None)
+ super(CharField, self).__init__(**kwargs)
def to_native(self, data):
if data == '' and not self.allow_blank:
@@ -233,7 +235,7 @@ class ChoiceField(Field):
}
coerce_to_type = str
- def __init__(self, *args, **kwargs):
+ def __init__(self, **kwargs):
choices = kwargs.pop('choices')
assert choices, '`choices` argument is required and may not be empty'
@@ -257,7 +259,7 @@ class ChoiceField(Field):
str(key): key for key in self.choices.keys()
}
- super(ChoiceField, self).__init__(*args, **kwargs)
+ super(ChoiceField, self).__init__(**kwargs)
def to_native(self, data):
try:
@@ -296,6 +298,24 @@ class IntegerField(Field):
return data
+class EmailField(CharField):
+ pass # TODO
+
+
+class RegexField(CharField):
+ def __init__(self, **kwargs):
+ self.regex = kwargs.pop('regex')
+ super(CharField, self).__init__(**kwargs)
+
+
+class DateTimeField(CharField):
+ pass # TODO
+
+
+class FileField(Field):
+ pass # TODO
+
+
class MethodField(Field):
def __init__(self, **kwargs):
kwargs['source'] = '*'
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index ee01cabc..3e9c9bb3 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -6,7 +6,6 @@ which allows mixin classes to be composed in interesting ways.
"""
from __future__ import unicode_literals
-from django.core.exceptions import ValidationError
from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index e69de29b..42d2c121 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -0,0 +1,111 @@
+from rest_framework.fields import Field
+from django.core.exceptions import ObjectDoesNotExist
+from django.core.urlresolvers import resolve, get_script_prefix
+from rest_framework.compat import urlparse
+
+
+def get_default_queryset(serializer_class, field_name):
+ manager = getattr(serializer_class.opts.model, field_name)
+ if hasattr(manager, 'related'):
+ # Forward relationships
+ return manager.related.model._default_manager.all()
+ # Reverse relationships
+ return manager.field.rel.to._default_manager.all()
+
+
+class RelatedField(Field):
+ def __init__(self, **kwargs):
+ self.queryset = kwargs.pop('queryset', None)
+ self.many = kwargs.pop('many', False)
+ super(RelatedField, self).__init__(**kwargs)
+
+ def bind(self, field_name, parent, root):
+ super(RelatedField, self).bind(field_name, parent, root)
+ if self.queryset is None and not self.read_only:
+ self.queryset = get_default_queryset(parent, self.source)
+
+
+class PrimaryKeyRelatedField(RelatedField):
+ MESSAGES = {
+ 'required': 'This field is required.',
+ 'does_not_exist': "Invalid pk '{pk_value}' - object does not exist.",
+ 'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.',
+ }
+
+ def from_native(self, data):
+ try:
+ return self.queryset.get(pk=data)
+ except ObjectDoesNotExist:
+ self.fail('does_not_exist', pk_value=data)
+ except (TypeError, ValueError):
+ self.fail('incorrect_type', data_type=type(data).__name__)
+
+
+class HyperlinkedRelatedField(RelatedField):
+ lookup_field = 'pk'
+
+ MESSAGES = {
+ 'required': 'This field is required.',
+ 'no_match': 'Invalid hyperlink - No URL match',
+ 'incorrect_match': 'Invalid hyperlink - Incorrect URL match.',
+ 'does_not_exist': "Invalid hyperlink - Object does not exist.",
+ 'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',
+ }
+
+ def __init__(self, **kwargs):
+ self.view_name = kwargs.pop('view_name')
+ self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
+ self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
+ super(HyperlinkedRelatedField, self).__init__(**kwargs)
+
+ def get_object(self, view_name, view_args, view_kwargs):
+ """
+ Return the object corresponding to a matched URL.
+
+ Takes the matched URL conf arguments, and should return an
+ object instance, or raise an `ObjectDoesNotExist` exception.
+ """
+ lookup_value = view_kwargs[self.lookup_url_kwarg]
+ lookup_kwargs = {self.lookup_field: lookup_value}
+ return self.queryset.get(**lookup_kwargs)
+
+ def from_native(self, value):
+ try:
+ http_prefix = value.startswith(('http:', 'https:'))
+ except AttributeError:
+ self.fail('incorrect_type', type(value).__name__)
+
+ if http_prefix:
+ # If needed convert absolute URLs to relative path
+ value = urlparse.urlparse(value).path
+ prefix = get_script_prefix()
+ if value.startswith(prefix):
+ value = '/' + value[len(prefix):]
+
+ try:
+ match = resolve(value)
+ except Exception:
+ self.fail('no_match')
+
+ if match.view_name != self.view_name:
+ self.fail('incorrect_match')
+
+ try:
+ return self.get_object(match.view_name, match.args, match.kwargs)
+ except (ObjectDoesNotExist, TypeError, ValueError):
+ self.fail('does_not_exist')
+
+
+class HyperlinkedIdentityField(RelatedField):
+ lookup_field = 'pk'
+
+ def __init__(self, **kwargs):
+ self.view_name = kwargs.pop('view_name')
+ self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
+ self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
+ super(HyperlinkedIdentityField, self).__init__(**kwargs)
+
+
+class SlugRelatedField(RelatedField):
+ def __init__(self, **kwargs):
+ self.slug_field = kwargs.pop('slug_field', None)
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index e8935b01..dfc5a39f 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -436,13 +436,13 @@ class BrowsableAPIRenderer(BaseRenderer):
if request.method == method:
try:
data = request.DATA
- files = request.FILES
+ # files = request.FILES
except ParseError:
data = None
- files = None
+ # files = None
else:
data = None
- files = None
+ # files = None
with override_method(view, request, method) as request:
obj = getattr(view, 'object', None)
@@ -579,10 +579,10 @@ class BrowsableAPIRenderer(BaseRenderer):
'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
'response_headers': response_headers,
- #'put_form': self.get_rendered_html_form(view, 'PUT', request),
- #'post_form': self.get_rendered_html_form(view, 'POST', request),
- #'delete_form': self.get_rendered_html_form(view, 'DELETE', request),
- #'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),
+ # 'put_form': self.get_rendered_html_form(view, 'PUT', request),
+ # 'post_form': self.get_rendered_html_form(view, 'POST', request),
+ # 'delete_form': self.get_rendered_html_form(view, 'DELETE', request),
+ # 'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),
'raw_data_put_form': raw_data_put_form,
'raw_data_post_form': raw_data_post_form,
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index d121812d..2f23b4d9 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -477,8 +477,8 @@ class ModelSerializer(Serializer):
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
- # if model_field.help_text is not None:
- # kwargs['help_text'] = model_field.help_text
+ # if model_field.help_text is not None:
+ # kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if not model_field.editable:
@@ -566,8 +566,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
- #_hyperlink_field_class = HyperlinkedRelatedField
- #_hyperlink_identify_field_class = HyperlinkedIdentityField
+ # _hyperlink_field_class = HyperlinkedRelatedField
+ # _hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py
index bf17050d..edc591e9 100644
--- a/rest_framework/utils/html.py
+++ b/rest_framework/utils/html.py
@@ -1,6 +1,8 @@
"""
Helpers for dealing with HTML input.
"""
+import re
+
def is_html_input(dictionary):
# MultiDict type datastructures are used to represent HTML form input,
--
cgit v1.2.3
From f2852811f93863f2eed04d51eeb7ef27716b2409 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Tue, 2 Sep 2014 17:41:23 +0100
Subject: Getting tests passing
---
rest_framework/authtoken/serializers.py | 5 ++--
rest_framework/authtoken/views.py | 3 +-
rest_framework/fields.py | 50 ++++++++++++++++++++++++++++++-
rest_framework/mixins.py | 41 +++----------------------
rest_framework/pagination.py | 36 ++++++++--------------
rest_framework/relations.py | 2 +-
rest_framework/serializers.py | 53 ++++++++++++++++++++-------------
7 files changed, 103 insertions(+), 87 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py
index 99e99ae3..edeae857 100644
--- a/rest_framework/authtoken/serializers.py
+++ b/rest_framework/authtoken/serializers.py
@@ -19,11 +19,12 @@ class AuthTokenSerializer(serializers.Serializer):
if not user.is_active:
msg = _('User account is disabled.')
raise serializers.ValidationError(msg)
- attrs['user'] = user
- return attrs
else:
msg = _('Unable to login with provided credentials.')
raise serializers.ValidationError(msg)
else:
msg = _('Must include "username" and "password"')
raise serializers.ValidationError(msg)
+
+ attrs['user'] = user
+ return attrs
diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py
index 7c03cb76..94e6f061 100644
--- a/rest_framework/authtoken/views.py
+++ b/rest_framework/authtoken/views.py
@@ -18,7 +18,8 @@ class ObtainAuthToken(APIView):
def post(self, request):
serializer = self.serializer_class(data=request.DATA)
if serializer.is_valid():
- token, created = Token.objects.get_or_create(user=serializer.object['user'])
+ user = serializer.validated_data['user']
+ token, created = Token.objects.get_or_create(user=user)
return Response({'token': token.key})
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 3e0f7ca4..838aa3b0 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,4 +1,5 @@
from rest_framework.utils import html
+import inspect
class empty:
@@ -11,6 +12,22 @@ class empty:
pass
+def is_simple_callable(obj):
+ """
+ True if the object is a callable that takes no arguments.
+ """
+ function = inspect.isfunction(obj)
+ method = inspect.ismethod(obj)
+
+ if not (function or method):
+ return False
+
+ args, _, _, defaults = inspect.getargspec(obj)
+ len_args = len(args) if function else len(args) - 1
+ len_defaults = len(defaults) if defaults else 0
+ return len_args <= len_defaults
+
+
def get_attribute(instance, attrs):
"""
Similar to Python's built in `getattr(instance, attr)`,
@@ -98,6 +115,7 @@ class Field(object):
self.field_name = field_name
self.parent = parent
self.root = root
+ self.context = parent.context
# `self.label` should deafult to being based on the field name.
if self.label is None:
@@ -297,25 +315,55 @@ class IntegerField(Field):
self.fail('invalid_integer')
return data
+ def to_primative(self, value):
+ if value is None:
+ return None
+ return int(value)
+
class EmailField(CharField):
pass # TODO
+class URLField(CharField):
+ pass # TODO
+
+
class RegexField(CharField):
def __init__(self, **kwargs):
self.regex = kwargs.pop('regex')
super(CharField, self).__init__(**kwargs)
+class DateField(CharField):
+ def __init__(self, **kwargs):
+ self.input_formats = kwargs.pop('input_formats', None)
+ super(DateField, self).__init__(**kwargs)
+
+
+class TimeField(CharField):
+ def __init__(self, **kwargs):
+ self.input_formats = kwargs.pop('input_formats', None)
+ super(TimeField, self).__init__(**kwargs)
+
+
class DateTimeField(CharField):
- pass # TODO
+ def __init__(self, **kwargs):
+ self.input_formats = kwargs.pop('input_formats', None)
+ super(DateTimeField, self).__init__(**kwargs)
class FileField(Field):
pass # TODO
+class ReadOnlyField(Field):
+ def to_primative(self, value):
+ if is_simple_callable(value):
+ return value()
+ return value
+
+
class MethodField(Field):
def __init__(self, **kwargs):
kwargs['source'] = '*'
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 3e9c9bb3..359740ce 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -13,23 +13,6 @@ from rest_framework.request import clone_request
from rest_framework.settings import api_settings
-def _get_validation_exclusions(obj, lookup_field=None):
- """
- Given a model instance, and an optional pk and slug field,
- return the full list of all other field names on that model.
-
- For use when performing full_clean on a model instance,
- so we only clean the required fields.
- """
- if lookup_field == 'pk':
- pk_field = obj._meta.pk
- while pk_field.rel:
- pk_field = pk_field.rel.to._meta.pk
- lookup_field = pk_field.name
-
- return [field.name for field in obj._meta.fields if field.name != lookup_field]
-
-
class CreateModelMixin(object):
"""
Create a model instance.
@@ -92,15 +75,14 @@ class UpdateModelMixin(object):
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
- lookup_value = self.kwargs[lookup_url_kwarg]
- extras = {self.lookup_field: lookup_value}
-
if self.object is None:
+ lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
+ lookup_value = self.kwargs[lookup_url_kwarg]
+ extras = {self.lookup_field: lookup_value}
self.object = serializer.save(extras=extras)
return Response(serializer.data, status=status.HTTP_201_CREATED)
- self.object = serializer.save(extras=extras)
+ self.object = serializer.save()
return Response(serializer.data, status=status.HTTP_200_OK)
def partial_update(self, request, *args, **kwargs):
@@ -122,21 +104,6 @@ class UpdateModelMixin(object):
# return a 404 response.
raise
- def pre_save(self, obj):
- """
- Set any attributes on the object that are implicit in the request.
- """
- lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
- lookup_value = self.kwargs[lookup_url_kwarg]
-
- setattr(obj, self.lookup_field, lookup_value)
-
- # Ensure we clean the attributes so that we don't eg return integer
- # pk using a string representation, as provided by the url conf kwarg.
- if hasattr(obj, 'full_clean'):
- exclude = _get_validation_exclusions(obj, self.lookup_field)
- obj.full_clean(exclude)
-
class DestroyModelMixin(object):
"""
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 83ef97c5..478d32b4 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -13,7 +13,7 @@ class NextPageField(serializers.Field):
"""
page_field = 'page'
- def to_native(self, value):
+ def to_primative(self, value):
if not value.has_next():
return None
page = value.next_page_number()
@@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field):
"""
page_field = 'page'
- def to_native(self, value):
+ def to_primative(self, value):
if not value.has_previous():
return None
page = value.previous_page_number()
@@ -48,25 +48,11 @@ class DefaultObjectSerializer(serializers.Field):
super(DefaultObjectSerializer, self).__init__(source=source)
-# class PaginationSerializerOptions(serializers.SerializerOptions):
-# """
-# An object that stores the options that may be provided to a
-# pagination serializer by using the inner `Meta` class.
-
-# Accessible on the instance as `serializer.opts`.
-# """
-# def __init__(self, meta):
-# super(PaginationSerializerOptions, self).__init__(meta)
-# self.object_serializer_class = getattr(meta, 'object_serializer_class',
-# DefaultObjectSerializer)
-
-
class BasePaginationSerializer(serializers.Serializer):
"""
A base class for pagination serializers to inherit from,
to make implementing custom serializers more easy.
"""
- # _options_class = PaginationSerializerOptions
results_field = 'results'
def __init__(self, *args, **kwargs):
@@ -75,14 +61,16 @@ class BasePaginationSerializer(serializers.Serializer):
"""
super(BasePaginationSerializer, self).__init__(*args, **kwargs)
results_field = self.results_field
- object_serializer = self.opts.object_serializer_class
-
- if 'context' in kwargs:
- context_kwarg = {'context': kwargs['context']}
- else:
- context_kwarg = {}
-
- self.fields[results_field] = object_serializer(source='object_list', **context_kwarg)
+ try:
+ object_serializer = self.Meta.object_serializer_class
+ except AttributeError:
+ object_serializer = DefaultObjectSerializer
+
+ self.fields[results_field] = serializers.ListSerializer(
+ child=object_serializer(),
+ source='object_list'
+ )
+ self.fields[results_field].bind(results_field, self, self) # TODO: Support automatic binding
class PaginationSerializer(BasePaginationSerializer):
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 42d2c121..0b01394a 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -73,7 +73,7 @@ class HyperlinkedRelatedField(RelatedField):
try:
http_prefix = value.startswith(('http:', 'https:'))
except AttributeError:
- self.fail('incorrect_type', type(value).__name__)
+ self.fail('incorrect_type', data_type=type(value).__name__)
if http_prefix:
# If needed convert absolute URLs to relative path
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 2f23b4d9..c38d8968 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -142,7 +142,7 @@ class Serializer(BaseSerializer):
return super(Serializer, cls).__new__(cls)
def __init__(self, *args, **kwargs):
- kwargs.pop('context', None)
+ self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
kwargs.pop('many', False)
@@ -202,7 +202,7 @@ class Serializer(BaseSerializer):
if errors:
raise ValidationError(errors)
- return ret
+ return self.validate(ret)
def to_primative(self, instance):
"""
@@ -217,6 +217,9 @@ class Serializer(BaseSerializer):
return ret
+ def validate(self, attrs):
+ return attrs
+
def __iter__(self):
errors = self.errors if hasattr(self, '_errors') else {}
for field in self.fields.values():
@@ -232,8 +235,7 @@ class ListSerializer(BaseSerializer):
def __init__(self, *args, **kwargs):
self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.'
-
- kwargs.pop('context', None)
+ self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
super(ListSerializer, self).__init__(*args, **kwargs)
@@ -316,19 +318,19 @@ class ModelSerializer(Serializer):
models.PositiveIntegerField: IntegerField,
models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
- # models.DateTimeField: DateTimeField,
- # models.DateField: DateField,
- # models.TimeField: TimeField,
+ models.DateTimeField: DateTimeField,
+ models.DateField: DateField,
+ models.TimeField: TimeField,
# models.DecimalField: DecimalField,
- # models.EmailField: EmailField,
+ models.EmailField: EmailField,
models.CharField: CharField,
- # models.URLField: URLField,
+ models.URLField: URLField,
# models.SlugField: SlugField,
models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField,
models.NullBooleanField: BooleanField,
- # models.FileField: FileField,
+ models.FileField: FileField,
# models.ImageField: ImageField,
}
@@ -338,6 +340,15 @@ class ModelSerializer(Serializer):
self.opts = self._options_class(self.Meta)
super(ModelSerializer, self).__init__(*args, **kwargs)
+ def create(self):
+ ModelClass = self.opts.model
+ return ModelClass.objects.create(**self.validated_data)
+
+ def update(self, obj):
+ for attr, value in self.validated_data.items():
+ setattr(obj, attr, value)
+ obj.save()
+
def get_fields(self):
# Get the explicitly declared fields.
fields = copy.deepcopy(self.base_fields)
@@ -566,8 +577,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
- # _hyperlink_field_class = HyperlinkedRelatedField
- # _hyperlink_identify_field_class = HyperlinkedIdentityField
+ _hyperlink_field_class = HyperlinkedRelatedField
+ _hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
@@ -575,15 +586,15 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model)
- # if self.opts.url_field_name not in fields:
- # url_field = self._hyperlink_identify_field_class(
- # view_name=self.opts.view_name,
- # lookup_field=self.opts.lookup_field
- # )
- # ret = self._dict_class()
- # ret[self.opts.url_field_name] = url_field
- # ret.update(fields)
- # fields = ret
+ if self.opts.url_field_name not in fields:
+ url_field = self._hyperlink_identify_field_class(
+ view_name=self.opts.view_name,
+ lookup_field=self.opts.lookup_field
+ )
+ ret = fields.__class__()
+ ret[self.opts.url_field_name] = url_field
+ ret.update(fields)
+ fields = ret
return fields
--
cgit v1.2.3
From c1036c17533a3091401ff90f825571f0e6125eca Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 3 Sep 2014 16:34:09 +0100
Subject: More test passing
---
rest_framework/relations.py | 56 ++++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 55 insertions(+), 1 deletion(-)
(limited to 'rest_framework')
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 0b01394a..661a1249 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -1,6 +1,7 @@
from rest_framework.fields import Field
+from rest_framework.reverse import reverse
from django.core.exceptions import ObjectDoesNotExist
-from django.core.urlresolvers import resolve, get_script_prefix
+from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
from rest_framework.compat import urlparse
@@ -100,11 +101,64 @@ class HyperlinkedIdentityField(RelatedField):
lookup_field = 'pk'
def __init__(self, **kwargs):
+ kwargs['read_only'] = True
self.view_name = kwargs.pop('view_name')
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
super(HyperlinkedIdentityField, self).__init__(**kwargs)
+ def get_attribute(self, instance):
+ return instance
+
+ def to_primative(self, value):
+ request = self.context.get('request', None)
+ format = self.context.get('format', None)
+
+ assert request is not None, (
+ "`HyperlinkedIdentityField` requires the request in the serializer"
+ " context. Add `context={'request': request}` when instantiating "
+ "the serializer."
+ )
+
+ # By default use whatever format is given for the current context
+ # unless the target is a different type to the source.
+ #
+ # Eg. Consider a HyperlinkedIdentityField pointing from a json
+ # representation to an html property of that representation...
+ #
+ # '/snippets/1/' should link to '/snippets/1/highlight/'
+ # ...but...
+ # '/snippets/1/.json' should link to '/snippets/1/highlight/.html'
+ if format and self.format and self.format != format:
+ format = self.format
+
+ # Return the hyperlink, or error if incorrectly configured.
+ try:
+ return self.get_url(value, self.view_name, request, format)
+ except NoReverseMatch:
+ msg = (
+ 'Could not resolve URL for hyperlinked relationship using '
+ 'view name "%s". You may have failed to include the related '
+ 'model in your API, or incorrectly configured the '
+ '`lookup_field` attribute on this field.'
+ )
+ raise Exception(msg % self.view_name)
+
+ def get_url(self, obj, view_name, request, format):
+ """
+ Given an object, return the URL that hyperlinks to the object.
+
+ May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
+ attributes are not configured to correctly match the URL conf.
+ """
+ # Unsaved objects will not yet have a valid URL.
+ if obj.pk is None:
+ return None
+
+ lookup_value = getattr(obj, self.lookup_field)
+ kwargs = {self.lookup_url_kwarg: lookup_value}
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+
class SlugRelatedField(RelatedField):
def __init__(self, **kwargs):
--
cgit v1.2.3
From d934824bff21e4a11226af61efba319be227f4f0 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 5 Sep 2014 16:29:46 +0100
Subject: Workin on
---
rest_framework/exceptions.py | 9 +++--
rest_framework/fields.py | 26 +++++++++++---
rest_framework/generics.py | 30 +---------------
rest_framework/mixins.py | 83 ++++++++++++++++++++++++-------------------
rest_framework/serializers.py | 64 +++++++++++++++++++++------------
5 files changed, 118 insertions(+), 94 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index ad52d172..852a08b1 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -15,7 +15,7 @@ class APIException(Exception):
Subclasses should provide `.status_code` and `.default_detail` properties.
"""
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
- default_detail = ''
+ default_detail = 'A server error occured'
def __init__(self, detail=None):
self.detail = detail or self.default_detail
@@ -29,6 +29,11 @@ class ParseError(APIException):
default_detail = 'Malformed request.'
+class ValidationError(APIException):
+ status_code = status.HTTP_400_BAD_REQUEST
+ default_detail = 'Invalid data in request.'
+
+
class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = 'Incorrect authentication credentials.'
@@ -54,7 +59,7 @@ class MethodNotAllowed(APIException):
class NotAcceptable(APIException):
status_code = status.HTTP_406_NOT_ACCEPTABLE
- default_detail = "Could not satisfy the request's Accept header"
+ default_detail = "Could not satisfy the request Accept header"
def __init__(self, detail=None, available_renderers=None):
self.detail = detail or self.default_detail
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 838aa3b0..d18551b3 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,3 +1,4 @@
+from rest_framework.exceptions import ValidationError
from rest_framework.utils import html
import inspect
@@ -59,10 +60,6 @@ def set_value(dictionary, keys, value):
dictionary[keys[-1]] = value
-class ValidationError(Exception):
- pass
-
-
class SkipField(Exception):
pass
@@ -204,6 +201,22 @@ class Field(object):
msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg)
+ def __new__(cls, *args, **kwargs):
+ instance = super(Field, cls).__new__(cls)
+ instance._args = args
+ instance._kwargs = kwargs
+ return instance
+
+ def __repr__(self):
+ arg_string = ', '.join([repr(val) for val in self._args])
+ kwarg_string = ', '.join([
+ '%s=%s' % (key, repr(val)) for key, val in self._kwargs.items()
+ ])
+ if arg_string and kwarg_string:
+ arg_string += ', '
+ class_name = self.__class__.__name__
+ return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
+
class BooleanField(Field):
MESSAGES = {
@@ -308,6 +321,11 @@ class IntegerField(Field):
'invalid_integer': 'A valid integer is required.'
}
+ def __init__(self, **kwargs):
+ self.max_value = kwargs.pop('max_value')
+ self.min_value = kwargs.pop('min_value')
+ super(CharField, self).__init__(**kwargs)
+
def to_native(self, data):
try:
data = int(str(data))
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 6705cbb2..c2c59154 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -27,7 +27,7 @@ def strict_positive_int(integer_string, cutoff=None):
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
"""
- Same as Django's standard shortcut, but make sure to raise 404
+ Same as Django's standard shortcut, but make sure to also raise 404
if the filter_kwargs don't match the required types.
"""
try:
@@ -249,34 +249,6 @@ class GenericAPIView(views.APIView):
#
# The are not called by GenericAPIView directly,
# but are used by the mixin methods.
-
- def pre_save(self, obj):
- """
- Placeholder method for calling before saving an object.
-
- May be used to set attributes on the object that are implicit
- in either the request, or the url.
- """
- pass
-
- def post_save(self, obj, created=False):
- """
- Placeholder method for calling after saving an object.
- """
- pass
-
- def pre_delete(self, obj):
- """
- Placeholder method for calling before deleting an object.
- """
- pass
-
- def post_delete(self, obj):
- """
- Placeholder method for calling after deleting an object.
- """
- pass
-
def metadata(self, request):
"""
Return a dictionary of metadata about the view.
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 359740ce..14a6b44b 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -19,14 +19,10 @@ class CreateModelMixin(object):
"""
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA)
-
- if serializer.is_valid():
- self.object = serializer.save()
- headers = self.get_success_headers(serializer.data)
- return Response(serializer.data, status=status.HTTP_201_CREATED,
- headers=headers)
-
- return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ serializer.is_valid(raise_exception=True)
+ serializer.save()
+ headers = self.get_success_headers(serializer.data)
+ return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def get_success_headers(self, data):
try:
@@ -40,15 +36,12 @@ class ListModelMixin(object):
List a queryset.
"""
def list(self, request, *args, **kwargs):
- self.object_list = self.filter_queryset(self.get_queryset())
-
- # Switch between paginated or standard style responses
- page = self.paginate_queryset(self.object_list)
+ instance = self.filter_queryset(self.get_queryset())
+ page = self.paginate_queryset(instance)
if page is not None:
serializer = self.get_pagination_serializer(page)
else:
- serializer = self.get_serializer(self.object_list, many=True)
-
+ serializer = self.get_serializer(instance, many=True)
return Response(serializer.data)
@@ -57,8 +50,8 @@ class RetrieveModelMixin(object):
Retrieve a model instance.
"""
def retrieve(self, request, *args, **kwargs):
- self.object = self.get_object()
- serializer = self.get_serializer(self.object)
+ instance = self.get_object()
+ serializer = self.get_serializer(instance)
return Response(serializer.data)
@@ -68,22 +61,52 @@ class UpdateModelMixin(object):
"""
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
- self.object = self.get_object_or_none()
+ instance = self.get_object()
+ serializer = self.get_serializer(instance, data=request.DATA, partial=partial)
+ serializer.is_valid(raise_exception=True)
+ serializer.save()
+ return Response(serializer.data)
+
+ def partial_update(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
- serializer = self.get_serializer(self.object, data=request.DATA, partial=partial)
+class DestroyModelMixin(object):
+ """
+ Destroy a model instance.
+ """
+ def destroy(self, request, *args, **kwargs):
+ instance = self.get_object()
+ instance.delete()
+ return Response(status=status.HTTP_204_NO_CONTENT)
- if not serializer.is_valid():
- return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- if self.object is None:
+# The AllowPUTAsCreateMixin was previously the default behaviour
+# for PUT requests. This has now been removed and must be *explictly*
+# included if it is the behavior that you want.
+# For more info see: ...
+
+class AllowPUTAsCreateMixin(object):
+ """
+ The following mixin class may be used in order to support PUT-as-create
+ behavior for incoming requests.
+ """
+ def update(self, request, *args, **kwargs):
+ partial = kwargs.pop('partial', False)
+ instance = self.get_object_or_none()
+ serializer = self.get_serializer(instance, data=request.DATA, partial=partial)
+ serializer.is_valid(raise_exception=True)
+
+ if instance is None:
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup_value = self.kwargs[lookup_url_kwarg]
extras = {self.lookup_field: lookup_value}
- self.object = serializer.save(extras=extras)
+ serializer.save(extras=extras)
return Response(serializer.data, status=status.HTTP_201_CREATED)
- self.object = serializer.save()
- return Response(serializer.data, status=status.HTTP_200_OK)
+ serializer.save()
+ return Response(serializer.data)
def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True
@@ -103,15 +126,3 @@ class UpdateModelMixin(object):
# PATCH requests where the object does not exist should still
# return a 404 response.
raise
-
-
-class DestroyModelMixin(object):
- """
- Destroy a model instance.
- """
- def destroy(self, request, *args, **kwargs):
- obj = self.get_object()
- self.pre_delete(obj)
- obj.delete()
- self.post_delete(obj)
- return Response(status=status.HTTP_204_NO_CONTENT)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index c38d8968..49eb6ce9 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -13,7 +13,8 @@ response content is handled by parsers and renderers.
from django.db import models
from django.utils import six
from collections import namedtuple, OrderedDict
-from rest_framework.fields import empty, set_value, Field, SkipField, ValidationError
+from rest_framework.exceptions import ValidationError
+from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
from rest_framework.utils import html
import copy
@@ -34,43 +35,53 @@ FieldResult = namedtuple('FieldResult', ['field', 'value', 'error'])
class BaseSerializer(Field):
+ """
+ The BaseSerializer class provides a minimal class which may be used
+ for writing custom serializer implementations.
+ """
+
def __init__(self, instance=None, data=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.instance = instance
self._initial_data = data
def to_native(self, data):
- raise NotImplementedError()
+ raise NotImplementedError('`to_native()` must be implemented.')
def to_primative(self, instance):
- raise NotImplementedError()
+ raise NotImplementedError('`to_primative()` must be implemented.')
- def update(self, instance):
- raise NotImplementedError()
+ def update(self, instance, attrs):
+ raise NotImplementedError('`update()` must be implemented.')
- def create(self):
- raise NotImplementedError()
+ def create(self, attrs):
+ raise NotImplementedError('`create()` must be implemented.')
def save(self, extras=None):
if extras is not None:
- self._validated_data.update(extras)
+ self.validated_data.update(extras)
if self.instance is not None:
- self.update(self.instance)
+ self.update(self.instance, self._validated_data)
else:
- self.instance = self.create()
+ self.instance = self.create(self._validated_data)
return self.instance
- def is_valid(self):
- try:
- self._validated_data = self.to_native(self._initial_data)
- except ValidationError as exc:
- self._validated_data = {}
- self._errors = exc.args[0]
- return False
- self._errors = {}
- return True
+ def is_valid(self, raise_exception=False):
+ if not hasattr(self, '_validated_data'):
+ try:
+ self._validated_data = self.to_native(self._initial_data)
+ except ValidationError as exc:
+ self._validated_data = {}
+ self._errors = exc.detail
+ else:
+ self._errors = {}
+
+ if self._errors and raise_exception:
+ raise ValidationError(self._errors)
+
+ return not bool(self._errors)
@property
def data(self):
@@ -184,14 +195,20 @@ class Serializer(BaseSerializer):
"""
Dict of native values <- Dict of primitive datatypes.
"""
+ if not isinstance(data, dict):
+ raise ValidationError({'non_field_errors': ['Invalid data']})
+
ret = {}
errors = {}
fields = [field for field in self.fields.values() if not field.read_only]
for field in fields:
+ validate_method = getattr(self, 'validate_' + field.field_name, None)
primitive_value = field.get_value(data)
try:
validated_value = field.validate(primitive_value)
+ if validate_method is not None:
+ validated_value = validate_method(validated_value)
except ValidationError as exc:
errors[field.field_name] = str(exc)
except SkipField:
@@ -202,6 +219,7 @@ class Serializer(BaseSerializer):
if errors:
raise ValidationError(errors)
+ # TODO: 'Non field errors'
return self.validate(ret)
def to_primative(self, instance):
@@ -340,12 +358,12 @@ class ModelSerializer(Serializer):
self.opts = self._options_class(self.Meta)
super(ModelSerializer, self).__init__(*args, **kwargs)
- def create(self):
+ def create(self, attrs):
ModelClass = self.opts.model
- return ModelClass.objects.create(**self.validated_data)
+ return ModelClass.objects.create(**attrs)
- def update(self, obj):
- for attr, value in self.validated_data.items():
+ def update(self, obj, attrs):
+ for attr, value in attrs.items():
setattr(obj, attr, value)
obj.save()
--
cgit v1.2.3
From 21980b800d04a1d82a6003823abfdf4ab80ae979 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 8 Sep 2014 14:24:05 +0100
Subject: More test sorting
---
rest_framework/exceptions.py | 5 ---
rest_framework/fields.py | 85 ++++++++++++++++++++++++++++++++++++++++---
rest_framework/serializers.py | 25 ++++++++-----
rest_framework/views.py | 9 ++++-
4 files changed, 101 insertions(+), 23 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 852a08b1..06b5e8a2 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -29,11 +29,6 @@ class ParseError(APIException):
default_detail = 'Malformed request.'
-class ValidationError(APIException):
- status_code = status.HTTP_400_BAD_REQUEST
- default_detail = 'Invalid data in request.'
-
-
class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = 'Incorrect authentication credentials.'
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index d18551b3..250c0579 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,4 +1,6 @@
-from rest_framework.exceptions import ValidationError
+from django.core import validators
+from django.core.exceptions import ValidationError
+from django.utils.encoding import is_protected_type
from rest_framework.utils import html
import inspect
@@ -33,9 +35,14 @@ def get_attribute(instance, attrs):
"""
Similar to Python's built in `getattr(instance, attr)`,
but takes a list of nested attributes, instead of a single attribute.
+
+ Also accepts either attribute lookup on objects or dictionary lookups.
"""
for attr in attrs:
- instance = getattr(instance, attr)
+ try:
+ instance = getattr(instance, attr)
+ except AttributeError:
+ return instance[attr]
return instance
@@ -80,9 +87,11 @@ class Field(object):
'not exist in the `MESSAGES` dictionary.'
)
+ default_validators = []
+
def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None,
- label=None, style=None, error_messages=None):
+ label=None, style=None, error_messages=None, validators=[]):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
@@ -104,6 +113,7 @@ class Field(object):
self.initial = initial
self.label = label
self.style = {} if style is None else style
+ self.validators = self.default_validators + validators
def bind(self, field_name, parent, root):
"""
@@ -176,8 +186,21 @@ class Field(object):
self.fail('required')
return self.get_default()
+ self.run_validators(data)
return self.to_native(data)
+ def run_validators(self, value):
+ if value in validators.EMPTY_VALUES:
+ return
+ errors = []
+ for validator in self.validators:
+ try:
+ validator(value)
+ except ValidationError as exc:
+ errors.extend(exc.messages)
+ if errors:
+ raise ValidationError(errors)
+
def to_native(self, data):
"""
Transform the *incoming* primative data into a native value.
@@ -322,9 +345,13 @@ class IntegerField(Field):
}
def __init__(self, **kwargs):
- self.max_value = kwargs.pop('max_value')
- self.min_value = kwargs.pop('min_value')
- super(CharField, self).__init__(**kwargs)
+ max_value = kwargs.pop('max_value', None)
+ min_value = kwargs.pop('min_value', None)
+ super(IntegerField, self).__init__(**kwargs)
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
def to_native(self, data):
try:
@@ -392,3 +419,49 @@ class MethodField(Field):
attr = 'get_{field_name}'.format(field_name=self.field_name)
method = getattr(self.parent, attr)
return method(value)
+
+
+class ModelField(Field):
+ """
+ A generic field that can be used against an arbitrary model field.
+ """
+ def __init__(self, *args, **kwargs):
+ try:
+ self.model_field = kwargs.pop('model_field')
+ except KeyError:
+ raise ValueError("ModelField requires 'model_field' kwarg")
+
+ self.min_length = kwargs.pop('min_length',
+ getattr(self.model_field, 'min_length', None))
+ self.max_length = kwargs.pop('max_length',
+ getattr(self.model_field, 'max_length', None))
+ self.min_value = kwargs.pop('min_value',
+ getattr(self.model_field, 'min_value', None))
+ self.max_value = kwargs.pop('max_value',
+ getattr(self.model_field, 'max_value', None))
+
+ super(ModelField, self).__init__(*args, **kwargs)
+
+ if self.min_length is not None:
+ self.validators.append(validators.MinLengthValidator(self.min_length))
+ if self.max_length is not None:
+ self.validators.append(validators.MaxLengthValidator(self.max_length))
+ if self.min_value is not None:
+ self.validators.append(validators.MinValueValidator(self.min_value))
+ if self.max_value is not None:
+ self.validators.append(validators.MaxValueValidator(self.max_value))
+
+ def get_attribute(self, instance):
+ return get_attribute(instance, self.source_attrs[:-1])
+
+ def to_native(self, data):
+ rel = getattr(self.model_field, 'rel', None)
+ if rel is not None:
+ return rel.to._meta.get_field(rel.field_name).to_python(data)
+ return self.model_field.to_python(data)
+
+ def to_primative(self, obj):
+ value = self.model_field._get_val_from_obj(obj)
+ if is_protected_type(value):
+ return value
+ return self.model_field.value_to_string(obj)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 49eb6ce9..93226d32 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -10,10 +10,10 @@ python primitives.
2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
+from django.core.exceptions import ValidationError
from django.db import models
from django.utils import six
from collections import namedtuple, OrderedDict
-from rest_framework.exceptions import ValidationError
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
from rest_framework.utils import html
@@ -58,13 +58,14 @@ class BaseSerializer(Field):
raise NotImplementedError('`create()` must be implemented.')
def save(self, extras=None):
+ attrs = self.validated_data
if extras is not None:
- self.validated_data.update(extras)
+ attrs = dict(list(attrs.items()) + list(extras.items()))
if self.instance is not None:
- self.update(self.instance, self._validated_data)
+ self.update(self.instance, attrs)
else:
- self.instance = self.create(self._validated_data)
+ self.instance = self.create(attrs)
return self.instance
@@ -74,7 +75,7 @@ class BaseSerializer(Field):
self._validated_data = self.to_native(self._initial_data)
except ValidationError as exc:
self._validated_data = {}
- self._errors = exc.detail
+ self._errors = exc.message_dict
else:
self._errors = {}
@@ -210,7 +211,7 @@ class Serializer(BaseSerializer):
if validate_method is not None:
validated_value = validate_method(validated_value)
except ValidationError as exc:
- errors[field.field_name] = str(exc)
+ errors[field.field_name] = exc.messages
except SkipField:
pass
else:
@@ -219,8 +220,10 @@ class Serializer(BaseSerializer):
if errors:
raise ValidationError(errors)
- # TODO: 'Non field errors'
- return self.validate(ret)
+ try:
+ return self.validate(ret)
+ except ValidationError, exc:
+ raise ValidationError({'non_field_errors': exc.messages})
def to_primative(self, instance):
"""
@@ -539,6 +542,9 @@ class ModelSerializer(Serializer):
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
+ if model_field.validators is not None:
+ kwargs['validators'] = model_field.validators
+
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
@@ -577,8 +583,7 @@ class ModelSerializer(Serializer):
try:
return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
- # TODO: Change this to `return ModelField(model_field=model_field, **kwargs)`
- return CharField(**kwargs)
+ return ModelField(model_field=model_field, **kwargs)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 23df3443..079e9285 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework.
"""
from __future__ import unicode_literals
-from django.core.exceptions import PermissionDenied
+from django.core.exceptions import PermissionDenied, ValidationError
from django.http import Http404
from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt
@@ -51,7 +51,8 @@ def exception_handler(exc):
Returns the response that should be used for any given exception.
By default we handle the REST framework `APIException`, and also
- Django's builtin `Http404` and `PermissionDenied` exceptions.
+ Django's built-in `ValidationError`, `Http404` and `PermissionDenied`
+ exceptions.
Any unhandled exceptions may return `None`, which will cause a 500 error
to be raised.
@@ -68,6 +69,10 @@ def exception_handler(exc):
status=exc.status_code,
headers=headers)
+ elif isinstance(exc, ValidationError):
+ return Response(exc.message_dict,
+ status=status.HTTP_400_BAD_REQUEST)
+
elif isinstance(exc, Http404):
return Response({'detail': 'Not found'},
status=status.HTTP_404_NOT_FOUND)
--
cgit v1.2.3
From b1c07670ca65084c5fef2bbb63d1f4163763014b Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Tue, 9 Sep 2014 17:46:28 +0100
Subject: Fleshing out serializer fields
---
rest_framework/fields.py | 591 +++++++++++++++++++++++-------
rest_framework/serializers.py | 380 +++++++++----------
rest_framework/utils/humanize_datetime.py | 47 +++
rest_framework/utils/modelinfo.py | 97 +++++
rest_framework/utils/representation.py | 72 ++++
5 files changed, 872 insertions(+), 315 deletions(-)
create mode 100644 rest_framework/utils/humanize_datetime.py
create mode 100644 rest_framework/utils/modelinfo.py
create mode 100644 rest_framework/utils/representation.py
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 250c0579..043a44ed 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,8 +1,18 @@
+from django.conf import settings
from django.core import validators
from django.core.exceptions import ValidationError
+from django.utils import timezone
+from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import is_protected_type
-from rest_framework.utils import html
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import ISO_8601
+from rest_framework.compat import smart_text
+from rest_framework.settings import api_settings
+from rest_framework.utils import html, representation, humanize_datetime
+import datetime
+import decimal
import inspect
+import warnings
class empty:
@@ -71,22 +81,22 @@ class SkipField(Exception):
pass
+NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
+NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
+NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
+NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
+MISSING_ERROR_MESSAGE = (
+ 'ValidationError raised by `{class_name}`, but error key `{key}` does '
+ 'not exist in the `error_messages` dictionary.'
+)
+
+
class Field(object):
_creation_counter = 0
- MESSAGES = {
- 'required': 'This field is required.'
+ default_error_messages = {
+ 'required': _('This field is required.')
}
-
- _NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
- _NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
- _NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
- _NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
- _MISSING_ERROR_MESSAGE = (
- 'ValidationError raised by `{class_name}`, but error key `{key}` does '
- 'not exist in the `MESSAGES` dictionary.'
- )
-
default_validators = []
def __init__(self, read_only=False, write_only=False,
@@ -100,10 +110,10 @@ class Field(object):
required = default is empty and not read_only
# Some combinations of keyword arguments do not make sense.
- assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY
- assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED
- assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT
- assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT
+ assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY
+ assert not (read_only and required), NOT_READ_ONLY_REQUIRED
+ assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
+ assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
self.read_only = read_only
self.write_only = write_only
@@ -113,7 +123,14 @@ class Field(object):
self.initial = initial
self.label = label
self.style = {} if style is None else style
- self.validators = self.default_validators + validators
+ self.validators = validators or self.default_validators[:]
+
+ # Collect default error message from self and parent classes
+ messages = {}
+ for cls in reversed(self.__class__.__mro__):
+ messages.update(getattr(cls, 'default_error_messages', {}))
+ messages.update(error_messages or {})
+ self.error_messages = messages
def bind(self, field_name, parent, root):
"""
@@ -186,12 +203,14 @@ class Field(object):
self.fail('required')
return self.get_default()
- self.run_validators(data)
- return self.to_native(data)
+ value = self.to_native(data)
+ self.run_validators(value)
+ return value
def run_validators(self, value):
if value in validators.EMPTY_VALUES:
return
+
errors = []
for validator in self.validators:
try:
@@ -218,33 +237,32 @@ class Field(object):
A helper method that simply raises a validation error.
"""
try:
- raise ValidationError(self.MESSAGES[key].format(**kwargs))
+ msg = self.error_messages[key]
except KeyError:
class_name = self.__class__.__name__
- msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
+ msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg)
+ raise ValidationError(msg.format(**kwargs))
def __new__(cls, *args, **kwargs):
+ """
+ When a field is instantiated, we store the arguments that were used,
+ so that we can present a helpful representation of the object.
+ """
instance = super(Field, cls).__new__(cls)
instance._args = args
instance._kwargs = kwargs
return instance
def __repr__(self):
- arg_string = ', '.join([repr(val) for val in self._args])
- kwarg_string = ', '.join([
- '%s=%s' % (key, repr(val)) for key, val in self._kwargs.items()
- ])
- if arg_string and kwarg_string:
- arg_string += ', '
- class_name = self.__class__.__name__
- return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
+ return representation.field_repr(self)
+# Boolean types...
+
class BooleanField(Field):
- MESSAGES = {
- 'required': 'This field is required.',
- 'invalid_value': '`{input}` is not a valid boolean.'
+ default_error_messages = {
+ 'invalid': _('`{input}` is not a valid boolean.')
}
TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True}
FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False}
@@ -261,13 +279,23 @@ class BooleanField(Field):
return True
elif data in self.FALSE_VALUES:
return False
- self.fail('invalid_value', input=data)
+ self.fail('invalid', input=data)
+
+ def to_primative(self, value):
+ if value is None:
+ return None
+ if value in self.TRUE_VALUES:
+ return True
+ elif value in self.FALSE_VALUES:
+ return False
+ return bool(value)
+# String types...
+
class CharField(Field):
- MESSAGES = {
- 'required': 'This field is required.',
- 'blank': 'This field may not be blank.'
+ default_error_messages = {
+ 'blank': _('This field may not be blank.')
}
def __init__(self, **kwargs):
@@ -281,19 +309,364 @@ class CharField(Field):
self.fail('blank')
return str(data)
+ def to_primative(self, value):
+ if value is None:
+ return None
+ return str(value)
-class ChoiceField(Field):
- MESSAGES = {
- 'required': 'This field is required.',
- 'invalid_choice': '`{input}` is not a valid choice.'
+
+class EmailField(CharField):
+ default_error_messages = {
+ 'invalid': _('Enter a valid email address.')
+ }
+ default_validators = [validators.validate_email]
+
+ def to_native(self, data):
+ ret = super(EmailField, self).to_native(data)
+ if ret is None:
+ return None
+ return ret.strip()
+
+ def to_primative(self, value):
+ ret = super(EmailField, self).to_primative(value)
+ if ret is None:
+ return None
+ return ret.strip()
+
+
+class RegexField(CharField):
+ def __init__(self, regex, **kwargs):
+ kwargs['validators'] = (
+ [validators.RegexValidator(regex)] +
+ kwargs.get('validators', [])
+ )
+ super(RegexField, self).__init__(**kwargs)
+
+
+class SlugField(CharField):
+ default_error_messages = {
+ 'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")
+ }
+ default_validators = [validators.validate_slug]
+
+
+class URLField(CharField):
+ default_error_messages = {
+ 'invalid': _("Enter a valid URL.")
+ }
+ default_validators = [validators.URLValidator()]
+
+
+# Number types...
+
+class IntegerField(Field):
+ default_error_messages = {
+ 'invalid': _('A valid integer is required.')
+ }
+
+ def __init__(self, **kwargs):
+ max_value = kwargs.pop('max_value', None)
+ min_value = kwargs.pop('min_value', None)
+ super(IntegerField, self).__init__(**kwargs)
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+ print self.__class__.__name__, self.validators
+
+ def to_native(self, data):
+ try:
+ data = int(str(data))
+ except (ValueError, TypeError):
+ self.fail('invalid')
+ return data
+
+ def to_primative(self, value):
+ if value is None:
+ return None
+ return int(value)
+
+
+class FloatField(Field):
+ default_error_messages = {
+ 'invalid': _("'%s' value must be a float."),
}
- coerce_to_type = str
def __init__(self, **kwargs):
- choices = kwargs.pop('choices')
+ max_value = kwargs.pop('max_value', None)
+ min_value = kwargs.pop('min_value', None)
+ super(FloatField, self).__init__(**kwargs)
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+
+ def to_primative(self, value):
+ if value is None:
+ return None
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ self.fail('invalid', value=value)
+
+ def to_native(self, value):
+ if value is None:
+ return None
+ return float(value)
+
+
+class DecimalField(Field):
+ default_error_messages = {
+ 'invalid': _('Enter a number.'),
+ 'max_value': _('Ensure this value is less than or equal to {max_value}.'),
+ 'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
+ 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
+ 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
+ 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
+ }
+
+ def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs):
+ self.max_value, self.min_value = max_value, min_value
+ self.max_digits, self.max_decimal_places = max_digits, decimal_places
+ super(DecimalField, self).__init__(**kwargs)
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+
+ def from_native(self, value):
+ """
+ Validates that the input is a decimal number. Returns a Decimal
+ instance. Returns None for empty values. Ensures that there are no more
+ than max_digits in the number, and no more than decimal_places digits
+ after the decimal point.
+ """
+ if value in validators.EMPTY_VALUES:
+ return None
+
+ value = smart_text(value).strip()
+ try:
+ value = decimal.Decimal(value)
+ except decimal.DecimalException:
+ self.fail('invalid')
+
+ # Check for NaN. It is the only value that isn't equal to itself,
+ # so we can use this to identify NaN values.
+ if value != value:
+ self.fail('invalid')
+
+ # Check for infinity and negative infinity.
+ if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')):
+ self.fail('invalid')
+
+ sign, digittuple, exponent = value.as_tuple()
+ decimals = abs(exponent)
+ # digittuple doesn't include any leading zeros.
+ digits = len(digittuple)
+ if decimals > digits:
+ # We have leading zeros up to or past the decimal point. Count
+ # everything past the decimal point as a digit. We do not count
+ # 0 before the decimal point as a digit since that would mean
+ # we would not allow max_digits = decimal_places.
+ digits = decimals
+ whole_digits = digits - decimals
+
+ if self.max_digits is not None and digits > self.max_digits:
+ self.fail('max_digits', max_digits=self.max_digits)
+ if self.decimal_places is not None and decimals > self.decimal_places:
+ self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places)
+ if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
+ self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
+
+ return value
+
+
+# Date & time fields...
+
+class DateField(Field):
+ default_error_messages = {
+ 'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
+ }
+ input_formats = api_settings.DATE_INPUT_FORMATS
+ format = api_settings.DATE_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(DateField, self).__init__(*args, **kwargs)
+
+ def from_native(self, value):
+ if value in validators.EMPTY_VALUES:
+ return None
+
+ if isinstance(value, datetime.datetime):
+ if timezone and settings.USE_TZ and timezone.is_aware(value):
+ # Convert aware datetimes to the default time zone
+ # before casting them to dates (#17742).
+ default_timezone = timezone.get_default_timezone()
+ value = timezone.make_naive(value, default_timezone)
+ return value.date()
+ if isinstance(value, datetime.date):
+ return value
+
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_date(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed.date()
+
+ humanized_format = humanize_datetime.date_formats(self.input_formats)
+ msg = self.error_messages['invalid'] % humanized_format
+ raise ValidationError(msg)
+
+ def to_primative(self, value):
+ if value is None or self.format is None:
+ return value
+
+ if isinstance(value, datetime.datetime):
+ value = value.date()
+
+ if self.format.lower() == ISO_8601:
+ return value.isoformat()
+ return value.strftime(self.format)
+
+
+class DateTimeField(Field):
+ default_error_messages = {
+ 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
+ }
+ input_formats = api_settings.DATETIME_INPUT_FORMATS
+ format = api_settings.DATETIME_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(DateTimeField, self).__init__(*args, **kwargs)
+
+ def from_native(self, value):
+ if value in validators.EMPTY_VALUES:
+ return None
+
+ if isinstance(value, datetime.datetime):
+ return value
+ if isinstance(value, datetime.date):
+ value = datetime.datetime(value.year, value.month, value.day)
+ if settings.USE_TZ:
+ # For backwards compatibility, interpret naive datetimes in
+ # local time. This won't work during DST change, but we can't
+ # do much about it, so we let the exceptions percolate up the
+ # call stack.
+ warnings.warn("DateTimeField received a naive datetime (%s)"
+ " while time zone support is active." % value,
+ RuntimeWarning)
+ default_timezone = timezone.get_default_timezone()
+ value = timezone.make_aware(value, default_timezone)
+ return value
+
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_datetime(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed
+
+ humanized_format = humanize_datetime.datetime_formats(self.input_formats)
+ msg = self.error_messages['invalid'] % humanized_format
+ raise ValidationError(msg)
+
+ def to_primative(self, value):
+ if value is None or self.format is None:
+ return value
+
+ if self.format.lower() == ISO_8601:
+ ret = value.isoformat()
+ if ret.endswith('+00:00'):
+ ret = ret[:-6] + 'Z'
+ return ret
+ return value.strftime(self.format)
- assert choices, '`choices` argument is required and may not be empty'
+class TimeField(Field):
+ default_error_messages = {
+ 'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
+ }
+ input_formats = api_settings.TIME_INPUT_FORMATS
+ format = api_settings.TIME_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(TimeField, self).__init__(*args, **kwargs)
+
+ def from_native(self, value):
+ if value in validators.EMPTY_VALUES:
+ return None
+
+ if isinstance(value, datetime.time):
+ return value
+
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_time(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed.time()
+
+ humanized_format = humanize_datetime.time_formats(self.input_formats)
+ msg = self.error_messages['invalid'] % humanized_format
+ raise ValidationError(msg)
+
+ def to_primative(self, value):
+ if value is None or self.format is None:
+ return value
+
+ if isinstance(value, datetime.datetime):
+ value = value.time()
+
+ if self.format.lower() == ISO_8601:
+ return value.isoformat()
+ return value.strftime(self.format)
+
+
+# Choice types...
+
+class ChoiceField(Field):
+ default_error_messages = {
+ 'invalid_choice': _('`{input}` is not a valid choice.')
+ }
+
+ def __init__(self, choices, **kwargs):
# Allow either single or paired choices style:
# choices = [1, 2, 3]
# choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
@@ -321,12 +694,14 @@ class ChoiceField(Field):
except KeyError:
self.fail('invalid_choice', input=data)
+ def to_primative(self, value):
+ return value
+
class MultipleChoiceField(ChoiceField):
- MESSAGES = {
- 'required': 'This field is required.',
- 'invalid_choice': '`{input}` is not a valid choice.',
- 'not_a_list': 'Expected a list of items but got type `{input_type}`'
+ default_error_messages = {
+ 'invalid_choice': _('`{input}` is not a valid choice.'),
+ 'not_a_list': _('Expected a list of items but got type `{input_type}`')
}
def to_native(self, data):
@@ -337,72 +712,42 @@ class MultipleChoiceField(ChoiceField):
for item in data
])
-
-class IntegerField(Field):
- MESSAGES = {
- 'required': 'This field is required.',
- 'invalid_integer': 'A valid integer is required.'
- }
-
- def __init__(self, **kwargs):
- max_value = kwargs.pop('max_value', None)
- min_value = kwargs.pop('min_value', None)
- super(IntegerField, self).__init__(**kwargs)
- if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
- if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
-
- def to_native(self, data):
- try:
- data = int(str(data))
- except (ValueError, TypeError):
- self.fail('invalid_integer')
- return data
-
def to_primative(self, value):
- if value is None:
- return None
- return int(value)
+ return value
-class EmailField(CharField):
+# File types...
+
+class FileField(Field):
pass # TODO
-class URLField(CharField):
+class ImageField(Field):
pass # TODO
-class RegexField(CharField):
- def __init__(self, **kwargs):
- self.regex = kwargs.pop('regex')
- super(CharField, self).__init__(**kwargs)
-
+# Advanced field types...
-class DateField(CharField):
- def __init__(self, **kwargs):
- self.input_formats = kwargs.pop('input_formats', None)
- super(DateField, self).__init__(**kwargs)
+class ReadOnlyField(Field):
+ """
+ A read-only field that simply returns the field value.
+ If the field is a method with no parameters, the method will be called
+ and it's return value used as the representation.
-class TimeField(CharField):
- def __init__(self, **kwargs):
- self.input_formats = kwargs.pop('input_formats', None)
- super(TimeField, self).__init__(**kwargs)
+ For example, the following would call `get_expiry_date()` on the object:
+ class ExampleSerializer(self):
+ expiry_date = ReadOnlyField(source='get_expiry_date')
+ """
-class DateTimeField(CharField):
def __init__(self, **kwargs):
- self.input_formats = kwargs.pop('input_formats', None)
- super(DateTimeField, self).__init__(**kwargs)
-
-
-class FileField(Field):
- pass # TODO
+ kwargs['read_only'] = True
+ super(ReadOnlyField, self).__init__(**kwargs)
+ def to_native(self, data):
+ raise NotImplemented('.to_native() not supported.')
-class ReadOnlyField(Field):
def to_primative(self, value):
if is_simple_callable(value):
return value()
@@ -410,11 +755,28 @@ class ReadOnlyField(Field):
class MethodField(Field):
+ """
+ A read-only field that get its representation from calling a method on the
+ parent serializer class. The method called will be of the form
+ "get_{field_name}", and should take a single argument, which is the
+ object being serialized.
+
+ For example:
+
+ class ExampleSerializer(self):
+ extra_info = MethodField()
+
+ def get_extra_info(self, obj):
+ return ... # Calculate some data to return.
+ """
def __init__(self, **kwargs):
kwargs['source'] = '*'
kwargs['read_only'] = True
super(MethodField, self).__init__(**kwargs)
+ def to_native(self, data):
+ raise NotImplemented('.to_native() not supported.')
+
def to_primative(self, value):
attr = 'get_{field_name}'.format(field_name=self.field_name)
method = getattr(self.parent, attr)
@@ -424,35 +786,14 @@ class MethodField(Field):
class ModelField(Field):
"""
A generic field that can be used against an arbitrary model field.
- """
- def __init__(self, *args, **kwargs):
- try:
- self.model_field = kwargs.pop('model_field')
- except KeyError:
- raise ValueError("ModelField requires 'model_field' kwarg")
-
- self.min_length = kwargs.pop('min_length',
- getattr(self.model_field, 'min_length', None))
- self.max_length = kwargs.pop('max_length',
- getattr(self.model_field, 'max_length', None))
- self.min_value = kwargs.pop('min_value',
- getattr(self.model_field, 'min_value', None))
- self.max_value = kwargs.pop('max_value',
- getattr(self.model_field, 'max_value', None))
-
- super(ModelField, self).__init__(*args, **kwargs)
-
- if self.min_length is not None:
- self.validators.append(validators.MinLengthValidator(self.min_length))
- if self.max_length is not None:
- self.validators.append(validators.MaxLengthValidator(self.max_length))
- if self.min_value is not None:
- self.validators.append(validators.MinValueValidator(self.min_value))
- if self.max_value is not None:
- self.validators.append(validators.MaxValueValidator(self.max_value))
- def get_attribute(self, instance):
- return get_attribute(instance, self.source_attrs[:-1])
+ This is used by `ModelSerializer` when dealing with custom model fields,
+ that do not have a serializer field to be mapped to.
+ """
+ def __init__(self, model_field, **kwargs):
+ self.model_field = model_field
+ kwargs['source'] = '*'
+ super(ModelField, self).__init__(**kwargs)
def to_native(self, data):
rel = getattr(self.model_field, 'rel', None)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 93226d32..8ca28387 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -10,15 +10,15 @@ python primitives.
2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
+from django.core import validators
from django.core.exceptions import ValidationError
from django.db import models
from django.utils import six
from collections import namedtuple, OrderedDict
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
-from rest_framework.utils import html
+from rest_framework.utils import html, modelinfo, representation
import copy
-import inspect
# Note: We do the following so that users of the framework can use this style:
#
@@ -146,12 +146,10 @@ class SerializerMetaclass(type):
class Serializer(BaseSerializer):
def __new__(cls, *args, **kwargs):
- many = kwargs.pop('many', False)
- if many:
- class DynamicListSerializer(ListSerializer):
- child = cls()
- return DynamicListSerializer(*args, **kwargs)
- return super(Serializer, cls).__new__(cls)
+ if kwargs.pop('many', False):
+ kwargs['child'] = cls()
+ return ListSerializer(*args, **kwargs)
+ return super(Serializer, cls).__new__(cls, *args, **kwargs)
def __init__(self, *args, **kwargs):
self.context = kwargs.pop('context', {})
@@ -248,6 +246,9 @@ class Serializer(BaseSerializer):
error = errors.get(field.field_name)
yield FieldResult(field, value, error)
+ def __repr__(self):
+ return representation.serializer_repr(self, indent=1)
+
class ListSerializer(BaseSerializer):
child = None
@@ -299,26 +300,8 @@ class ListSerializer(BaseSerializer):
self.instance = self.create(self.validated_data)
return self.instance
-
-def _resolve_model(obj):
- """
- Resolve supplied `obj` to a Django model class.
-
- `obj` must be a Django model class itself, or a string
- representation of one. Useful in situtations like GH #1225 where
- Django may not have resolved a string-based reference to a model in
- another model's foreign key definition.
-
- String representations should have the format:
- 'appname.ModelName'
- """
- if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
- app_name, model_name = obj.split('.')
- return models.get_model(app_name, model_name)
- elif inspect.isclass(obj) and issubclass(obj, models.Model):
- return obj
- else:
- raise ValueError("{0} is not a Django model".format(obj))
+ def __repr__(self):
+ return representation.list_repr(self, indent=1)
class ModelSerializerOptions(object):
@@ -334,24 +317,25 @@ class ModelSerializerOptions(object):
class ModelSerializer(Serializer):
field_mapping = {
models.AutoField: IntegerField,
- # models.FloatField: FloatField,
+ models.BigIntegerField: IntegerField,
+ models.BooleanField: BooleanField,
+ models.CharField: CharField,
+ models.CommaSeparatedIntegerField: CharField,
+ models.DateField: DateField,
+ models.DateTimeField: DateTimeField,
+ models.DecimalField: DecimalField,
+ models.EmailField: EmailField,
+ models.FileField: FileField,
+ models.FloatField: FloatField,
models.IntegerField: IntegerField,
+ models.NullBooleanField: BooleanField,
models.PositiveIntegerField: IntegerField,
- models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
- models.DateTimeField: DateTimeField,
- models.DateField: DateField,
+ models.SlugField: SlugField,
+ models.SmallIntegerField: IntegerField,
+ models.TextField: CharField,
models.TimeField: TimeField,
- # models.DecimalField: DecimalField,
- models.EmailField: EmailField,
- models.CharField: CharField,
models.URLField: URLField,
- # models.SlugField: SlugField,
- models.TextField: CharField,
- models.CommaSeparatedIntegerField: CharField,
- models.BooleanField: BooleanField,
- models.NullBooleanField: BooleanField,
- models.FileField: FileField,
# models.ImageField: ImageField,
}
@@ -392,85 +376,31 @@ class ModelSerializer(Serializer):
"""
Return all the fields that should be serialized for the model.
"""
- cls = self.opts.model
- opts = cls._meta.concrete_model._meta
+ info = modelinfo.get_field_info(self.opts.model)
ret = OrderedDict()
- nested = bool(self.opts.depth)
- # Deal with adding the primary key field
- pk_field = opts.pk
- while pk_field.rel and pk_field.rel.parent_link:
- # If model is a child via multitable inheritance, use parent's pk
- pk_field = pk_field.rel.to._meta.pk
-
- serializer_pk_field = self.get_pk_field(pk_field)
+ serializer_pk_field = self.get_pk_field(info.pk)
if serializer_pk_field:
- ret[pk_field.name] = serializer_pk_field
-
- # Deal with forward relationships
- forward_rels = [field for field in opts.fields if field.serialize]
- forward_rels += [field for field in opts.many_to_many if field.serialize]
-
- for model_field in forward_rels:
- has_through_model = False
-
- if model_field.rel:
- to_many = isinstance(model_field,
- models.fields.related.ManyToManyField)
- related_model = _resolve_model(model_field.rel.to)
-
- if to_many and not model_field.rel.through._meta.auto_created:
- has_through_model = True
+ ret[info.pk.name] = serializer_pk_field
- if model_field.rel and nested:
- field = self.get_nested_field(model_field, related_model, to_many)
- elif model_field.rel:
- field = self.get_related_field(model_field, related_model, to_many)
- else:
- field = self.get_field(model_field)
-
- if field:
- if has_through_model:
- field.read_only = True
+ # Regular fields
+ for field_name, field in info.fields.items():
+ ret[field_name] = self.get_field(field)
- ret[model_field.name] = field
-
- # Deal with reverse relationships
- if not self.opts.fields:
- reverse_rels = []
- else:
- # Reverse relationships are only included if they are explicitly
- # present in the `fields` option on the serializer
- reverse_rels = opts.get_all_related_objects()
- reverse_rels += opts.get_all_related_many_to_many_objects()
-
- for relation in reverse_rels:
- accessor_name = relation.get_accessor_name()
- if not self.opts.fields or accessor_name not in self.opts.fields:
- continue
- related_model = relation.model
- to_many = relation.field.rel.multiple
- has_through_model = False
- is_m2m = isinstance(relation.field,
- models.fields.related.ManyToManyField)
-
- if (
- is_m2m and
- hasattr(relation.field.rel, 'through') and
- not relation.field.rel.through._meta.auto_created
- ):
- has_through_model = True
-
- if nested:
- field = self.get_nested_field(None, related_model, to_many)
+ # Forward relations
+ for field_name, relation_info in info.forward_relations.items():
+ if self.opts.depth:
+ ret[field_name] = self.get_nested_field(*relation_info)
else:
- field = self.get_related_field(None, related_model, to_many)
+ ret[field_name] = self.get_related_field(*relation_info)
- if field:
- if has_through_model:
- field.read_only = True
-
- ret[accessor_name] = field
+ # Reverse relations
+ for accessor_name, relation_info in info.reverse_relations.items():
+ if accessor_name in self.opts.fields:
+ if self.opts.depth:
+ ret[field_name] = self.get_nested_field(*relation_info)
+ else:
+ ret[field_name] = self.get_related_field(*relation_info)
return ret
@@ -480,7 +410,7 @@ class ModelSerializer(Serializer):
"""
return self.get_field(model_field)
- def get_nested_field(self, model_field, related_model, to_many):
+ def get_nested_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a nested relational field.
@@ -491,59 +421,148 @@ class ModelSerializer(Serializer):
model = related_model
depth = self.opts.depth - 1
- return NestedModelSerializer(many=to_many)
+ kwargs = {'read_only': True}
+ if to_many:
+ kwargs['many'] = True
+ return NestedModelSerializer(**kwargs)
- def get_related_field(self, model_field, related_model, to_many):
+ def get_related_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
Note that model_field will be `None` for reverse relationships.
"""
- # TODO: filter queryset using:
- # .using(db).complex_filter(self.rel.limit_choices_to)
+ kwargs = {
+ 'queryset': related_model._default_manager,
+ }
- kwargs = {}
- # 'queryset': related_model._default_manager,
- # 'many': to_many
- # }
+ if to_many:
+ kwargs['many'] = True
+
+ if has_through_model:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
if model_field:
- kwargs['required'] = not(model_field.null or model_field.blank)
+ if model_field.null or model_field.blank:
+ kwargs['required'] = False
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if not model_field.editable:
kwargs['read_only'] = True
- if model_field.verbose_name is not None:
- kwargs['label'] = model_field.verbose_name
+ kwargs.pop('queryset', None)
- return IntegerField(**kwargs)
- # TODO: return PrimaryKeyRelatedField(**kwargs)
+ return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
+ validator_kwarg = model_field.validators
if model_field.null or model_field.blank:
kwargs['required'] = False
+ if model_field.verbose_name is not None:
+ kwargs['label'] = model_field.verbose_name
+
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
+ # Read only implies that the field is not required.
+ # We have a cleaner repr on the instance if we don't set it.
+ kwargs.pop('required', None)
if model_field.has_default():
kwargs['default'] = model_field.get_default()
-
- if issubclass(model_field.__class__, models.TextField):
- kwargs['widget'] = widgets.Textarea
-
- if model_field.verbose_name is not None:
- kwargs['label'] = model_field.verbose_name
-
- if model_field.validators is not None:
- kwargs['validators'] = model_field.validators
+ # Having a default implies that the field is not required.
+ # We have a cleaner repr on the instance if we don't set it.
+ kwargs.pop('required', None)
+
+ # Ensure that max_length is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ max_length = getattr(model_field, 'max_length', None)
+ if max_length is not None:
+ kwargs['max_length'] = max_length
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MaxLengthValidator)
+ ]
+
+ # Ensure that min_length is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ min_length = getattr(model_field, 'min_length', None)
+ if min_length is not None:
+ kwargs['min_length'] = min_length
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MinLengthValidator)
+ ]
+
+ # Ensure that max_value is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ max_value = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MaxValueValidator)
+ ), None)
+ if max_value is not None:
+ kwargs['max_value'] = max_value
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MaxValueValidator)
+ ]
+
+ # Ensure that max_value is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ min_value = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MinValueValidator)
+ ), None)
+ if min_value is not None:
+ kwargs['min_value'] = min_value
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MinValueValidator)
+ ]
+
+ # URLField does not need to include the URLValidator argument,
+ # as it is explicitly added in.
+ if isinstance(model_field, models.URLField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.URLValidator)
+ ]
+
+ # EmailField does not need to include the validate_email argument,
+ # as it is explicitly added in.
+ if isinstance(model_field, models.EmailField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if validator is not validators.validate_email
+ ]
+
+ # SlugField do not need to include the 'validate_slug' argument,
+ if isinstance(model_field, models.SlugField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if validator is not validators.validate_slug
+ ]
+
+ max_digits = getattr(model_field, 'max_digits', None)
+ if max_digits is not None:
+ kwargs['max_digits'] = max_digits
+
+ decimal_places = getattr(model_field, 'decimal_places', None)
+ if decimal_places is not None:
+ kwargs['decimal_places'] = decimal_places
+
+ if validator_kwarg:
+ kwargs['validators'] = validator_kwarg
+
+ # if issubclass(model_field.__class__, models.TextField):
+ # kwargs['widget'] = widgets.Textarea
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
@@ -555,31 +574,10 @@ class ModelSerializer(Serializer):
kwargs['empty'] = None
return ChoiceField(**kwargs)
- # put this below the ChoiceField because min_value isn't a valid initializer
- if issubclass(model_field.__class__, models.PositiveIntegerField) or \
- issubclass(model_field.__class__, models.PositiveSmallIntegerField):
- kwargs['min_value'] = 0
-
if model_field.null and \
issubclass(model_field.__class__, (models.CharField, models.TextField)):
kwargs['allow_none'] = True
- # attribute_dict = {
- # models.CharField: ['max_length'],
- # models.CommaSeparatedIntegerField: ['max_length'],
- # models.DecimalField: ['max_digits', 'decimal_places'],
- # models.EmailField: ['max_length'],
- # models.FileField: ['max_length'],
- # models.ImageField: ['max_length'],
- # models.SlugField: ['max_length'],
- # models.URLField: ['max_length'],
- # }
-
- # if model_field.__class__ in attribute_dict:
- # attributes = attribute_dict[model_field.__class__]
- # for attribute in attributes:
- # kwargs.update({attribute: getattr(model_field, attribute)})
-
try:
return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
@@ -594,28 +592,21 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
self.lookup_field = getattr(meta, 'lookup_field', None)
- self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME)
class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions
- _default_view_name = '%(model_name)s-detail'
- _hyperlink_field_class = HyperlinkedRelatedField
- _hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
if self.opts.view_name is None:
- self.opts.view_name = self._get_default_view_name(self.opts.model)
+ self.opts.view_name = self.get_default_view_name(self.opts.model)
- if self.opts.url_field_name not in fields:
- url_field = self._hyperlink_identify_field_class(
- view_name=self.opts.view_name,
- lookup_field=self.opts.lookup_field
- )
+ url_field_name = api_settings.URL_FIELD_NAME
+ if url_field_name not in fields:
ret = fields.__class__()
- ret[self.opts.url_field_name] = url_field
+ ret[url_field_name] = self.get_url_field()
ret.update(fields)
fields = ret
@@ -625,39 +616,48 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.fields and model_field.name in self.opts.fields:
return self.get_field(model_field)
- def get_related_field(self, model_field, related_model, to_many):
+ def get_url_field(self):
+ kwargs = {
+ 'view_name': self.get_default_view_name(self.opts.model)
+ }
+ if self.opts.lookup_field:
+ kwargs['lookup_field'] = self.opts.lookup_field
+ return HyperlinkedIdentityField(**kwargs)
+
+ def get_related_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
"""
- # TODO: filter queryset using:
- # .using(db).complex_filter(self.rel.limit_choices_to)
- # kwargs = {
- # 'queryset': related_model._default_manager,
- # 'view_name': self._get_default_view_name(related_model),
- # 'many': to_many
- # }
- kwargs = {}
+ kwargs = {
+ 'queryset': related_model._default_manager,
+ 'view_name': self.get_default_view_name(related_model),
+ }
+
+ if to_many:
+ kwargs['many'] = True
+
+ if has_through_model:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
if model_field:
- kwargs['required'] = not(model_field.null or model_field.blank)
+ if model_field.null or model_field.blank:
+ kwargs['required'] = False
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
+ if not model_field.editable:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
- return IntegerField(**kwargs)
- # if self.opts.lookup_field:
- # kwargs['lookup_field'] = self.opts.lookup_field
-
- # return self._hyperlink_field_class(**kwargs)
+ return HyperlinkedRelatedField(**kwargs)
- def _get_default_view_name(self, model):
+ def get_default_view_name(self, model):
"""
- Return the view name to use if 'view_name' is not specified in 'Meta'
+ Return the view name to use for related models.
"""
- model_meta = model._meta
- format_kwargs = {
- 'app_label': model_meta.app_label,
- 'model_name': model_meta.object_name.lower()
+ return '%(model_name)s-detail' % {
+ 'app_label': model._meta.app_label,
+ 'model_name': model._meta.object_name.lower()
}
- return self._default_view_name % format_kwargs
diff --git a/rest_framework/utils/humanize_datetime.py b/rest_framework/utils/humanize_datetime.py
new file mode 100644
index 00000000..649f2abc
--- /dev/null
+++ b/rest_framework/utils/humanize_datetime.py
@@ -0,0 +1,47 @@
+"""
+Helper functions that convert strftime formats into more readable representations.
+"""
+from rest_framework import ISO_8601
+
+
+def datetime_formats(formats):
+ format = ', '.join(formats).replace(
+ ISO_8601,
+ 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
+ )
+ return humanize_strptime(format)
+
+
+def date_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
+ return humanize_strptime(format)
+
+
+def time_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
+ return humanize_strptime(format)
+
+
+def humanize_strptime(format_string):
+ # Note that we're missing some of the locale specific mappings that
+ # don't really make sense.
+ mapping = {
+ "%Y": "YYYY",
+ "%y": "YY",
+ "%m": "MM",
+ "%b": "[Jan-Dec]",
+ "%B": "[January-December]",
+ "%d": "DD",
+ "%H": "hh",
+ "%I": "hh", # Requires '%p' to differentiate from '%H'.
+ "%M": "mm",
+ "%S": "ss",
+ "%f": "uuuuuu",
+ "%a": "[Mon-Sun]",
+ "%A": "[Monday-Sunday]",
+ "%p": "[AM|PM]",
+ "%z": "[+HHMM|-HHMM]"
+ }
+ for key, val in mapping.items():
+ format_string = format_string.replace(key, val)
+ return format_string
diff --git a/rest_framework/utils/modelinfo.py b/rest_framework/utils/modelinfo.py
new file mode 100644
index 00000000..c0513886
--- /dev/null
+++ b/rest_framework/utils/modelinfo.py
@@ -0,0 +1,97 @@
+"""
+Helper functions for returning the field information that is associated
+with a model class.
+"""
+from collections import namedtuple, OrderedDict
+from django.db import models
+from django.utils import six
+import inspect
+
+FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
+RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model'])
+
+
+def _resolve_model(obj):
+ """
+ Resolve supplied `obj` to a Django model class.
+
+ `obj` must be a Django model class itself, or a string
+ representation of one. Useful in situtations like GH #1225 where
+ Django may not have resolved a string-based reference to a model in
+ another model's foreign key definition.
+
+ String representations should have the format:
+ 'appname.ModelName'
+ """
+ if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
+ app_name, model_name = obj.split('.')
+ return models.get_model(app_name, model_name)
+ elif inspect.isclass(obj) and issubclass(obj, models.Model):
+ return obj
+ raise ValueError("{0} is not a Django model".format(obj))
+
+
+def get_field_info(model):
+ """
+ Given a model class, returns a `FieldInfo` instance containing metadata
+ about the various field types on the model.
+ """
+ opts = model._meta.concrete_model._meta
+
+ # Deal with the primary key.
+ pk = opts.pk
+ while pk.rel and pk.rel.parent_link:
+ # If model is a child via multitable inheritance, use parent's pk.
+ pk = pk.rel.to._meta.pk
+
+ # Deal with regular fields.
+ fields = OrderedDict()
+ for field in [field for field in opts.fields if field.serialize and not field.rel]:
+ fields[field.name] = field
+
+ # Deal with forward relationships.
+ forward_relations = OrderedDict()
+ for field in [field for field in opts.fields if field.serialize and field.rel]:
+ forward_relations[field.name] = RelationInfo(
+ field=field,
+ related=_resolve_model(field.rel.to),
+ to_many=False,
+ has_through_model=False
+ )
+
+ # Deal with forward many-to-many relationships.
+ for field in [field for field in opts.many_to_many if field.serialize]:
+ forward_relations[field.name] = RelationInfo(
+ field=field,
+ related=_resolve_model(field.rel.to),
+ to_many=True,
+ has_through_model=(
+ not field.rel.through._meta.auto_created
+ )
+ )
+
+ # Deal with reverse relationships.
+ reverse_relations = OrderedDict()
+ for relation in opts.get_all_related_objects():
+ accessor_name = relation.get_accessor_name()
+ reverse_relations[accessor_name] = RelationInfo(
+ field=None,
+ related=relation.model,
+ to_many=relation.field.rel.multiple,
+ has_through_model=False
+ )
+
+ # Deal with reverse many-to-many relationships.
+ for relation in opts.get_all_related_many_to_many_objects():
+ accessor_name = relation.get_accessor_name()
+ reverse_relations[accessor_name] = RelationInfo(
+ field=None,
+ related=relation.model,
+ to_many=True,
+ has_through_model=(
+ hasattr(relation.field.rel, 'through') and
+ not relation.field.rel.through._meta.auto_created
+ )
+ )
+
+ return FieldInfo(pk, fields, forward_relations, reverse_relations)
diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py
new file mode 100644
index 00000000..1de21597
--- /dev/null
+++ b/rest_framework/utils/representation.py
@@ -0,0 +1,72 @@
+"""
+Helper functions for creating user-friendly representations
+of serializer classes and serializer fields.
+"""
+import re
+
+
+def smart_repr(value):
+ value = repr(value)
+
+ # Representations like u'help text'
+ # should simply be presented as 'help text'
+ if value.startswith("u'") and value.endswith("'"):
+ return value[1:]
+
+ # Representations like
+ #
+ # Should be presented as
+ #
+ value = re.sub(' at 0x[0-9a-f]{8,10}>', '>', value)
+
+ return value
+
+
+def field_repr(field, force_many=False):
+ kwargs = field._kwargs
+ if force_many:
+ kwargs = kwargs.copy()
+ kwargs['many'] = True
+ kwargs.pop('child', None)
+
+ arg_string = ', '.join([smart_repr(val) for val in field._args])
+ kwarg_string = ', '.join([
+ '%s=%s' % (key, smart_repr(val))
+ for key, val in sorted(kwargs.items())
+ ])
+ if arg_string and kwarg_string:
+ arg_string += ', '
+
+ if force_many:
+ class_name = force_many.__class__.__name__
+ else:
+ class_name = field.__class__.__name__
+
+ return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
+
+
+def serializer_repr(serializer, indent, force_many=None):
+ ret = field_repr(serializer, force_many) + ':'
+ indent_str = ' ' * indent
+
+ if force_many:
+ fields = force_many.fields
+ else:
+ fields = serializer.fields
+
+ for field_name, field in fields.items():
+ ret += '\n' + indent_str + field_name + ' = '
+ if hasattr(field, 'fields'):
+ ret += serializer_repr(field, indent + 1)
+ elif hasattr(field, 'child'):
+ ret += list_repr(field, indent + 1)
+ else:
+ ret += field_repr(field)
+ return ret
+
+
+def list_repr(serializer, indent):
+ child = serializer.child
+ if hasattr(child, 'fields'):
+ return serializer_repr(serializer, indent, force_many=child)
+ return field_repr(serializer)
--
cgit v1.2.3
From 234369aefdf08d7d0161d851866990754c00d31f Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 10 Sep 2014 08:53:33 +0100
Subject: Tweaks
---
rest_framework/fields.py | 6 ++++--
rest_framework/serializers.py | 4 ++--
2 files changed, 6 insertions(+), 4 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 043a44ed..e2bd5700 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -190,7 +190,7 @@ class Field(object):
raise SkipField()
return self.default
- def validate(self, data=empty):
+ def validate_value(self, data=empty):
"""
Validate a simple representation and return the internal value.
@@ -506,6 +506,7 @@ class DateField(Field):
default_timezone = timezone.get_default_timezone()
value = timezone.make_naive(value, default_timezone)
return value.date()
+
if isinstance(value, datetime.date):
return value
@@ -560,6 +561,7 @@ class DateTimeField(Field):
if isinstance(value, datetime.datetime):
return value
+
if isinstance(value, datetime.date):
value = datetime.datetime(value.year, value.month, value.day)
if settings.USE_TZ:
@@ -675,7 +677,7 @@ class ChoiceField(Field):
for item in choices
]
if all(pairs):
- self.choices = {key: val for key, val in choices}
+ self.choices = {key: display_value for key, display_value in choices}
else:
self.choices = {item: item for item in choices}
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 8ca28387..0727b8cd 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -205,7 +205,7 @@ class Serializer(BaseSerializer):
validate_method = getattr(self, 'validate_' + field.field_name, None)
primitive_value = field.get_value(data)
try:
- validated_value = field.validate(primitive_value)
+ validated_value = field.validate_value(primitive_value)
if validate_method is not None:
validated_value = validate_method(validated_value)
except ValidationError as exc:
@@ -327,6 +327,7 @@ class ModelSerializer(Serializer):
models.EmailField: EmailField,
models.FileField: FileField,
models.FloatField: FloatField,
+ models.ImageField: ImageField,
models.IntegerField: IntegerField,
models.NullBooleanField: BooleanField,
models.PositiveIntegerField: IntegerField,
@@ -336,7 +337,6 @@ class ModelSerializer(Serializer):
models.TextField: CharField,
models.TimeField: TimeField,
models.URLField: URLField,
- # models.ImageField: ImageField,
}
_options_class = ModelSerializerOptions
--
cgit v1.2.3
From 01c8c0cad977fc0787dbfc78bd34f4fd37e613f4 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 10 Sep 2014 13:52:16 +0100
Subject: Added help_text argument to fields
---
rest_framework/compat.py | 23 ++++++------
rest_framework/fields.py | 5 ++-
rest_framework/serializers.py | 86 +++++++++++++++++++++----------------------
3 files changed, 56 insertions(+), 58 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index fa0f0bfb..70b38df9 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -39,6 +39,17 @@ except ImportError:
django_filters = None
+if django.VERSION >= (1, 6):
+ def clean_manytomany_helptext(text):
+ return text
+else:
+ # Up to version 1.5 many to many fields automatically suffix
+ # the `help_text` attribute with hardcoded text.
+ def clean_manytomany_helptext(text):
+ if text.endswith(' Hold down "Control", or "Command" on a Mac, to select more than one.'):
+ text = text[:-69]
+ return text
+
# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
# Fixes (#1712). We keep the try/except for the test suite.
guardian = None
@@ -99,18 +110,8 @@ def get_concrete_model(model_cls):
return model_cls
-# View._allowed_methods only present from 1.5 onwards
-if django.VERSION >= (1, 5):
- from django.views.generic import View
-else:
- from django.views.generic import View as DjangoView
-
- class View(DjangoView):
- def _allowed_methods(self):
- return [m.upper() for m in self.http_method_names if hasattr(self, m)]
-
-
# PATCH method is not implemented by Django
+from django.views.generic import View
if 'patch' not in View.http_method_names:
View.http_method_names = View.http_method_names + ['patch']
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index e2bd5700..e71dce1d 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -101,7 +101,8 @@ class Field(object):
def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None,
- label=None, style=None, error_messages=None, validators=[]):
+ label=None, help_text=None, style=None,
+ error_messages=None, validators=[]):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
@@ -122,6 +123,7 @@ class Field(object):
self.source = source
self.initial = initial
self.label = label
+ self.help_text = help_text
self.style = {} if style is None else style
self.validators = validators or self.default_validators[:]
@@ -372,7 +374,6 @@ class IntegerField(Field):
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
- print self.__class__.__name__, self.validators
def to_native(self, data):
try:
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 0727b8cd..459f8a8c 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -15,6 +15,7 @@ from django.core.exceptions import ValidationError
from django.db import models
from django.utils import six
from collections import namedtuple, OrderedDict
+from rest_framework.compat import clean_manytomany_helptext
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
from rest_framework.utils import html, modelinfo, representation
@@ -117,8 +118,9 @@ class SerializerMetaclass(type):
"""
This metaclass sets a dictionary named `base_fields` on the class.
- Any fields included as attributes on either the class or it's superclasses
- will be include in the `base_fields` dictionary.
+ Any instances of `Field` included as attributes on either the class
+ or on any of its superclasses will be include in the
+ `base_fields` dictionary.
"""
@classmethod
@@ -379,6 +381,10 @@ class ModelSerializer(Serializer):
info = modelinfo.get_field_info(self.opts.model)
ret = OrderedDict()
+ serializer_url_field = self.get_url_field()
+ if serializer_url_field:
+ ret[api_settings.URL_FIELD_NAME] = serializer_url_field
+
serializer_pk_field = self.get_pk_field(info.pk)
if serializer_pk_field:
ret[info.pk.name] = serializer_pk_field
@@ -404,6 +410,9 @@ class ModelSerializer(Serializer):
return ret
+ def get_url_field(self):
+ return None
+
def get_pk_field(self, model_field):
"""
Returns a default instance of the pk field.
@@ -446,13 +455,14 @@ class ModelSerializer(Serializer):
if model_field:
if model_field.null or model_field.blank:
kwargs['required'] = False
- # if model_field.help_text is not None:
- # kwargs['help_text'] = model_field.help_text
- if model_field.verbose_name is not None:
+ if model_field.verbose_name:
kwargs['label'] = model_field.verbose_name
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
+ help_text = clean_manytomany_helptext(model_field.help_text)
+ if help_text:
+ kwargs['help_text'] = help_text
return PrimaryKeyRelatedField(**kwargs)
@@ -469,6 +479,9 @@ class ModelSerializer(Serializer):
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
+ if model_field.help_text:
+ kwargs['help_text'] = model_field.help_text
+
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
# Read only implies that the field is not required.
@@ -481,6 +494,14 @@ class ModelSerializer(Serializer):
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
+ if model_field.flatchoices:
+ # If this model field contains choices, then use a ChoiceField,
+ # rather than the standard serializer field for this type.
+ # Note that we return this prior to setting any validation type
+ # keyword arguments, as those are not valid initializers.
+ kwargs['choices'] = model_field.flatchoices
+ return ChoiceField(**kwargs)
+
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
@@ -561,23 +582,6 @@ class ModelSerializer(Serializer):
if validator_kwarg:
kwargs['validators'] = validator_kwarg
- # if issubclass(model_field.__class__, models.TextField):
- # kwargs['widget'] = widgets.Textarea
-
- # if model_field.help_text is not None:
- # kwargs['help_text'] = model_field.help_text
-
- # TODO: TypedChoiceField?
- if model_field.flatchoices: # This ModelField contains choices
- kwargs['choices'] = model_field.flatchoices
- if model_field.null:
- kwargs['empty'] = None
- return ChoiceField(**kwargs)
-
- if model_field.null and \
- issubclass(model_field.__class__, (models.CharField, models.TextField)):
- kwargs['allow_none'] = True
-
try:
return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
@@ -597,33 +601,24 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions
- def get_default_fields(self):
- fields = super(HyperlinkedModelSerializer, self).get_default_fields()
-
- if self.opts.view_name is None:
- self.opts.view_name = self.get_default_view_name(self.opts.model)
-
- url_field_name = api_settings.URL_FIELD_NAME
- if url_field_name not in fields:
- ret = fields.__class__()
- ret[url_field_name] = self.get_url_field()
- ret.update(fields)
- fields = ret
-
- return fields
-
- def get_pk_field(self, model_field):
- if self.opts.fields and model_field.name in self.opts.fields:
- return self.get_field(model_field)
-
def get_url_field(self):
+ if self.opts.view_name is not None:
+ view_name = self.opts.view_name
+ else:
+ view_name = self.get_default_view_name(self.opts.model)
+
kwargs = {
- 'view_name': self.get_default_view_name(self.opts.model)
+ 'view_name': view_name
}
if self.opts.lookup_field:
kwargs['lookup_field'] = self.opts.lookup_field
+
return HyperlinkedIdentityField(**kwargs)
+ def get_pk_field(self, model_field):
+ if self.opts.fields and model_field.name in self.opts.fields:
+ return self.get_field(model_field)
+
def get_related_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
@@ -643,13 +638,14 @@ class HyperlinkedModelSerializer(ModelSerializer):
if model_field:
if model_field.null or model_field.blank:
kwargs['required'] = False
- # if model_field.help_text is not None:
- # kwargs['help_text'] = model_field.help_text
- if model_field.verbose_name is not None:
+ if model_field.verbose_name:
kwargs['label'] = model_field.verbose_name
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
+ help_text = clean_manytomany_helptext(model_field.help_text)
+ if help_text:
+ kwargs['help_text'] = help_text
return HyperlinkedRelatedField(**kwargs)
--
cgit v1.2.3
From 80ba0473473501968154c5cc5dd5922e53d96a70 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 10 Sep 2014 16:57:22 +0100
Subject: Compat fixes
---
rest_framework/compat.py | 12 +++++++++++-
rest_framework/fields.py | 14 +++++++-------
rest_framework/serializers.py | 25 +++++++++++++------------
rest_framework/utils/modelinfo.py | 9 +++++----
4 files changed, 36 insertions(+), 24 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 70b38df9..7c05bed9 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -110,8 +110,18 @@ def get_concrete_model(model_cls):
return model_cls
+# View._allowed_methods only present from 1.5 onwards
+if django.VERSION >= (1, 5):
+ from django.views.generic import View
+else:
+ from django.views.generic import View as DjangoView
+
+ class View(DjangoView):
+ def _allowed_methods(self):
+ return [m.upper() for m in self.http_method_names if hasattr(self, m)]
+
+
# PATCH method is not implemented by Django
-from django.views.generic import View
if 'patch' not in View.http_method_names:
View.http_method_names = View.http_method_names + ['patch']
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index e71dce1d..3ec28908 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -266,8 +266,8 @@ class BooleanField(Field):
default_error_messages = {
'invalid': _('`{input}` is not a valid boolean.')
}
- TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True}
- FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False}
+ TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
+ FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
def get_value(self, dictionary):
if html.is_html_input(dictionary):
@@ -678,16 +678,16 @@ class ChoiceField(Field):
for item in choices
]
if all(pairs):
- self.choices = {key: display_value for key, display_value in choices}
+ self.choices = dict([(key, display_value) for key, display_value in choices])
else:
- self.choices = {item: item for item in choices}
+ self.choices = dict([(item, item) for item in choices])
# Map the string representation of choices to the underlying value.
# Allows us to deal with eg. integer choices while supporting either
# integer or string input, but still get the correct datatype out.
- self.choice_strings_to_values = {
- str(key): key for key in self.choices.keys()
- }
+ self.choice_strings_to_values = dict([
+ (str(key), key) for key in self.choices.keys()
+ ])
super(ChoiceField, self).__init__(**kwargs)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 459f8a8c..13e57939 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -14,7 +14,8 @@ from django.core import validators
from django.core.exceptions import ValidationError
from django.db import models
from django.utils import six
-from collections import namedtuple, OrderedDict
+from django.utils.datastructures import SortedDict
+from collections import namedtuple
from rest_framework.compat import clean_manytomany_helptext
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
@@ -91,10 +92,10 @@ class BaseSerializer(Field):
if self.instance is not None:
self._data = self.to_primative(self.instance)
elif self._initial_data is not None:
- self._data = {
- field_name: field.get_value(self._initial_data)
+ self._data = dict([
+ (field_name, field.get_value(self._initial_data))
for field_name, field in self.fields.items()
- }
+ ])
else:
self._data = self.get_initial()
return self._data
@@ -137,7 +138,7 @@ class SerializerMetaclass(type):
if hasattr(base, 'base_fields'):
fields = list(base.base_fields.items()) + fields
- return OrderedDict(fields)
+ return SortedDict(fields)
def __new__(cls, name, bases, attrs):
attrs['base_fields'] = cls._get_fields(bases, attrs)
@@ -180,10 +181,10 @@ class Serializer(BaseSerializer):
field.bind(field_name, self, root)
def get_initial(self):
- return {
- field.field_name: field.get_initial()
+ return dict([
+ (field.field_name, field.get_initial())
for field in self.fields.values()
- }
+ ])
def get_value(self, dictionary):
# We override the default field access in order to support
@@ -222,14 +223,14 @@ class Serializer(BaseSerializer):
try:
return self.validate(ret)
- except ValidationError, exc:
+ except ValidationError as exc:
raise ValidationError({'non_field_errors': exc.messages})
def to_primative(self, instance):
"""
Object instance -> Dict of primitive datatypes.
"""
- ret = OrderedDict()
+ ret = SortedDict()
fields = [field for field in self.fields.values() if not field.write_only]
for field in fields:
@@ -368,7 +369,7 @@ class ModelSerializer(Serializer):
# If `fields` is set on the `Meta` class,
# then use only those fields, and in that order.
if self.opts.fields:
- fields = OrderedDict([
+ fields = SortedDict([
(key, fields[key]) for key in self.opts.fields
])
@@ -379,7 +380,7 @@ class ModelSerializer(Serializer):
Return all the fields that should be serialized for the model.
"""
info = modelinfo.get_field_info(self.opts.model)
- ret = OrderedDict()
+ ret = SortedDict()
serializer_url_field = self.get_url_field()
if serializer_url_field:
diff --git a/rest_framework/utils/modelinfo.py b/rest_framework/utils/modelinfo.py
index c0513886..a7a0346c 100644
--- a/rest_framework/utils/modelinfo.py
+++ b/rest_framework/utils/modelinfo.py
@@ -2,9 +2,10 @@
Helper functions for returning the field information that is associated
with a model class.
"""
-from collections import namedtuple, OrderedDict
+from collections import namedtuple
from django.db import models
from django.utils import six
+from django.utils.datastructures import SortedDict
import inspect
FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
@@ -45,12 +46,12 @@ def get_field_info(model):
pk = pk.rel.to._meta.pk
# Deal with regular fields.
- fields = OrderedDict()
+ fields = SortedDict()
for field in [field for field in opts.fields if field.serialize and not field.rel]:
fields[field.name] = field
# Deal with forward relationships.
- forward_relations = OrderedDict()
+ forward_relations = SortedDict()
for field in [field for field in opts.fields if field.serialize and field.rel]:
forward_relations[field.name] = RelationInfo(
field=field,
@@ -71,7 +72,7 @@ def get_field_info(model):
)
# Deal with reverse relationships.
- reverse_relations = OrderedDict()
+ reverse_relations = SortedDict()
for relation in opts.get_all_related_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
--
cgit v1.2.3
From 54ccf7230d0fcdabe8c2457539e314893915a34b Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 11 Sep 2014 13:43:46 +0100
Subject: Improve memory address removal for serializer representations
---
rest_framework/utils/representation.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'rest_framework')
diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py
index 1de21597..e2a37497 100644
--- a/rest_framework/utils/representation.py
+++ b/rest_framework/utils/representation.py
@@ -17,7 +17,7 @@ def smart_repr(value):
#
# Should be presented as
#
- value = re.sub(' at 0x[0-9a-f]{8,10}>', '>', value)
+ value = re.sub(' at 0x[0-9a-f]{8,32}>', '>', value)
return value
--
cgit v1.2.3
From 3318f75a7166cbac76a40d0461ca7b3e4640d3a2 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 11 Sep 2014 13:50:53 +0100
Subject: Improve memory address removal for serializer representations
---
rest_framework/utils/representation.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'rest_framework')
diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py
index e2a37497..1a4d1a62 100644
--- a/rest_framework/utils/representation.py
+++ b/rest_framework/utils/representation.py
@@ -17,7 +17,7 @@ def smart_repr(value):
#
# Should be presented as
#
- value = re.sub(' at 0x[0-9a-f]{8,32}>', '>', value)
+ value = re.sub(' at 0x[0-9a-f]{4,32}>', '>', value)
return value
--
cgit v1.2.3
From ab40780dc2f341a271c2f489659dcd48eb47c07d Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 11 Sep 2014 20:22:32 +0100
Subject: Tidy up lookup_class
---
rest_framework/serializers.py | 21 +++++++++++----------
1 file changed, 11 insertions(+), 10 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 8fe999ae..4322f213 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -317,17 +317,17 @@ class ModelSerializerOptions(object):
self.depth = getattr(meta, 'depth', 0)
-def lookup_class(mapping, obj):
+def lookup_class(mapping, instance):
"""
Takes a dictionary with classes as keys, and an object.
Traverses the object's inheritance hierarchy in method
resolution order, and returns the first matching value
- from the dictionary or None.
+ from the dictionary or raises a KeyError if nothing matches.
"""
- return next(
- (mapping[cls] for cls in inspect.getmro(obj.__class__) if cls in mapping),
- None
- )
+ for cls in inspect.getmro(instance.__class__):
+ if cls in mapping:
+ return mapping[cls]
+ raise KeyError('Class %s not found in lookup.', cls.__name__)
class ModelSerializer(Serializer):
@@ -341,6 +341,7 @@ class ModelSerializer(Serializer):
models.DateTimeField: DateTimeField,
models.DecimalField: DecimalField,
models.EmailField: EmailField,
+ models.Field: ModelField,
models.FileField: FileField,
models.FloatField: FloatField,
models.ImageField: ImageField,
@@ -484,6 +485,7 @@ class ModelSerializer(Serializer):
"""
Creates a default instance of a basic non-relational field.
"""
+ serializer_cls = lookup_class(self.field_mapping, model_field)
kwargs = {}
validator_kwarg = model_field.validators
@@ -602,11 +604,10 @@ class ModelSerializer(Serializer):
if validator_kwarg:
kwargs['validators'] = validator_kwarg
- cls = lookup_class(self.field_mapping, model_field)
- if cls is None:
- cls = ModelField
+ if issubclass(serializer_cls, ModelField):
kwargs['model_field'] = model_field
- return cls(**kwargs)
+
+ return serializer_cls(**kwargs)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
--
cgit v1.2.3
From bf52d04f4c370d6917599d26c84b73124d5ef366 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 11 Sep 2014 20:37:27 +0100
Subject: Nice manager representations on serializer classes
---
rest_framework/utils/representation.py | 13 +++++++++++++
1 file changed, 13 insertions(+)
(limited to 'rest_framework')
diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py
index 1a4d1a62..71db1886 100644
--- a/rest_framework/utils/representation.py
+++ b/rest_framework/utils/representation.py
@@ -2,10 +2,23 @@
Helper functions for creating user-friendly representations
of serializer classes and serializer fields.
"""
+from django.db import models
import re
+def manager_repr(value):
+ model = value.model
+ opts = model._meta
+ for _, name, manager in opts.concrete_managers + opts.abstract_managers:
+ if manager == value:
+ return '%s.%s.all()' % (model._meta.object_name, name)
+ return repr(value)
+
+
def smart_repr(value):
+ if isinstance(value, models.Manager):
+ return manager_repr(value)
+
value = repr(value)
# Representations like u'help text'
--
cgit v1.2.3
From 19b8f779de82fa4737b37fb4359145af0b07a56c Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 11 Sep 2014 20:43:44 +0100
Subject: Throttles now use Retry-After header and no longer support the custom
style
---
rest_framework/views.py | 1 -
1 file changed, 1 deletion(-)
(limited to 'rest_framework')
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 3b7b1c16..cd394b2d 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -62,7 +62,6 @@ def exception_handler(exc):
if getattr(exc, 'auth_header', None):
headers['WWW-Authenticate'] = exc.auth_header
if getattr(exc, 'wait', None):
- headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
headers['Retry-After'] = '%d' % exc.wait
return Response({'detail': exc.detail},
--
cgit v1.2.3
From 55650a743d579e0bc1643c8812428746b0271984 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 11 Sep 2014 20:49:10 +0100
Subject: no longer tightly coupled to private queryset API
---
rest_framework/generics.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
(limited to 'rest_framework')
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index c2c59154..408b1246 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -3,6 +3,7 @@ Generic views that provide commonly needed behaviour.
"""
from __future__ import unicode_literals
+from django.db.models.query import QuerySet
from django.core.exceptions import PermissionDenied
from django.core.paginator import Paginator, InvalidPage
from django.http import Http404
@@ -214,7 +215,9 @@ class GenericAPIView(views.APIView):
% self.__class__.__name__
)
- return self.queryset._clone()
+ if isinstance(self.queryset, QuerySet):
+ return self.queryset.all()
+ return self.queryset
def get_object(self):
"""
--
cgit v1.2.3
From a7518719917c7ad8e699119b442cfeb568ba1dde Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 11 Sep 2014 20:50:26 +0100
Subject: no longer tightly coupled to private queryset API
---
rest_framework/generics.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 408b1246..338d56a6 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -215,9 +215,11 @@ class GenericAPIView(views.APIView):
% self.__class__.__name__
)
+ queryset = self.queryset
if isinstance(self.queryset, QuerySet):
- return self.queryset.all()
- return self.queryset
+ # Ensure queryset is re-evaluated on each request.
+ queryset = queryset.all()
+ return queryset
def get_object(self):
"""
--
cgit v1.2.3
From 040bfcc09c851bb3dadd60558c78a1f7937e9fbd Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 11 Sep 2014 21:48:54 +0100
Subject: NotImplemented stubs for Field, and DecimalField improvements
---
rest_framework/fields.py | 27 +++++++++++++++++++++------
rest_framework/pagination.py | 4 ++--
rest_framework/utils/encoders.py | 2 +-
3 files changed, 24 insertions(+), 9 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 7496a629..20b8ffbf 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -229,13 +229,13 @@ class Field(object):
"""
Transform the *incoming* primative data into a native value.
"""
- return data
+ raise NotImplementedError('to_native() must be implemented.')
def to_primative(self, value):
"""
Transform the *outgoing* native value into primative data.
"""
- return value
+ raise NotImplementedError('to_primative() must be implemented.')
def fail(self, key, **kwargs):
"""
@@ -429,9 +429,10 @@ class DecimalField(Field):
'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
}
- def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs):
- self.max_value, self.min_value = max_value, min_value
- self.max_digits, self.max_decimal_places = max_digits, decimal_places
+ def __init__(self, max_digits, decimal_places, coerce_to_string=True, max_value=None, min_value=None, **kwargs):
+ self.max_digits = max_digits
+ self.decimal_places = decimal_places
+ self.coerce_to_string = coerce_to_string
super(DecimalField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
@@ -478,12 +479,26 @@ class DecimalField(Field):
if self.max_digits is not None and digits > self.max_digits:
self.fail('max_digits', max_digits=self.max_digits)
if self.decimal_places is not None and decimals > self.decimal_places:
- self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places)
+ self.fail('max_decimal_places', max_decimal_places=self.decimal_places)
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
return value
+ def to_primative(self, value):
+ if not self.coerce_to_string:
+ return value
+
+ if isinstance(value, decimal.Decimal):
+ context = decimal.getcontext().copy()
+ context.prec = self.max_digits
+ quantized = value.quantize(
+ decimal.Decimal('.1') ** self.decimal_places,
+ context=context
+ )
+ return '{0:f}'.format(quantized)
+ return '%.*f' % (self.max_decimal_places, value)
+
# Date & time fields...
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 9cf31629..d82d2d3b 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -37,7 +37,7 @@ class PreviousPageField(serializers.Field):
return replace_query_param(url, self.page_field, page)
-class DefaultObjectSerializer(serializers.Field):
+class DefaultObjectSerializer(serializers.ReadOnlyField):
"""
If no object serializer is specified, then this serializer will be applied
as the default.
@@ -79,6 +79,6 @@ class PaginationSerializer(BasePaginationSerializer):
"""
A default implementation of a pagination serializer.
"""
- count = serializers.Field(source='paginator.count')
+ count = serializers.ReadOnlyField(source='paginator.count')
next = NextPageField(source='*')
previous = PreviousPageField(source='*')
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 6a2f6126..7992b6b1 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -43,7 +43,7 @@ class JSONEncoder(json.JSONEncoder):
elif isinstance(o, datetime.timedelta):
return str(o.total_seconds())
elif isinstance(o, decimal.Decimal):
- return str(o)
+ return float(o)
elif isinstance(o, QuerySet):
return list(o)
elif hasattr(o, 'tolist'):
--
cgit v1.2.3
From 1e53eb0aa2998385e26aa0a4d8542013bc7b575b Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 11 Sep 2014 21:57:32 +0100
Subject: DecimalFields should still be quantized even without coerce_to_string
---
rest_framework/fields.py | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 20b8ffbf..a56ea96b 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -486,9 +486,6 @@ class DecimalField(Field):
return value
def to_primative(self, value):
- if not self.coerce_to_string:
- return value
-
if isinstance(value, decimal.Decimal):
context = decimal.getcontext().copy()
context.prec = self.max_digits
@@ -496,7 +493,12 @@ class DecimalField(Field):
decimal.Decimal('.1') ** self.decimal_places,
context=context
)
+ if not self.coerce_to_string:
+ return quantized
return '{0:f}'.format(quantized)
+
+ if not self.coerce_to_string:
+ return value
return '%.*f' % (self.max_decimal_places, value)
--
cgit v1.2.3
From adcb64ab4198f35c61d5be68956d201685ed3538 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 12 Sep 2014 09:12:56 +0100
Subject: MethodField -> SerializerMethodField
---
rest_framework/fields.py | 21 +++++++++------------
rest_framework/serializers.py | 3 ++-
rest_framework/utils/modelinfo.py | 3 ++-
3 files changed, 13 insertions(+), 14 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index a56ea96b..5f198767 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -768,16 +768,13 @@ class ReadOnlyField(Field):
kwargs['read_only'] = True
super(ReadOnlyField, self).__init__(**kwargs)
- def to_native(self, data):
- raise NotImplemented('.to_native() not supported.')
-
def to_primative(self, value):
if is_simple_callable(value):
return value()
return value
-class MethodField(Field):
+class SerializerMethodField(Field):
"""
A read-only field that get its representation from calling a method on the
parent serializer class. The method called will be of the form
@@ -787,22 +784,22 @@ class MethodField(Field):
For example:
class ExampleSerializer(self):
- extra_info = MethodField()
+ extra_info = SerializerMethodField()
def get_extra_info(self, obj):
return ... # Calculate some data to return.
"""
- def __init__(self, **kwargs):
+ def __init__(self, method_attr=None, **kwargs):
+ self.method_attr = method_attr
kwargs['source'] = '*'
kwargs['read_only'] = True
- super(MethodField, self).__init__(**kwargs)
-
- def to_native(self, data):
- raise NotImplemented('.to_native() not supported.')
+ super(SerializerMethodField, self).__init__(**kwargs)
def to_primative(self, value):
- attr = 'get_{field_name}'.format(field_name=self.field_name)
- method = getattr(self.parent, attr)
+ method_attr = self.method_attr
+ if method_attr is None:
+ method_attr = 'get_{field_name}'.format(field_name=self.field_name)
+ method = getattr(self.parent, method_attr)
return method(value)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 4322f213..388fe29f 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -598,7 +598,8 @@ class ModelSerializer(Serializer):
if isinstance(model_field, models.BooleanField):
# models.BooleanField has `blank=True`, but *is* actually
# required *unless* a default is provided.
- # Also note that <1.6 `default=False`, >=1.6 `default=None`.
+ # Also note that Django<1.6 uses `default=False` for
+ # models.BooleanField, but Django>=1.6 uses `default=None`.
kwargs.pop('required', None)
if validator_kwarg:
diff --git a/rest_framework/utils/modelinfo.py b/rest_framework/utils/modelinfo.py
index a7a0346c..960fa4d0 100644
--- a/rest_framework/utils/modelinfo.py
+++ b/rest_framework/utils/modelinfo.py
@@ -1,6 +1,7 @@
"""
Helper functions for returning the field information that is associated
-with a model class.
+with a model class. This includes returning all the forward and reverse
+relationships and their associated metadata.
"""
from collections import namedtuple
from django.db import models
--
cgit v1.2.3
From 0d354e8f92c7daaf8dac3b80f0fd64f983f21e0b Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 12 Sep 2014 09:49:35 +0100
Subject: to_internal_value() and to_representation()
---
rest_framework/fields.py | 86 ++++++++++++++++++++++---------------------
rest_framework/pagination.py | 4 +-
rest_framework/relations.py | 2 +-
rest_framework/serializers.py | 28 +++++++-------
4 files changed, 61 insertions(+), 59 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 5f198767..a96f9ba8 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -195,7 +195,7 @@ class Field(object):
raise SkipField()
return self.default
- def validate_value(self, data=empty):
+ def run_validation(self, data=empty):
"""
Validate a simple representation and return the internal value.
@@ -208,7 +208,7 @@ class Field(object):
self.fail('required')
return self.get_default()
- value = self.to_native(data)
+ value = self.to_internal_value(data)
self.run_validators(value)
return value
@@ -225,17 +225,17 @@ class Field(object):
if errors:
raise ValidationError(errors)
- def to_native(self, data):
+ def to_internal_value(self, data):
"""
Transform the *incoming* primative data into a native value.
"""
- raise NotImplementedError('to_native() must be implemented.')
+ raise NotImplementedError('to_internal_value() must be implemented.')
- def to_primative(self, value):
+ def to_representation(self, value):
"""
Transform the *outgoing* native value into primative data.
"""
- raise NotImplementedError('to_primative() must be implemented.')
+ raise NotImplementedError('to_representation() must be implemented.')
def fail(self, key, **kwargs):
"""
@@ -279,14 +279,14 @@ class BooleanField(Field):
return dictionary.get(self.field_name, False)
return dictionary.get(self.field_name, empty)
- def to_native(self, data):
+ def to_internal_value(self, data):
if data in self.TRUE_VALUES:
return True
elif data in self.FALSE_VALUES:
return False
self.fail('invalid', input=data)
- def to_primative(self, value):
+ def to_representation(self, value):
if value is None:
return None
if value in self.TRUE_VALUES:
@@ -309,12 +309,14 @@ class CharField(Field):
self.min_length = kwargs.pop('min_length', None)
super(CharField, self).__init__(**kwargs)
- def to_native(self, data):
+ def to_internal_value(self, data):
if data == '' and not self.allow_blank:
self.fail('blank')
+ if data is None:
+ return None
return str(data)
- def to_primative(self, value):
+ def to_representation(self, value):
if value is None:
return None
return str(value)
@@ -326,17 +328,17 @@ class EmailField(CharField):
}
default_validators = [validators.validate_email]
- def to_native(self, data):
- ret = super(EmailField, self).to_native(data)
- if ret is None:
+ def to_internal_value(self, data):
+ if data == '' and not self.allow_blank:
+ self.fail('blank')
+ if data is None:
return None
- return ret.strip()
+ return str(data).strip()
- def to_primative(self, value):
- ret = super(EmailField, self).to_primative(value)
- if ret is None:
+ def to_representation(self, value):
+ if value is None:
return None
- return ret.strip()
+ return str(value).strip()
class RegexField(CharField):
@@ -378,14 +380,14 @@ class IntegerField(Field):
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
- def to_native(self, data):
+ def to_internal_value(self, data):
try:
data = int(str(data))
except (ValueError, TypeError):
self.fail('invalid')
return data
- def to_primative(self, value):
+ def to_representation(self, value):
if value is None:
return None
return int(value)
@@ -405,7 +407,12 @@ class FloatField(Field):
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
- def to_primative(self, value):
+ def to_internal_value(self, value):
+ if value is None:
+ return None
+ return float(value)
+
+ def to_representation(self, value):
if value is None:
return None
try:
@@ -413,11 +420,6 @@ class FloatField(Field):
except (TypeError, ValueError):
self.fail('invalid', value=value)
- def to_native(self, value):
- if value is None:
- return None
- return float(value)
-
class DecimalField(Field):
default_error_messages = {
@@ -439,7 +441,7 @@ class DecimalField(Field):
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
- def from_native(self, value):
+ def to_internal_value(self, value):
"""
Validates that the input is a decimal number. Returns a Decimal
instance. Returns None for empty values. Ensures that there are no more
@@ -485,7 +487,7 @@ class DecimalField(Field):
return value
- def to_primative(self, value):
+ def to_representation(self, value):
if isinstance(value, decimal.Decimal):
context = decimal.getcontext().copy()
context.prec = self.max_digits
@@ -516,7 +518,7 @@ class DateField(Field):
self.format = format if format is not None else self.format
super(DateField, self).__init__(*args, **kwargs)
- def from_native(self, value):
+ def to_internal_value(self, value):
if value in validators.EMPTY_VALUES:
return None
@@ -552,7 +554,7 @@ class DateField(Field):
msg = self.error_messages['invalid'] % humanized_format
raise ValidationError(msg)
- def to_primative(self, value):
+ def to_representation(self, value):
if value is None or self.format is None:
return value
@@ -576,7 +578,7 @@ class DateTimeField(Field):
self.format = format if format is not None else self.format
super(DateTimeField, self).__init__(*args, **kwargs)
- def from_native(self, value):
+ def to_internal_value(self, value):
if value in validators.EMPTY_VALUES:
return None
@@ -618,7 +620,7 @@ class DateTimeField(Field):
msg = self.error_messages['invalid'] % humanized_format
raise ValidationError(msg)
- def to_primative(self, value):
+ def to_representation(self, value):
if value is None or self.format is None:
return value
@@ -670,7 +672,7 @@ class TimeField(Field):
msg = self.error_messages['invalid'] % humanized_format
raise ValidationError(msg)
- def to_primative(self, value):
+ def to_representation(self, value):
if value is None or self.format is None:
return value
@@ -711,13 +713,13 @@ class ChoiceField(Field):
super(ChoiceField, self).__init__(**kwargs)
- def to_native(self, data):
+ def to_internal_value(self, data):
try:
return self.choice_strings_to_values[str(data)]
except KeyError:
self.fail('invalid_choice', input=data)
- def to_primative(self, value):
+ def to_representation(self, value):
return value
@@ -727,15 +729,15 @@ class MultipleChoiceField(ChoiceField):
'not_a_list': _('Expected a list of items but got type `{input_type}`')
}
- def to_native(self, data):
+ def to_internal_value(self, data):
if not hasattr(data, '__iter__'):
self.fail('not_a_list', input_type=type(data).__name__)
return set([
- super(MultipleChoiceField, self).to_native(item)
+ super(MultipleChoiceField, self).to_internal_value(item)
for item in data
])
- def to_primative(self, value):
+ def to_representation(self, value):
return value
@@ -768,7 +770,7 @@ class ReadOnlyField(Field):
kwargs['read_only'] = True
super(ReadOnlyField, self).__init__(**kwargs)
- def to_primative(self, value):
+ def to_representation(self, value):
if is_simple_callable(value):
return value()
return value
@@ -795,7 +797,7 @@ class SerializerMethodField(Field):
kwargs['read_only'] = True
super(SerializerMethodField, self).__init__(**kwargs)
- def to_primative(self, value):
+ def to_representation(self, value):
method_attr = self.method_attr
if method_attr is None:
method_attr = 'get_{field_name}'.format(field_name=self.field_name)
@@ -815,13 +817,13 @@ class ModelField(Field):
kwargs['source'] = '*'
super(ModelField, self).__init__(**kwargs)
- def to_native(self, data):
+ def to_internal_value(self, data):
rel = getattr(self.model_field, 'rel', None)
if rel is not None:
return rel.to._meta.get_field(rel.field_name).to_python(data)
return self.model_field.to_python(data)
- def to_primative(self, obj):
+ def to_representation(self, obj):
value = self.model_field._get_val_from_obj(obj)
if is_protected_type(value):
return value
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index d82d2d3b..c5a9270a 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -13,7 +13,7 @@ class NextPageField(serializers.Field):
"""
page_field = 'page'
- def to_primative(self, value):
+ def to_representation(self, value):
if not value.has_next():
return None
page = value.next_page_number()
@@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field):
"""
page_field = 'page'
- def to_primative(self, value):
+ def to_representation(self, value):
if not value.has_previous():
return None
page = value.previous_page_number()
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 661a1249..30a252db 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -110,7 +110,7 @@ class HyperlinkedIdentityField(RelatedField):
def get_attribute(self, instance):
return instance
- def to_primative(self, value):
+ def to_representation(self, value):
request = self.context.get('request', None)
format = self.context.get('format', None)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 388fe29f..502b1e19 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -47,11 +47,11 @@ class BaseSerializer(Field):
self.instance = instance
self._initial_data = data
- def to_native(self, data):
- raise NotImplementedError('`to_native()` must be implemented.')
+ def to_internal_value(self, data):
+ raise NotImplementedError('`to_internal_value()` must be implemented.')
- def to_primative(self, instance):
- raise NotImplementedError('`to_primative()` must be implemented.')
+ def to_representation(self, instance):
+ raise NotImplementedError('`to_representation()` must be implemented.')
def update(self, instance, attrs):
raise NotImplementedError('`update()` must be implemented.')
@@ -74,7 +74,7 @@ class BaseSerializer(Field):
def is_valid(self, raise_exception=False):
if not hasattr(self, '_validated_data'):
try:
- self._validated_data = self.to_native(self._initial_data)
+ self._validated_data = self.to_internal_value(self._initial_data)
except ValidationError as exc:
self._validated_data = {}
self._errors = exc.message_dict
@@ -90,7 +90,7 @@ class BaseSerializer(Field):
def data(self):
if not hasattr(self, '_data'):
if self.instance is not None:
- self._data = self.to_primative(self.instance)
+ self._data = self.to_representation(self.instance)
elif self._initial_data is not None:
self._data = dict([
(field_name, field.get_value(self._initial_data))
@@ -193,7 +193,7 @@ class Serializer(BaseSerializer):
return html.parse_html_dict(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty)
- def to_native(self, data):
+ def to_internal_value(self, data):
"""
Dict of native values <- Dict of primitive datatypes.
"""
@@ -208,7 +208,7 @@ class Serializer(BaseSerializer):
validate_method = getattr(self, 'validate_' + field.field_name, None)
primitive_value = field.get_value(data)
try:
- validated_value = field.validate_value(primitive_value)
+ validated_value = field.run_validation(primitive_value)
if validate_method is not None:
validated_value = validate_method(validated_value)
except ValidationError as exc:
@@ -226,7 +226,7 @@ class Serializer(BaseSerializer):
except ValidationError as exc:
raise ValidationError({'non_field_errors': exc.messages})
- def to_primative(self, instance):
+ def to_representation(self, instance):
"""
Object instance -> Dict of primitive datatypes.
"""
@@ -235,7 +235,7 @@ class Serializer(BaseSerializer):
for field in fields:
native_value = field.get_attribute(instance)
- ret[field.field_name] = field.to_primative(native_value)
+ ret[field.field_name] = field.to_representation(native_value)
return ret
@@ -279,20 +279,20 @@ class ListSerializer(BaseSerializer):
return html.parse_html_list(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty)
- def to_native(self, data):
+ def to_internal_value(self, data):
"""
List of dicts of native values <- List of dicts of primitive datatypes.
"""
if html.is_html_input(data):
data = html.parse_html_list(data)
- return [self.child.validate(item) for item in data]
+ return [self.child.run_validation(item) for item in data]
- def to_primative(self, data):
+ def to_representation(self, data):
"""
List of object instances -> List of dicts of primitive datatypes.
"""
- return [self.child.to_primative(item) for item in data]
+ return [self.child.to_representation(item) for item in data]
def create(self, attrs_list):
return [self.child.create(attrs) for attrs in attrs_list]
--
cgit v1.2.3
From 6db3356c4d1aa4f9a042b0ec67d47238abc16dd7 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 12 Sep 2014 10:21:35 +0100
Subject: NON_FIELD_ERRORS_KEY setting
---
rest_framework/serializers.py | 8 ++++++--
rest_framework/settings.py | 1 +
rest_framework/views.py | 8 +++++++-
3 files changed, 14 insertions(+), 3 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 502b1e19..0c2aedfa 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -198,7 +198,9 @@ class Serializer(BaseSerializer):
Dict of native values <- Dict of primitive datatypes.
"""
if not isinstance(data, dict):
- raise ValidationError({'non_field_errors': ['Invalid data']})
+ raise ValidationError({
+ api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
+ })
ret = {}
errors = {}
@@ -224,7 +226,9 @@ class Serializer(BaseSerializer):
try:
return self.validate(ret)
except ValidationError as exc:
- raise ValidationError({'non_field_errors': exc.messages})
+ raise ValidationError({
+ api_settings.NON_FIELD_ERRORS_KEY: exc.messages
+ })
def to_representation(self, instance):
"""
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index bbe7a56a..f48643b5 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -77,6 +77,7 @@ DEFAULTS = {
# Exception handling
'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
+ 'NON_FIELD_ERRORS_KEY': 'non_field_errors',
# Testing
'TEST_REQUEST_RENDERER_CLASSES': (
diff --git a/rest_framework/views.py b/rest_framework/views.py
index cd394b2d..9f08a4ad 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework.
"""
from __future__ import unicode_literals
-from django.core.exceptions import PermissionDenied, ValidationError
+from django.core.exceptions import PermissionDenied, ValidationError, NON_FIELD_ERRORS
from django.http import Http404
from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt
@@ -69,6 +69,12 @@ def exception_handler(exc):
headers=headers)
elif isinstance(exc, ValidationError):
+ # ValidationErrors may include the non-field key named '__all__'.
+ # When returning a response we map this to a key name that can be
+ # modified in settings.
+ if NON_FIELD_ERRORS in exc.message_dict:
+ errors = exc.message_dict.pop(NON_FIELD_ERRORS)
+ exc.message_dict[api_settings.NON_FIELD_ERRORS_KEY] = errors
return Response(exc.message_dict,
status=status.HTTP_400_BAD_REQUEST)
--
cgit v1.2.3
From 250755def707e1397876614fa0c08130d9fcc449 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 12 Sep 2014 10:59:51 +0100
Subject: Clean up relational fields queryset usage
---
rest_framework/fields.py | 15 ++++------
rest_framework/generics.py | 2 +-
rest_framework/relations.py | 73 ++++++++++++++++++++++++---------------------
3 files changed, 46 insertions(+), 44 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index a96f9ba8..4f06d186 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -508,7 +508,7 @@ class DecimalField(Field):
class DateField(Field):
default_error_messages = {
- 'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
+ 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'),
}
input_formats = api_settings.DATE_INPUT_FORMATS
format = api_settings.DATE_FORMAT
@@ -551,8 +551,7 @@ class DateField(Field):
return parsed.date()
humanized_format = humanize_datetime.date_formats(self.input_formats)
- msg = self.error_messages['invalid'] % humanized_format
- raise ValidationError(msg)
+ self.fail('invalid', format=humanized_format)
def to_representation(self, value):
if value is None or self.format is None:
@@ -568,7 +567,7 @@ class DateField(Field):
class DateTimeField(Field):
default_error_messages = {
- 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
+ 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'),
}
input_formats = api_settings.DATETIME_INPUT_FORMATS
format = api_settings.DATETIME_FORMAT
@@ -617,8 +616,7 @@ class DateTimeField(Field):
return parsed
humanized_format = humanize_datetime.datetime_formats(self.input_formats)
- msg = self.error_messages['invalid'] % humanized_format
- raise ValidationError(msg)
+ self.fail('invalid', format=humanized_format)
def to_representation(self, value):
if value is None or self.format is None:
@@ -634,7 +632,7 @@ class DateTimeField(Field):
class TimeField(Field):
default_error_messages = {
- 'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
+ 'invalid': _('Time has wrong format. Use one of these formats instead: {format}'),
}
input_formats = api_settings.TIME_INPUT_FORMATS
format = api_settings.TIME_FORMAT
@@ -669,8 +667,7 @@ class TimeField(Field):
return parsed.time()
humanized_format = humanize_datetime.time_formats(self.input_formats)
- msg = self.error_messages['invalid'] % humanized_format
- raise ValidationError(msg)
+ self.fail('invalid', format=humanized_format)
def to_representation(self, value):
if value is None or self.format is None:
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 338d56a6..eb6b64ef 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -216,7 +216,7 @@ class GenericAPIView(views.APIView):
)
queryset = self.queryset
- if isinstance(self.queryset, QuerySet):
+ if isinstance(queryset, QuerySet):
# Ensure queryset is re-evaluated on each request.
queryset = queryset.all()
return queryset
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 30a252db..e23a4152 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -2,28 +2,35 @@ from rest_framework.fields import Field
from rest_framework.reverse import reverse
from django.core.exceptions import ObjectDoesNotExist
from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
+from django.db.models.query import QuerySet
from rest_framework.compat import urlparse
-def get_default_queryset(serializer_class, field_name):
- manager = getattr(serializer_class.opts.model, field_name)
- if hasattr(manager, 'related'):
- # Forward relationships
- return manager.related.model._default_manager.all()
- # Reverse relationships
- return manager.field.rel.to._default_manager.all()
-
-
class RelatedField(Field):
def __init__(self, **kwargs):
self.queryset = kwargs.pop('queryset', None)
self.many = kwargs.pop('many', False)
+ assert self.queryset is not None or kwargs.get('read_only', False), (
+ 'Relational field must provide a `queryset` argument, '
+ 'or set read_only=`True`.'
+ )
super(RelatedField, self).__init__(**kwargs)
- def bind(self, field_name, parent, root):
- super(RelatedField, self).bind(field_name, parent, root)
- if self.queryset is None and not self.read_only:
- self.queryset = get_default_queryset(parent, self.source)
+ def get_queryset(self):
+ queryset = self.queryset
+ if isinstance(queryset, QuerySet):
+ # Ensure queryset is re-evaluated whenever used.
+ queryset = queryset.all()
+ return queryset
+
+
+class StringRelatedField(Field):
+ def __init__(self, **kwargs):
+ kwargs['read_only'] = True
+ super(StringRelatedField, self).__init__(**kwargs)
+
+ def to_representation(self, value):
+ return str(value)
class PrimaryKeyRelatedField(RelatedField):
@@ -33,9 +40,9 @@ class PrimaryKeyRelatedField(RelatedField):
'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.',
}
- def from_native(self, data):
+ def to_internal_value(self, data):
try:
- return self.queryset.get(pk=data)
+ return self.get_queryset().get(pk=data)
except ObjectDoesNotExist:
self.fail('does_not_exist', pk_value=data)
except (TypeError, ValueError):
@@ -68,9 +75,9 @@ class HyperlinkedRelatedField(RelatedField):
"""
lookup_value = view_kwargs[self.lookup_url_kwarg]
lookup_kwargs = {self.lookup_field: lookup_value}
- return self.queryset.get(**lookup_kwargs)
+ return self.get_queryset().get(**lookup_kwargs)
- def from_native(self, value):
+ def to_internal_value(self, value):
try:
http_prefix = value.startswith(('http:', 'https:'))
except AttributeError:
@@ -102,13 +109,26 @@ class HyperlinkedIdentityField(RelatedField):
def __init__(self, **kwargs):
kwargs['read_only'] = True
+ kwargs['source'] = '*'
self.view_name = kwargs.pop('view_name')
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
super(HyperlinkedIdentityField, self).__init__(**kwargs)
- def get_attribute(self, instance):
- return instance
+ def get_url(self, obj, view_name, request, format):
+ """
+ Given an object, return the URL that hyperlinks to the object.
+
+ May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
+ attributes are not configured to correctly match the URL conf.
+ """
+ # Unsaved objects will not yet have a valid URL.
+ if obj.pk is None:
+ return None
+
+ lookup_value = getattr(obj, self.lookup_field)
+ kwargs = {self.lookup_url_kwarg: lookup_value}
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
def to_representation(self, value):
request = self.context.get('request', None)
@@ -144,21 +164,6 @@ class HyperlinkedIdentityField(RelatedField):
)
raise Exception(msg % self.view_name)
- def get_url(self, obj, view_name, request, format):
- """
- Given an object, return the URL that hyperlinks to the object.
-
- May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
- attributes are not configured to correctly match the URL conf.
- """
- # Unsaved objects will not yet have a valid URL.
- if obj.pk is None:
- return None
-
- lookup_value = getattr(obj, self.lookup_field)
- kwargs = {self.lookup_url_kwarg: lookup_value}
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
-
class SlugRelatedField(RelatedField):
def __init__(self, **kwargs):
--
cgit v1.2.3
From 5e39e159ee6aee90755709cfecec7d22c7ea3049 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 12 Sep 2014 11:38:22 +0100
Subject: UNICODE_JSON and COMPACT_JSON settings
---
rest_framework/fields.py | 20 ++++++++++----------
rest_framework/parsers.py | 2 +-
rest_framework/renderers.py | 33 ++++++++++++---------------------
rest_framework/settings.py | 3 +++
4 files changed, 26 insertions(+), 32 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 4f06d186..9d96cf5c 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -137,6 +137,16 @@ class Field(object):
messages.update(error_messages or {})
self.error_messages = messages
+ def __new__(cls, *args, **kwargs):
+ """
+ When a field is instantiated, we store the arguments that were used,
+ so that we can present a helpful representation of the object.
+ """
+ instance = super(Field, cls).__new__(cls)
+ instance._args = args
+ instance._kwargs = kwargs
+ return instance
+
def bind(self, field_name, parent, root):
"""
Setup the context for the field instance.
@@ -249,16 +259,6 @@ class Field(object):
raise AssertionError(msg)
raise ValidationError(msg.format(**kwargs))
- def __new__(cls, *args, **kwargs):
- """
- When a field is instantiated, we store the arguments that were used,
- so that we can present a helpful representation of the object.
- """
- instance = super(Field, cls).__new__(cls)
- instance._args = args
- instance._kwargs = kwargs
- return instance
-
def __repr__(self):
return representation.field_repr(self)
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index c287908d..fa02ecf1 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -48,7 +48,7 @@ class JSONParser(BaseParser):
"""
media_type = 'application/json'
- renderer_class = renderers.UnicodeJSONRenderer
+ renderer_class = renderers.JSONRenderer
def parse(self, stream, media_type=None, parser_context=None):
"""
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index dfc5a39f..3bf03e62 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -26,6 +26,10 @@ from rest_framework.utils.breadcrumbs import get_breadcrumbs
from rest_framework import exceptions, status, VERSION
+def zero_as_none(value):
+ return None if value == 0 else value
+
+
class BaseRenderer(object):
"""
All renderers should extend this class, setting the `media_type`
@@ -44,13 +48,13 @@ class BaseRenderer(object):
class JSONRenderer(BaseRenderer):
"""
Renderer which serializes to JSON.
- Applies JSON's backslash-u character escaping for non-ascii characters.
"""
media_type = 'application/json'
format = 'json'
encoder_class = encoders.JSONEncoder
- ensure_ascii = True
+ ensure_ascii = not api_settings.UNICODE_JSON
+ compact = api_settings.COMPACT_JSON
# We don't set a charset because JSON is a binary encoding,
# that can be encoded as utf-8, utf-16 or utf-32.
@@ -62,9 +66,10 @@ class JSONRenderer(BaseRenderer):
if accepted_media_type:
# If the media type looks like 'application/json; indent=4',
# then pretty print the result.
+ # Note that we coerce `indent=0` into `indent=None`.
base_media_type, params = parse_header(accepted_media_type.encode('ascii'))
try:
- return max(min(int(params['indent']), 8), 0)
+ return zero_as_none(max(min(int(params['indent']), 8), 0))
except (KeyError, ValueError, TypeError):
pass
@@ -81,10 +86,12 @@ class JSONRenderer(BaseRenderer):
renderer_context = renderer_context or {}
indent = self.get_indent(accepted_media_type, renderer_context)
+ separators = (',', ':') if (indent is None and self.compact) else (', ', ': ')
ret = json.dumps(
data, cls=self.encoder_class,
- indent=indent, ensure_ascii=self.ensure_ascii
+ indent=indent, ensure_ascii=self.ensure_ascii,
+ separators=separators
)
# On python 2.x json.dumps() returns bytestrings if ensure_ascii=True,
@@ -96,14 +103,6 @@ class JSONRenderer(BaseRenderer):
return ret
-class UnicodeJSONRenderer(JSONRenderer):
- ensure_ascii = False
- """
- Renderer which serializes to JSON.
- Does *not* apply JSON's character escaping for non-ascii characters.
- """
-
-
class JSONPRenderer(JSONRenderer):
"""
Renderer which serializes to json,
@@ -196,7 +195,7 @@ class YAMLRenderer(BaseRenderer):
format = 'yaml'
encoder = encoders.SafeDumper
charset = 'utf-8'
- ensure_ascii = True
+ ensure_ascii = False
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -210,14 +209,6 @@ class YAMLRenderer(BaseRenderer):
return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii)
-class UnicodeYAMLRenderer(YAMLRenderer):
- """
- Renderer which serializes to YAML.
- Does *not* apply character escaping for non-ascii characters.
- """
- ensure_ascii = False
-
-
class TemplateHTMLRenderer(BaseRenderer):
"""
An HTML renderer for use with templates.
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index f48643b5..e55610bb 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -112,6 +112,9 @@ DEFAULTS = {
),
'TIME_FORMAT': None,
+ # Encoding
+ 'UNICODE_JSON': True,
+ 'COMPACT_JSON': True
}
--
cgit v1.2.3
From 22af49bf8ffc73afc9b638f1b9cd2e909c6c89a8 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 12 Sep 2014 11:50:20 +0100
Subject: Tidy up JSONEncoder
---
rest_framework/utils/encoders.py | 67 ++++++++++++++++++++--------------------
1 file changed, 34 insertions(+), 33 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 7992b6b1..174b08b8 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -7,7 +7,6 @@ from django.db.models.query import QuerySet
from django.utils.datastructures import SortedDict
from django.utils.functional import Promise
from rest_framework.compat import force_text
-# from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
import datetime
import decimal
import types
@@ -17,45 +16,47 @@ import json
class JSONEncoder(json.JSONEncoder):
"""
JSONEncoder subclass that knows how to encode date/time/timedelta,
- decimal types, and generators.
+ decimal types, generators and other basic python objects.
"""
- def default(self, o):
+ def default(self, obj):
# For Date Time string spec, see ECMA 262
# http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15
- if isinstance(o, Promise):
- return force_text(o)
- elif isinstance(o, datetime.datetime):
- r = o.isoformat()
- if o.microsecond:
- r = r[:23] + r[26:]
- if r.endswith('+00:00'):
- r = r[:-6] + 'Z'
- return r
- elif isinstance(o, datetime.date):
- return o.isoformat()
- elif isinstance(o, datetime.time):
- if timezone and timezone.is_aware(o):
+ if isinstance(obj, Promise):
+ return force_text(obj)
+ elif isinstance(obj, datetime.datetime):
+ representation = obj.isoformat()
+ if obj.microsecond:
+ representation = representation[:23] + representation[26:]
+ if representation.endswith('+00:00'):
+ representation = representation[:-6] + 'Z'
+ return representation
+ elif isinstance(obj, datetime.date):
+ return obj.isoformat()
+ elif isinstance(obj, datetime.time):
+ if timezone and timezone.is_aware(obj):
raise ValueError("JSON can't represent timezone-aware times.")
- r = o.isoformat()
- if o.microsecond:
- r = r[:12]
- return r
- elif isinstance(o, datetime.timedelta):
- return str(o.total_seconds())
- elif isinstance(o, decimal.Decimal):
- return float(o)
- elif isinstance(o, QuerySet):
- return list(o)
- elif hasattr(o, 'tolist'):
- return o.tolist()
- elif hasattr(o, '__getitem__'):
+ representation = obj.isoformat()
+ if obj.microsecond:
+ representation = representation[:12]
+ return representation
+ elif isinstance(obj, datetime.timedelta):
+ return str(obj.total_seconds())
+ elif isinstance(obj, decimal.Decimal):
+ # Serializers will coerce decimals to strings by default.
+ return float(obj)
+ elif isinstance(obj, QuerySet):
+ return list(obj)
+ elif hasattr(obj, 'tolist'):
+ # Numpy arrays and array scalars.
+ return obj.tolist()
+ elif hasattr(obj, '__getitem__'):
try:
- return dict(o)
+ return dict(obj)
except:
pass
- elif hasattr(o, '__iter__'):
- return [i for i in o]
- return super(JSONEncoder, self).default(o)
+ elif hasattr(obj, '__iter__'):
+ return [item for item in obj]
+ return super(JSONEncoder, self).default(obj)
try:
--
cgit v1.2.3
From 79715f01f8c34fdd55c2291b6b21d09fa3a8153e Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 12 Sep 2014 12:10:22 +0100
Subject: Coerce dates etc to ISO_8601 in seralizer, by default.
---
rest_framework/fields.py | 24 +++++++++++++-----------
rest_framework/settings.py | 21 ++++++++-------------
2 files changed, 21 insertions(+), 24 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 9d96cf5c..e1855ff7 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -431,10 +431,12 @@ class DecimalField(Field):
'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
}
- def __init__(self, max_digits, decimal_places, coerce_to_string=True, max_value=None, min_value=None, **kwargs):
+ coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING
+
+ def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, **kwargs):
self.max_digits = max_digits
self.decimal_places = decimal_places
- self.coerce_to_string = coerce_to_string
+ self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string
super(DecimalField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
@@ -510,12 +512,12 @@ class DateField(Field):
default_error_messages = {
'invalid': _('Date has wrong format. Use one of these formats instead: {format}'),
}
- input_formats = api_settings.DATE_INPUT_FORMATS
format = api_settings.DATE_FORMAT
+ input_formats = api_settings.DATE_INPUT_FORMATS
- def __init__(self, input_formats=None, format=None, *args, **kwargs):
- self.input_formats = input_formats if input_formats is not None else self.input_formats
+ def __init__(self, format=None, input_formats=None, *args, **kwargs):
self.format = format if format is not None else self.format
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
super(DateField, self).__init__(*args, **kwargs)
def to_internal_value(self, value):
@@ -569,12 +571,12 @@ class DateTimeField(Field):
default_error_messages = {
'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'),
}
- input_formats = api_settings.DATETIME_INPUT_FORMATS
format = api_settings.DATETIME_FORMAT
+ input_formats = api_settings.DATETIME_INPUT_FORMATS
- def __init__(self, input_formats=None, format=None, *args, **kwargs):
- self.input_formats = input_formats if input_formats is not None else self.input_formats
+ def __init__(self, format=None, input_formats=None, *args, **kwargs):
self.format = format if format is not None else self.format
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
super(DateTimeField, self).__init__(*args, **kwargs)
def to_internal_value(self, value):
@@ -634,12 +636,12 @@ class TimeField(Field):
default_error_messages = {
'invalid': _('Time has wrong format. Use one of these formats instead: {format}'),
}
- input_formats = api_settings.TIME_INPUT_FORMATS
format = api_settings.TIME_FORMAT
+ input_formats = api_settings.TIME_INPUT_FORMATS
- def __init__(self, input_formats=None, format=None, *args, **kwargs):
- self.input_formats = input_formats if input_formats is not None else self.input_formats
+ def __init__(self, format=None, input_formats=None, *args, **kwargs):
self.format = format if format is not None else self.format
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
super(TimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index e55610bb..421e146c 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -97,24 +97,19 @@ DEFAULTS = {
'URL_FIELD_NAME': 'url',
# Input and output formats
- 'DATE_INPUT_FORMATS': (
- ISO_8601,
- ),
- 'DATE_FORMAT': None,
+ 'DATE_FORMAT': ISO_8601,
+ 'DATE_INPUT_FORMATS': (ISO_8601,),
- 'DATETIME_INPUT_FORMATS': (
- ISO_8601,
- ),
- 'DATETIME_FORMAT': None,
+ 'DATETIME_FORMAT': ISO_8601,
+ 'DATETIME_INPUT_FORMATS': (ISO_8601,),
- 'TIME_INPUT_FORMATS': (
- ISO_8601,
- ),
- 'TIME_FORMAT': None,
+ 'TIME_FORMAT': ISO_8601,
+ 'TIME_INPUT_FORMATS': (ISO_8601,),
# Encoding
'UNICODE_JSON': True,
- 'COMPACT_JSON': True
+ 'COMPACT_JSON': True,
+ 'COERCE_DECIMAL_TO_STRING': True
}
--
cgit v1.2.3
From b73a205cc021983d9a508b447f30e144a1ce4129 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 12 Sep 2014 17:03:42 +0100
Subject: Tests for relational fields (not including many=True)
---
rest_framework/relations.py | 143 +++++++++++++++++++++++++++++---------------
1 file changed, 94 insertions(+), 49 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index e23a4152..75ec89a8 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -1,19 +1,24 @@
+from rest_framework.compat import smart_text, urlparse
from rest_framework.fields import Field
from rest_framework.reverse import reverse
-from django.core.exceptions import ObjectDoesNotExist
+from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured
from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
from django.db.models.query import QuerySet
-from rest_framework.compat import urlparse
+from django.utils.translation import ugettext_lazy as _
class RelatedField(Field):
def __init__(self, **kwargs):
self.queryset = kwargs.pop('queryset', None)
self.many = kwargs.pop('many', False)
- assert self.queryset is not None or kwargs.get('read_only', False), (
+ assert self.queryset is not None or kwargs.get('read_only', None), (
'Relational field must provide a `queryset` argument, '
'or set read_only=`True`.'
)
+ assert not (self.queryset is not None and kwargs.get('read_only', None)), (
+ 'Relational fields should not provide a `queryset` argument, '
+ 'when setting read_only=`True`.'
+ )
super(RelatedField, self).__init__(**kwargs)
def get_queryset(self):
@@ -25,6 +30,11 @@ class RelatedField(Field):
class StringRelatedField(Field):
+ """
+ A read only field that represents its targets using their
+ plain string representation.
+ """
+
def __init__(self, **kwargs):
kwargs['read_only'] = True
super(StringRelatedField, self).__init__(**kwargs)
@@ -34,10 +44,10 @@ class StringRelatedField(Field):
class PrimaryKeyRelatedField(RelatedField):
- MESSAGES = {
+ default_error_messages = {
'required': 'This field is required.',
'does_not_exist': "Invalid pk '{pk_value}' - object does not exist.",
- 'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.',
+ 'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.',
}
def to_internal_value(self, data):
@@ -48,22 +58,33 @@ class PrimaryKeyRelatedField(RelatedField):
except (TypeError, ValueError):
self.fail('incorrect_type', data_type=type(data).__name__)
+ def to_representation(self, value):
+ return value.pk
+
class HyperlinkedRelatedField(RelatedField):
lookup_field = 'pk'
- MESSAGES = {
+ default_error_messages = {
'required': 'This field is required.',
'no_match': 'Invalid hyperlink - No URL match',
'incorrect_match': 'Invalid hyperlink - Incorrect URL match.',
- 'does_not_exist': "Invalid hyperlink - Object does not exist.",
- 'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',
+ 'does_not_exist': 'Invalid hyperlink - Object does not exist.',
+ 'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',
}
- def __init__(self, **kwargs):
- self.view_name = kwargs.pop('view_name')
+ def __init__(self, view_name, **kwargs):
+ self.view_name = view_name
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
+ self.format = kwargs.pop('format', None)
+
+ # We include these simply for dependancy injection in tests.
+ # We can't add them as class attributes or they would expect an
+ # implict `self` argument to be passed.
+ self.reverse = reverse
+ self.resolve = resolve
+
super(HyperlinkedRelatedField, self).__init__(**kwargs)
def get_object(self, view_name, view_args, view_kwargs):
@@ -77,21 +98,36 @@ class HyperlinkedRelatedField(RelatedField):
lookup_kwargs = {self.lookup_field: lookup_value}
return self.get_queryset().get(**lookup_kwargs)
- def to_internal_value(self, value):
+ def get_url(self, obj, view_name, request, format):
+ """
+ Given an object, return the URL that hyperlinks to the object.
+
+ May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
+ attributes are not configured to correctly match the URL conf.
+ """
+ # Unsaved objects will not yet have a valid URL.
+ if obj.pk is None:
+ return None
+
+ lookup_value = getattr(obj, self.lookup_field)
+ kwargs = {self.lookup_url_kwarg: lookup_value}
+ return self.reverse(view_name, kwargs=kwargs, request=request, format=format)
+
+ def to_internal_value(self, data):
try:
- http_prefix = value.startswith(('http:', 'https:'))
+ http_prefix = data.startswith(('http:', 'https:'))
except AttributeError:
- self.fail('incorrect_type', data_type=type(value).__name__)
+ self.fail('incorrect_type', data_type=type(data).__name__)
if http_prefix:
# If needed convert absolute URLs to relative path
- value = urlparse.urlparse(value).path
+ data = urlparse.urlparse(data).path
prefix = get_script_prefix()
- if value.startswith(prefix):
- value = '/' + value[len(prefix):]
+ if data.startswith(prefix):
+ data = '/' + data[len(prefix):]
try:
- match = resolve(value)
+ match = self.resolve(data)
except Exception:
self.fail('no_match')
@@ -103,41 +139,14 @@ class HyperlinkedRelatedField(RelatedField):
except (ObjectDoesNotExist, TypeError, ValueError):
self.fail('does_not_exist')
-
-class HyperlinkedIdentityField(RelatedField):
- lookup_field = 'pk'
-
- def __init__(self, **kwargs):
- kwargs['read_only'] = True
- kwargs['source'] = '*'
- self.view_name = kwargs.pop('view_name')
- self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
- self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
- super(HyperlinkedIdentityField, self).__init__(**kwargs)
-
- def get_url(self, obj, view_name, request, format):
- """
- Given an object, return the URL that hyperlinks to the object.
-
- May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
- attributes are not configured to correctly match the URL conf.
- """
- # Unsaved objects will not yet have a valid URL.
- if obj.pk is None:
- return None
-
- lookup_value = getattr(obj, self.lookup_field)
- kwargs = {self.lookup_url_kwarg: lookup_value}
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
-
def to_representation(self, value):
request = self.context.get('request', None)
format = self.context.get('format', None)
assert request is not None, (
- "`HyperlinkedIdentityField` requires the request in the serializer"
+ "`%s` requires the request in the serializer"
" context. Add `context={'request': request}` when instantiating "
- "the serializer."
+ "the serializer." % self.__class__.__name__
)
# By default use whatever format is given for the current context
@@ -162,9 +171,45 @@ class HyperlinkedIdentityField(RelatedField):
'model in your API, or incorrectly configured the '
'`lookup_field` attribute on this field.'
)
- raise Exception(msg % self.view_name)
+ raise ImproperlyConfigured(msg % self.view_name)
+
+
+class HyperlinkedIdentityField(HyperlinkedRelatedField):
+ """
+ A read-only field that represents the identity URL for an object, itself.
+
+ This is in contrast to `HyperlinkedRelatedField` which represents the
+ URL of relationships to other objects.
+ """
+
+ def __init__(self, view_name, **kwargs):
+ kwargs['read_only'] = True
+ kwargs['source'] = '*'
+ super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
class SlugRelatedField(RelatedField):
- def __init__(self, **kwargs):
- self.slug_field = kwargs.pop('slug_field', None)
+ """
+ A read-write field the represents the target of the relationship
+ by a unique 'slug' attribute.
+ """
+
+ default_error_messages = {
+ 'does_not_exist': _("Object with {slug_name}={value} does not exist."),
+ 'invalid': _('Invalid value.'),
+ }
+
+ def __init__(self, slug_field, **kwargs):
+ self.slug_field = slug_field
+ super(SlugRelatedField, self).__init__(**kwargs)
+
+ def to_internal_value(self, data):
+ try:
+ return self.get_queryset().get(**{self.slug_field: data})
+ except ObjectDoesNotExist:
+ self.fail('does_not_exist', slug_name=self.slug_field, value=smart_text(data))
+ except (TypeError, ValueError):
+ self.fail('invalid')
+
+ def to_representation(self, obj):
+ return getattr(obj, self.slug_field)
--
cgit v1.2.3
From 0ac52e0808288892717c017e57c57aa8ad81e6d3 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 12 Sep 2014 17:06:37 +0100
Subject: Use Resolver404 instead of base Exception
---
rest_framework/relations.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 75ec89a8..46fe55ef 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -2,7 +2,7 @@ from rest_framework.compat import smart_text, urlparse
from rest_framework.fields import Field
from rest_framework.reverse import reverse
from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured
-from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
+from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch, Resolver404
from django.db.models.query import QuerySet
from django.utils.translation import ugettext_lazy as _
@@ -128,7 +128,7 @@ class HyperlinkedRelatedField(RelatedField):
try:
match = self.resolve(data)
- except Exception:
+ except Resolver404:
self.fail('no_match')
if match.view_name != self.view_name:
--
cgit v1.2.3
From e6c88a423361b084ba171af7a74a183bd557e73e Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 12 Sep 2014 19:54:27 +0100
Subject: Drop usage of validatiors.EMPTY_VALUES
---
rest_framework/fields.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index e1855ff7..33ab0682 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -223,7 +223,7 @@ class Field(object):
return value
def run_validators(self, value):
- if value in validators.EMPTY_VALUES:
+ if value in (None, '', [], (), {}):
return
errors = []
@@ -450,7 +450,7 @@ class DecimalField(Field):
than max_digits in the number, and no more than decimal_places digits
after the decimal point.
"""
- if value in validators.EMPTY_VALUES:
+ if value in (None, ''):
return None
value = smart_text(value).strip()
@@ -521,7 +521,7 @@ class DateField(Field):
super(DateField, self).__init__(*args, **kwargs)
def to_internal_value(self, value):
- if value in validators.EMPTY_VALUES:
+ if value in (None, ''):
return None
if isinstance(value, datetime.datetime):
@@ -580,7 +580,7 @@ class DateTimeField(Field):
super(DateTimeField, self).__init__(*args, **kwargs)
def to_internal_value(self, value):
- if value in validators.EMPTY_VALUES:
+ if value in (None, ''):
return None
if isinstance(value, datetime.datetime):
@@ -645,7 +645,7 @@ class TimeField(Field):
super(TimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
- if value in validators.EMPTY_VALUES:
+ if value in (None, ''):
return None
if isinstance(value, datetime.time):
--
cgit v1.2.3
From afb28a44ad1737cd6fcd6da50ba9552f38293368 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 12 Sep 2014 21:32:20 +0100
Subject: Dealing with reverse relationships
---
rest_framework/serializers.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 0c2aedfa..ecb2829b 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -157,7 +157,7 @@ class Serializer(BaseSerializer):
def __init__(self, *args, **kwargs):
self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
- kwargs.pop('many', False)
+ kwargs.pop('many', None)
super(Serializer, self).__init__(*args, **kwargs)
@@ -423,9 +423,9 @@ class ModelSerializer(Serializer):
for accessor_name, relation_info in info.reverse_relations.items():
if accessor_name in self.opts.fields:
if self.opts.depth:
- ret[field_name] = self.get_nested_field(*relation_info)
+ ret[accessor_name] = self.get_nested_field(*relation_info)
else:
- ret[field_name] = self.get_related_field(*relation_info)
+ ret[accessor_name] = self.get_related_field(*relation_info)
return ret
@@ -444,7 +444,7 @@ class ModelSerializer(Serializer):
Note that model_field will be `None` for reverse relationships.
"""
- class NestedModelSerializer(ModelSerializer):
+ class NestedModelSerializer(ModelSerializer): # Not right!
class Meta:
model = related_model
depth = self.opts.depth - 1
--
cgit v1.2.3
From 40dc588a372375608701f7e521dea6d860a49eb2 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 15 Sep 2014 09:50:51 +0100
Subject: Drop label from serializer fields when not needed
---
rest_framework/fields.py | 6 ++-
rest_framework/serializers.py | 56 ++++++++++++---------
rest_framework/utils/model_meta.py | 99 ++++++++++++++++++++++++++++++++++++++
rest_framework/utils/modelinfo.py | 99 --------------------------------------
4 files changed, 137 insertions(+), 123 deletions(-)
create mode 100644 rest_framework/utils/model_meta.py
delete mode 100644 rest_framework/utils/modelinfo.py
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 33ab0682..1818e705 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -80,6 +80,10 @@ def set_value(dictionary, keys, value):
dictionary[keys[-1]] = value
+def field_name_to_label(field_name):
+ return field_name.replace('_', ' ').capitalize()
+
+
class SkipField(Exception):
pass
@@ -158,7 +162,7 @@ class Field(object):
# `self.label` should deafult to being based on the field name.
if self.label is None:
- self.label = self.field_name.replace('_', ' ').capitalize()
+ self.label = field_name_to_label(self.field_name)
# self.source should default to being the same as the field name.
if self.source is None:
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index ecb2829b..ba8d475f 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -15,11 +15,12 @@ from django.core.exceptions import ValidationError
from django.db import models
from django.utils import six
from django.utils.datastructures import SortedDict
+from django.utils.text import capfirst
from collections import namedtuple
from rest_framework.compat import clean_manytomany_helptext
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
-from rest_framework.utils import html, modelinfo, representation
+from rest_framework.utils import html, model_meta, representation
import copy
# Note: We do the following so that users of the framework can use this style:
@@ -334,6 +335,14 @@ def lookup_class(mapping, instance):
raise KeyError('Class %s not found in lookup.', cls.__name__)
+def needs_label(model_field, field_name):
+ """
+ Returns `True` if the label based on the model's verbose name
+ is not equal to the default label it would have based on it's field name.
+ """
+ return capfirst(model_field.verbose_name) != field_name_to_label(field_name)
+
+
class ModelSerializer(Serializer):
field_mapping = {
models.AutoField: IntegerField,
@@ -397,54 +406,55 @@ class ModelSerializer(Serializer):
"""
Return all the fields that should be serialized for the model.
"""
- info = modelinfo.get_field_info(self.opts.model)
+ info = model_meta.get_field_info(self.opts.model)
ret = SortedDict()
serializer_url_field = self.get_url_field()
if serializer_url_field:
ret[api_settings.URL_FIELD_NAME] = serializer_url_field
- serializer_pk_field = self.get_pk_field(info.pk)
+ field_name = info.pk.name
+ serializer_pk_field = self.get_pk_field(field_name, info.pk)
if serializer_pk_field:
- ret[info.pk.name] = serializer_pk_field
+ ret[field_name] = serializer_pk_field
# Regular fields
for field_name, field in info.fields.items():
- ret[field_name] = self.get_field(field)
+ ret[field_name] = self.get_field(field_name, field)
# Forward relations
for field_name, relation_info in info.forward_relations.items():
if self.opts.depth:
- ret[field_name] = self.get_nested_field(*relation_info)
+ ret[field_name] = self.get_nested_field(field_name, *relation_info)
else:
- ret[field_name] = self.get_related_field(*relation_info)
+ ret[field_name] = self.get_related_field(field_name, *relation_info)
# Reverse relations
for accessor_name, relation_info in info.reverse_relations.items():
if accessor_name in self.opts.fields:
if self.opts.depth:
- ret[accessor_name] = self.get_nested_field(*relation_info)
+ ret[accessor_name] = self.get_nested_field(accessor_name, *relation_info)
else:
- ret[accessor_name] = self.get_related_field(*relation_info)
+ ret[accessor_name] = self.get_related_field(accessor_name, *relation_info)
return ret
def get_url_field(self):
return None
- def get_pk_field(self, model_field):
+ def get_pk_field(self, field_name, model_field):
"""
Returns a default instance of the pk field.
"""
- return self.get_field(model_field)
+ return self.get_field(field_name, model_field)
- def get_nested_field(self, model_field, related_model, to_many, has_through_model):
+ def get_nested_field(self, field_name, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a nested relational field.
Note that model_field will be `None` for reverse relationships.
"""
- class NestedModelSerializer(ModelSerializer): # Not right!
+ class NestedModelSerializer(ModelSerializer):
class Meta:
model = related_model
depth = self.opts.depth - 1
@@ -454,7 +464,7 @@ class ModelSerializer(Serializer):
kwargs['many'] = True
return NestedModelSerializer(**kwargs)
- def get_related_field(self, model_field, related_model, to_many, has_through_model):
+ def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
@@ -474,8 +484,8 @@ class ModelSerializer(Serializer):
if model_field:
if model_field.null or model_field.blank:
kwargs['required'] = False
- if model_field.verbose_name:
- kwargs['label'] = model_field.verbose_name
+ if model_field.verbose_name and needs_label(model_field, field_name):
+ kwargs['label'] = capfirst(model_field.verbose_name)
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
@@ -485,7 +495,7 @@ class ModelSerializer(Serializer):
return PrimaryKeyRelatedField(**kwargs)
- def get_field(self, model_field):
+ def get_field(self, field_name, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
@@ -496,8 +506,8 @@ class ModelSerializer(Serializer):
if model_field.null or model_field.blank:
kwargs['required'] = False
- if model_field.verbose_name is not None:
- kwargs['label'] = model_field.verbose_name
+ if model_field.verbose_name and needs_label(model_field, field_name):
+ kwargs['label'] = capfirst(model_field.verbose_name)
if model_field.help_text:
kwargs['help_text'] = model_field.help_text
@@ -642,11 +652,11 @@ class HyperlinkedModelSerializer(ModelSerializer):
return HyperlinkedIdentityField(**kwargs)
- def get_pk_field(self, model_field):
+ def get_pk_field(self, field_name, model_field):
if self.opts.fields and model_field.name in self.opts.fields:
return self.get_field(model_field)
- def get_related_field(self, model_field, related_model, to_many, has_through_model):
+ def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
"""
@@ -665,8 +675,8 @@ class HyperlinkedModelSerializer(ModelSerializer):
if model_field:
if model_field.null or model_field.blank:
kwargs['required'] = False
- if model_field.verbose_name:
- kwargs['label'] = model_field.verbose_name
+ if model_field.verbose_name and needs_label(model_field, field_name):
+ kwargs['label'] = capfirst(model_field.verbose_name)
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py
new file mode 100644
index 00000000..960fa4d0
--- /dev/null
+++ b/rest_framework/utils/model_meta.py
@@ -0,0 +1,99 @@
+"""
+Helper functions for returning the field information that is associated
+with a model class. This includes returning all the forward and reverse
+relationships and their associated metadata.
+"""
+from collections import namedtuple
+from django.db import models
+from django.utils import six
+from django.utils.datastructures import SortedDict
+import inspect
+
+FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
+RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model'])
+
+
+def _resolve_model(obj):
+ """
+ Resolve supplied `obj` to a Django model class.
+
+ `obj` must be a Django model class itself, or a string
+ representation of one. Useful in situtations like GH #1225 where
+ Django may not have resolved a string-based reference to a model in
+ another model's foreign key definition.
+
+ String representations should have the format:
+ 'appname.ModelName'
+ """
+ if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
+ app_name, model_name = obj.split('.')
+ return models.get_model(app_name, model_name)
+ elif inspect.isclass(obj) and issubclass(obj, models.Model):
+ return obj
+ raise ValueError("{0} is not a Django model".format(obj))
+
+
+def get_field_info(model):
+ """
+ Given a model class, returns a `FieldInfo` instance containing metadata
+ about the various field types on the model.
+ """
+ opts = model._meta.concrete_model._meta
+
+ # Deal with the primary key.
+ pk = opts.pk
+ while pk.rel and pk.rel.parent_link:
+ # If model is a child via multitable inheritance, use parent's pk.
+ pk = pk.rel.to._meta.pk
+
+ # Deal with regular fields.
+ fields = SortedDict()
+ for field in [field for field in opts.fields if field.serialize and not field.rel]:
+ fields[field.name] = field
+
+ # Deal with forward relationships.
+ forward_relations = SortedDict()
+ for field in [field for field in opts.fields if field.serialize and field.rel]:
+ forward_relations[field.name] = RelationInfo(
+ field=field,
+ related=_resolve_model(field.rel.to),
+ to_many=False,
+ has_through_model=False
+ )
+
+ # Deal with forward many-to-many relationships.
+ for field in [field for field in opts.many_to_many if field.serialize]:
+ forward_relations[field.name] = RelationInfo(
+ field=field,
+ related=_resolve_model(field.rel.to),
+ to_many=True,
+ has_through_model=(
+ not field.rel.through._meta.auto_created
+ )
+ )
+
+ # Deal with reverse relationships.
+ reverse_relations = SortedDict()
+ for relation in opts.get_all_related_objects():
+ accessor_name = relation.get_accessor_name()
+ reverse_relations[accessor_name] = RelationInfo(
+ field=None,
+ related=relation.model,
+ to_many=relation.field.rel.multiple,
+ has_through_model=False
+ )
+
+ # Deal with reverse many-to-many relationships.
+ for relation in opts.get_all_related_many_to_many_objects():
+ accessor_name = relation.get_accessor_name()
+ reverse_relations[accessor_name] = RelationInfo(
+ field=None,
+ related=relation.model,
+ to_many=True,
+ has_through_model=(
+ hasattr(relation.field.rel, 'through') and
+ not relation.field.rel.through._meta.auto_created
+ )
+ )
+
+ return FieldInfo(pk, fields, forward_relations, reverse_relations)
diff --git a/rest_framework/utils/modelinfo.py b/rest_framework/utils/modelinfo.py
deleted file mode 100644
index 960fa4d0..00000000
--- a/rest_framework/utils/modelinfo.py
+++ /dev/null
@@ -1,99 +0,0 @@
-"""
-Helper functions for returning the field information that is associated
-with a model class. This includes returning all the forward and reverse
-relationships and their associated metadata.
-"""
-from collections import namedtuple
-from django.db import models
-from django.utils import six
-from django.utils.datastructures import SortedDict
-import inspect
-
-FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
-RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model'])
-
-
-def _resolve_model(obj):
- """
- Resolve supplied `obj` to a Django model class.
-
- `obj` must be a Django model class itself, or a string
- representation of one. Useful in situtations like GH #1225 where
- Django may not have resolved a string-based reference to a model in
- another model's foreign key definition.
-
- String representations should have the format:
- 'appname.ModelName'
- """
- if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
- app_name, model_name = obj.split('.')
- return models.get_model(app_name, model_name)
- elif inspect.isclass(obj) and issubclass(obj, models.Model):
- return obj
- raise ValueError("{0} is not a Django model".format(obj))
-
-
-def get_field_info(model):
- """
- Given a model class, returns a `FieldInfo` instance containing metadata
- about the various field types on the model.
- """
- opts = model._meta.concrete_model._meta
-
- # Deal with the primary key.
- pk = opts.pk
- while pk.rel and pk.rel.parent_link:
- # If model is a child via multitable inheritance, use parent's pk.
- pk = pk.rel.to._meta.pk
-
- # Deal with regular fields.
- fields = SortedDict()
- for field in [field for field in opts.fields if field.serialize and not field.rel]:
- fields[field.name] = field
-
- # Deal with forward relationships.
- forward_relations = SortedDict()
- for field in [field for field in opts.fields if field.serialize and field.rel]:
- forward_relations[field.name] = RelationInfo(
- field=field,
- related=_resolve_model(field.rel.to),
- to_many=False,
- has_through_model=False
- )
-
- # Deal with forward many-to-many relationships.
- for field in [field for field in opts.many_to_many if field.serialize]:
- forward_relations[field.name] = RelationInfo(
- field=field,
- related=_resolve_model(field.rel.to),
- to_many=True,
- has_through_model=(
- not field.rel.through._meta.auto_created
- )
- )
-
- # Deal with reverse relationships.
- reverse_relations = SortedDict()
- for relation in opts.get_all_related_objects():
- accessor_name = relation.get_accessor_name()
- reverse_relations[accessor_name] = RelationInfo(
- field=None,
- related=relation.model,
- to_many=relation.field.rel.multiple,
- has_through_model=False
- )
-
- # Deal with reverse many-to-many relationships.
- for relation in opts.get_all_related_many_to_many_objects():
- accessor_name = relation.get_accessor_name()
- reverse_relations[accessor_name] = RelationInfo(
- field=None,
- related=relation.model,
- to_many=True,
- has_through_model=(
- hasattr(relation.field.rel, 'through') and
- not relation.field.rel.through._meta.auto_created
- )
- )
-
- return FieldInfo(pk, fields, forward_relations, reverse_relations)
--
cgit v1.2.3
From d196608d5af912057baba79ab13d05d876368ad2 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 15 Sep 2014 13:55:09 +0100
Subject: Fix nested model serializer base class
---
rest_framework/serializers.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
(limited to 'rest_framework')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index ba8d475f..40d76897 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -368,6 +368,7 @@ class ModelSerializer(Serializer):
models.TimeField: TimeField,
models.URLField: URLField,
}
+ nested_class = None # We fill this in at the end of this module.
_options_class = ModelSerializerOptions
@@ -454,7 +455,7 @@ class ModelSerializer(Serializer):
Note that model_field will be `None` for reverse relationships.
"""
- class NestedModelSerializer(ModelSerializer):
+ class NestedModelSerializer(self.nested_class):
class Meta:
model = related_model
depth = self.opts.depth - 1
@@ -694,3 +695,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
'app_label': model._meta.app_label,
'model_name': model._meta.object_name.lower()
}
+
+
+ModelSerializer.nested_class = ModelSerializer
+HyperlinkedModelSerializer.nested_class = HyperlinkedModelSerializer
--
cgit v1.2.3
From c0155fd9dc654dc5932effd46a00f66495ce700b Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 17 Sep 2014 14:11:53 +0100
Subject: Update comments
---
rest_framework/serializers.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'rest_framework')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 40d76897..1fea1380 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -410,10 +410,12 @@ class ModelSerializer(Serializer):
info = model_meta.get_field_info(self.opts.model)
ret = SortedDict()
+ # URL field
serializer_url_field = self.get_url_field()
if serializer_url_field:
ret[api_settings.URL_FIELD_NAME] = serializer_url_field
+ # Primary key field
field_name = info.pk.name
serializer_pk_field = self.get_pk_field(field_name, info.pk)
if serializer_pk_field:
--
cgit v1.2.3
From 5b7e4af0d657a575cb15eea85a63a7100c636085 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 18 Sep 2014 11:20:56 +0100
Subject: get_base_field() refactor
---
rest_framework/fields.py | 6 +-
rest_framework/relations.py | 9 +-
rest_framework/serializers.py | 464 ++++++++--------------------------
rest_framework/utils/field_mapping.py | 215 ++++++++++++++++
rest_framework/utils/model_meta.py | 46 +++-
5 files changed, 364 insertions(+), 376 deletions(-)
create mode 100644 rest_framework/utils/field_mapping.py
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 1818e705..0c78b3fb 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -80,10 +80,6 @@ def set_value(dictionary, keys, value):
dictionary[keys[-1]] = value
-def field_name_to_label(field_name):
- return field_name.replace('_', ' ').capitalize()
-
-
class SkipField(Exception):
pass
@@ -162,7 +158,7 @@ class Field(object):
# `self.label` should deafult to being based on the field name.
if self.label is None:
- self.label = field_name_to_label(self.field_name)
+ self.label = field_name.replace('_', ' ').capitalize()
# self.source should default to being the same as the field name.
if self.source is None:
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 46fe55ef..9f44ab63 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -73,7 +73,8 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',
}
- def __init__(self, view_name, **kwargs):
+ def __init__(self, view_name=None, **kwargs):
+ assert view_name is not None, 'The `view_name` argument is required.'
self.view_name = view_name
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
@@ -182,7 +183,8 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField):
URL of relationships to other objects.
"""
- def __init__(self, view_name, **kwargs):
+ def __init__(self, view_name=None, **kwargs):
+ assert view_name is not None, 'The `view_name` argument is required.'
kwargs['read_only'] = True
kwargs['source'] = '*'
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
@@ -199,7 +201,8 @@ class SlugRelatedField(RelatedField):
'invalid': _('Invalid value.'),
}
- def __init__(self, slug_field, **kwargs):
+ def __init__(self, slug_field=None, **kwargs):
+ assert slug_field is not None, 'The `slug_field` argument is required.'
self.slug_field = slug_field
super(SlugRelatedField, self).__init__(**kwargs)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 1fea1380..99dcc349 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -10,17 +10,19 @@ python primitives.
2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
-from django.core import validators
from django.core.exceptions import ValidationError
from django.db import models
from django.utils import six
from django.utils.datastructures import SortedDict
-from django.utils.text import capfirst
from collections import namedtuple
-from rest_framework.compat import clean_manytomany_helptext
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
from rest_framework.utils import html, model_meta, representation
+from rest_framework.utils.field_mapping import (
+ get_url_kwargs, get_field_kwargs,
+ get_relation_kwargs, get_nested_relation_kwargs,
+ lookup_class
+)
import copy
# Note: We do the following so that users of the framework can use this style:
@@ -126,7 +128,7 @@ class SerializerMetaclass(type):
"""
@classmethod
- def _get_fields(cls, bases, attrs):
+ def _get_declared_fields(cls, bases, attrs):
fields = [(field_name, attrs.pop(field_name))
for field_name, obj in list(attrs.items())
if isinstance(obj, Field)]
@@ -136,25 +138,18 @@ class SerializerMetaclass(type):
# fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields.
for base in bases[::-1]:
- if hasattr(base, 'base_fields'):
- fields = list(base.base_fields.items()) + fields
+ if hasattr(base, '_declared_fields'):
+ fields = list(base._declared_fields.items()) + fields
return SortedDict(fields)
def __new__(cls, name, bases, attrs):
- attrs['base_fields'] = cls._get_fields(bases, attrs)
+ attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
@six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer):
-
- def __new__(cls, *args, **kwargs):
- if kwargs.pop('many', False):
- kwargs['child'] = cls()
- return ListSerializer(*args, **kwargs)
- return super(Serializer, cls).__new__(cls, *args, **kwargs)
-
def __init__(self, *args, **kwargs):
self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
@@ -165,14 +160,22 @@ class Serializer(BaseSerializer):
# Every new serializer is created with a clone of the field instances.
# This allows users to dynamically modify the fields on a serializer
# instance without affecting every other serializer class.
- self.fields = self.get_fields()
+ self.fields = self._get_base_fields()
# Setup all the child fields, to provide them with the current context.
for field_name, field in self.fields.items():
field.bind(field_name, self, self)
- def get_fields(self):
- return copy.deepcopy(self.base_fields)
+ def __new__(cls, *args, **kwargs):
+ # We override this method in order to automagically create
+ # `ListSerializer` classes instead when `many=True` is set.
+ if kwargs.pop('many', False):
+ kwargs['child'] = cls()
+ return ListSerializer(*args, **kwargs)
+ return super(Serializer, cls).__new__(cls, *args, **kwargs)
+
+ def _get_base_fields(self):
+ return copy.deepcopy(self._declared_fields)
def bind(self, field_name, parent, root):
# If the serializer is used as a field then when it becomes bound
@@ -312,39 +315,8 @@ class ListSerializer(BaseSerializer):
return representation.list_repr(self, indent=1)
-class ModelSerializerOptions(object):
- """
- Meta class options for ModelSerializer
- """
- def __init__(self, meta):
- self.model = getattr(meta, 'model')
- self.fields = getattr(meta, 'fields', ())
- self.depth = getattr(meta, 'depth', 0)
-
-
-def lookup_class(mapping, instance):
- """
- Takes a dictionary with classes as keys, and an object.
- Traverses the object's inheritance hierarchy in method
- resolution order, and returns the first matching value
- from the dictionary or raises a KeyError if nothing matches.
- """
- for cls in inspect.getmro(instance.__class__):
- if cls in mapping:
- return mapping[cls]
- raise KeyError('Class %s not found in lookup.', cls.__name__)
-
-
-def needs_label(model_field, field_name):
- """
- Returns `True` if the label based on the model's verbose name
- is not equal to the default label it would have based on it's field name.
- """
- return capfirst(model_field.verbose_name) != field_name_to_label(field_name)
-
-
class ModelSerializer(Serializer):
- field_mapping = {
+ _field_mapping = {
models.AutoField: IntegerField,
models.BigIntegerField: IntegerField,
models.BooleanField: BooleanField,
@@ -368,16 +340,10 @@ class ModelSerializer(Serializer):
models.TimeField: TimeField,
models.URLField: URLField,
}
- nested_class = None # We fill this in at the end of this module.
-
- _options_class = ModelSerializerOptions
-
- def __init__(self, *args, **kwargs):
- self.opts = self._options_class(self.Meta)
- super(ModelSerializer, self).__init__(*args, **kwargs)
+ _related_class = PrimaryKeyRelatedField
def create(self, attrs):
- ModelClass = self.opts.model
+ ModelClass = self.Meta.model
return ModelClass.objects.create(**attrs)
def update(self, obj, attrs):
@@ -385,319 +351,97 @@ class ModelSerializer(Serializer):
setattr(obj, attr, value)
obj.save()
- def get_fields(self):
- # Get the explicitly declared fields.
- fields = copy.deepcopy(self.base_fields)
+ def _get_base_fields(self):
+ declared_fields = copy.deepcopy(self._declared_fields)
- # Add in the default fields.
- for key, val in self.get_default_fields().items():
- if key not in fields:
- fields[key] = val
-
- # If `fields` is set on the `Meta` class,
- # then use only those fields, and in that order.
- if self.opts.fields:
- fields = SortedDict([
- (key, fields[key]) for key in self.opts.fields
- ])
-
- return fields
-
- def get_default_fields(self):
- """
- Return all the fields that should be serialized for the model.
- """
- info = model_meta.get_field_info(self.opts.model)
ret = SortedDict()
+ model = getattr(self.Meta, 'model')
+ fields = getattr(self.Meta, 'fields', None)
+ depth = getattr(self.Meta, 'depth', 0)
+
+ # Retrieve metadata about fields & relationships on the model class.
+ info = model_meta.get_field_info(model)
+
+ # Use the default set of fields if none is supplied explicitly.
+ if fields is None:
+ fields = self._get_default_field_names(declared_fields, info)
+
+ for field_name in fields:
+ if field_name in declared_fields:
+ # Field is explicitly declared on the class, use that.
+ ret[field_name] = declared_fields[field_name]
+ continue
+
+ elif field_name == api_settings.URL_FIELD_NAME:
+ # Create the URL field.
+ field_cls = HyperlinkedIdentityField
+ kwargs = get_url_kwargs(model)
+
+ elif field_name in info.fields_and_pk:
+ # Create regular model fields.
+ model_field = info.fields_and_pk[field_name]
+ field_cls = lookup_class(self._field_mapping, model_field)
+ kwargs = get_field_kwargs(field_name, model_field)
+ if 'choices' in kwargs:
+ # Fields with choices get coerced into `ChoiceField`
+ # instead of using their regular typed field.
+ field_cls = ChoiceField
+ if not issubclass(field_cls, ModelField):
+ # `model_field` is only valid for the fallback case of
+ # `ModelField`, which is used when no other typed field
+ # matched to the model field.
+ kwargs.pop('model_field', None)
+
+ elif field_name in info.relations:
+ # Create forward and reverse relationships.
+ relation_info = info.relations[field_name]
+ if depth:
+ field_cls = self._get_nested_class(depth, relation_info)
+ kwargs = get_nested_relation_kwargs(relation_info)
+ else:
+ field_cls = self._related_class
+ kwargs = get_relation_kwargs(field_name, relation_info)
+ # `view_name` is only valid for hyperlinked relationships.
+ if not issubclass(field_cls, HyperlinkedRelatedField):
+ kwargs.pop('view_name', None)
- # URL field
- serializer_url_field = self.get_url_field()
- if serializer_url_field:
- ret[api_settings.URL_FIELD_NAME] = serializer_url_field
-
- # Primary key field
- field_name = info.pk.name
- serializer_pk_field = self.get_pk_field(field_name, info.pk)
- if serializer_pk_field:
- ret[field_name] = serializer_pk_field
-
- # Regular fields
- for field_name, field in info.fields.items():
- ret[field_name] = self.get_field(field_name, field)
-
- # Forward relations
- for field_name, relation_info in info.forward_relations.items():
- if self.opts.depth:
- ret[field_name] = self.get_nested_field(field_name, *relation_info)
else:
- ret[field_name] = self.get_related_field(field_name, *relation_info)
+ assert False, 'Field name `%s` is not valid.' % field_name
- # Reverse relations
- for accessor_name, relation_info in info.reverse_relations.items():
- if accessor_name in self.opts.fields:
- if self.opts.depth:
- ret[accessor_name] = self.get_nested_field(accessor_name, *relation_info)
- else:
- ret[accessor_name] = self.get_related_field(accessor_name, *relation_info)
+ ret[field_name] = field_cls(**kwargs)
return ret
- def get_url_field(self):
- return None
-
- def get_pk_field(self, field_name, model_field):
- """
- Returns a default instance of the pk field.
- """
- return self.get_field(field_name, model_field)
-
- def get_nested_field(self, field_name, model_field, related_model, to_many, has_through_model):
- """
- Creates a default instance of a nested relational field.
+ def _get_default_field_names(self, declared_fields, model_info):
+ return (
+ [model_info.pk.name] +
+ list(declared_fields.keys()) +
+ list(model_info.fields.keys()) +
+ list(model_info.forward_relations.keys())
+ )
- Note that model_field will be `None` for reverse relationships.
- """
- class NestedModelSerializer(self.nested_class):
+ def _get_nested_class(self, nested_depth, relation_info):
+ class NestedSerializer(ModelSerializer):
class Meta:
- model = related_model
- depth = self.opts.depth - 1
-
- kwargs = {'read_only': True}
- if to_many:
- kwargs['many'] = True
- return NestedModelSerializer(**kwargs)
-
- def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model):
- """
- Creates a default instance of a flat relational field.
-
- Note that model_field will be `None` for reverse relationships.
- """
- kwargs = {
- 'queryset': related_model._default_manager,
- }
-
- if to_many:
- kwargs['many'] = True
-
- if has_through_model:
- kwargs['read_only'] = True
- kwargs.pop('queryset', None)
-
- if model_field:
- if model_field.null or model_field.blank:
- kwargs['required'] = False
- if model_field.verbose_name and needs_label(model_field, field_name):
- kwargs['label'] = capfirst(model_field.verbose_name)
- if not model_field.editable:
- kwargs['read_only'] = True
- kwargs.pop('queryset', None)
- help_text = clean_manytomany_helptext(model_field.help_text)
- if help_text:
- kwargs['help_text'] = help_text
-
- return PrimaryKeyRelatedField(**kwargs)
-
- def get_field(self, field_name, model_field):
- """
- Creates a default instance of a basic non-relational field.
- """
- serializer_cls = lookup_class(self.field_mapping, model_field)
- kwargs = {}
- validator_kwarg = model_field.validators
-
- if model_field.null or model_field.blank:
- kwargs['required'] = False
-
- if model_field.verbose_name and needs_label(model_field, field_name):
- kwargs['label'] = capfirst(model_field.verbose_name)
-
- if model_field.help_text:
- kwargs['help_text'] = model_field.help_text
-
- if isinstance(model_field, models.AutoField) or not model_field.editable:
- kwargs['read_only'] = True
- # Read only implies that the field is not required.
- # We have a cleaner repr on the instance if we don't set it.
- kwargs.pop('required', None)
-
- if model_field.has_default():
- kwargs['default'] = model_field.get_default()
- # Having a default implies that the field is not required.
- # We have a cleaner repr on the instance if we don't set it.
- kwargs.pop('required', None)
-
- if model_field.flatchoices:
- # If this model field contains choices, then use a ChoiceField,
- # rather than the standard serializer field for this type.
- # Note that we return this prior to setting any validation type
- # keyword arguments, as those are not valid initializers.
- kwargs['choices'] = model_field.flatchoices
- return ChoiceField(**kwargs)
-
- # Ensure that max_length is passed explicitly as a keyword arg,
- # rather than as a validator.
- max_length = getattr(model_field, 'max_length', None)
- if max_length is not None:
- kwargs['max_length'] = max_length
- validator_kwarg = [
- validator for validator in validator_kwarg
- if not isinstance(validator, validators.MaxLengthValidator)
- ]
-
- # Ensure that min_length is passed explicitly as a keyword arg,
- # rather than as a validator.
- min_length = getattr(model_field, 'min_length', None)
- if min_length is not None:
- kwargs['min_length'] = min_length
- validator_kwarg = [
- validator for validator in validator_kwarg
- if not isinstance(validator, validators.MinLengthValidator)
- ]
-
- # Ensure that max_value is passed explicitly as a keyword arg,
- # rather than as a validator.
- max_value = next((
- validator.limit_value for validator in validator_kwarg
- if isinstance(validator, validators.MaxValueValidator)
- ), None)
- if max_value is not None:
- kwargs['max_value'] = max_value
- validator_kwarg = [
- validator for validator in validator_kwarg
- if not isinstance(validator, validators.MaxValueValidator)
- ]
-
- # Ensure that max_value is passed explicitly as a keyword arg,
- # rather than as a validator.
- min_value = next((
- validator.limit_value for validator in validator_kwarg
- if isinstance(validator, validators.MinValueValidator)
- ), None)
- if min_value is not None:
- kwargs['min_value'] = min_value
- validator_kwarg = [
- validator for validator in validator_kwarg
- if not isinstance(validator, validators.MinValueValidator)
- ]
-
- # URLField does not need to include the URLValidator argument,
- # as it is explicitly added in.
- if isinstance(model_field, models.URLField):
- validator_kwarg = [
- validator for validator in validator_kwarg
- if not isinstance(validator, validators.URLValidator)
- ]
-
- # EmailField does not need to include the validate_email argument,
- # as it is explicitly added in.
- if isinstance(model_field, models.EmailField):
- validator_kwarg = [
- validator for validator in validator_kwarg
- if validator is not validators.validate_email
- ]
-
- # SlugField do not need to include the 'validate_slug' argument,
- if isinstance(model_field, models.SlugField):
- validator_kwarg = [
- validator for validator in validator_kwarg
- if validator is not validators.validate_slug
- ]
-
- max_digits = getattr(model_field, 'max_digits', None)
- if max_digits is not None:
- kwargs['max_digits'] = max_digits
-
- decimal_places = getattr(model_field, 'decimal_places', None)
- if decimal_places is not None:
- kwargs['decimal_places'] = decimal_places
-
- if isinstance(model_field, models.BooleanField):
- # models.BooleanField has `blank=True`, but *is* actually
- # required *unless* a default is provided.
- # Also note that Django<1.6 uses `default=False` for
- # models.BooleanField, but Django>=1.6 uses `default=None`.
- kwargs.pop('required', None)
-
- if validator_kwarg:
- kwargs['validators'] = validator_kwarg
-
- if issubclass(serializer_cls, ModelField):
- kwargs['model_field'] = model_field
-
- return serializer_cls(**kwargs)
-
-
-class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
- """
- Options for HyperlinkedModelSerializer
- """
- def __init__(self, meta):
- super(HyperlinkedModelSerializerOptions, self).__init__(meta)
- self.view_name = getattr(meta, 'view_name', None)
- self.lookup_field = getattr(meta, 'lookup_field', None)
+ model = relation_info.related
+ depth = nested_depth
+ return NestedSerializer
class HyperlinkedModelSerializer(ModelSerializer):
- _options_class = HyperlinkedModelSerializerOptions
-
- def get_url_field(self):
- if self.opts.view_name is not None:
- view_name = self.opts.view_name
- else:
- view_name = self.get_default_view_name(self.opts.model)
-
- kwargs = {
- 'view_name': view_name
- }
- if self.opts.lookup_field:
- kwargs['lookup_field'] = self.opts.lookup_field
-
- return HyperlinkedIdentityField(**kwargs)
-
- def get_pk_field(self, field_name, model_field):
- if self.opts.fields and model_field.name in self.opts.fields:
- return self.get_field(model_field)
-
- def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model):
- """
- Creates a default instance of a flat relational field.
- """
- kwargs = {
- 'queryset': related_model._default_manager,
- 'view_name': self.get_default_view_name(related_model),
- }
-
- if to_many:
- kwargs['many'] = True
-
- if has_through_model:
- kwargs['read_only'] = True
- kwargs.pop('queryset', None)
-
- if model_field:
- if model_field.null or model_field.blank:
- kwargs['required'] = False
- if model_field.verbose_name and needs_label(model_field, field_name):
- kwargs['label'] = capfirst(model_field.verbose_name)
- if not model_field.editable:
- kwargs['read_only'] = True
- kwargs.pop('queryset', None)
- help_text = clean_manytomany_helptext(model_field.help_text)
- if help_text:
- kwargs['help_text'] = help_text
-
- return HyperlinkedRelatedField(**kwargs)
-
- def get_default_view_name(self, model):
- """
- Return the view name to use for related models.
- """
- return '%(model_name)s-detail' % {
- 'app_label': model._meta.app_label,
- 'model_name': model._meta.object_name.lower()
- }
-
-
-ModelSerializer.nested_class = ModelSerializer
-HyperlinkedModelSerializer.nested_class = HyperlinkedModelSerializer
+ _related_class = HyperlinkedRelatedField
+
+ def _get_default_field_names(self, declared_fields, model_info):
+ return (
+ [api_settings.URL_FIELD_NAME] +
+ list(declared_fields.keys()) +
+ list(model_info.fields.keys()) +
+ list(model_info.forward_relations.keys())
+ )
+
+ def _get_nested_class(self, nested_depth, relation_info):
+ class NestedSerializer(HyperlinkedModelSerializer):
+ class Meta:
+ model = relation_info.related
+ depth = nested_depth
+ return NestedSerializer
diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py
new file mode 100644
index 00000000..be72e444
--- /dev/null
+++ b/rest_framework/utils/field_mapping.py
@@ -0,0 +1,215 @@
+"""
+Helper functions for mapping model fields to a dictionary of default
+keyword arguments that should be used for their equivelent serializer fields.
+"""
+from django.core import validators
+from django.db import models
+from django.utils.text import capfirst
+from rest_framework.compat import clean_manytomany_helptext
+import inspect
+
+
+def lookup_class(mapping, instance):
+ """
+ Takes a dictionary with classes as keys, and an object.
+ Traverses the object's inheritance hierarchy in method
+ resolution order, and returns the first matching value
+ from the dictionary or raises a KeyError if nothing matches.
+ """
+ for cls in inspect.getmro(instance.__class__):
+ if cls in mapping:
+ return mapping[cls]
+ raise KeyError('Class %s not found in lookup.', cls.__name__)
+
+
+def needs_label(model_field, field_name):
+ """
+ Returns `True` if the label based on the model's verbose name
+ is not equal to the default label it would have based on it's field name.
+ """
+ default_label = field_name.replace('_', ' ').capitalize()
+ return capfirst(model_field.verbose_name) != default_label
+
+
+def get_detail_view_name(model):
+ """
+ Given a model class, return the view name to use for URL relationships
+ that refer to instances of the model.
+ """
+ return '%(model_name)s-detail' % {
+ 'app_label': model._meta.app_label,
+ 'model_name': model._meta.object_name.lower()
+ }
+
+
+def get_field_kwargs(field_name, model_field):
+ """
+ Creates a default instance of a basic non-relational field.
+ """
+ kwargs = {}
+ validator_kwarg = model_field.validators
+
+ if model_field.null or model_field.blank:
+ kwargs['required'] = False
+
+ if model_field.verbose_name and needs_label(model_field, field_name):
+ kwargs['label'] = capfirst(model_field.verbose_name)
+
+ if model_field.help_text:
+ kwargs['help_text'] = model_field.help_text
+
+ if isinstance(model_field, models.AutoField) or not model_field.editable:
+ kwargs['read_only'] = True
+ # Read only implies that the field is not required.
+ # We have a cleaner repr on the instance if we don't set it.
+ kwargs.pop('required', None)
+
+ if model_field.has_default():
+ kwargs['default'] = model_field.get_default()
+ # Having a default implies that the field is not required.
+ # We have a cleaner repr on the instance if we don't set it.
+ kwargs.pop('required', None)
+
+ if model_field.flatchoices:
+ # If this model field contains choices, then return now,
+ # any further keyword arguments are not valid.
+ kwargs['choices'] = model_field.flatchoices
+ return kwargs
+
+ # Ensure that max_length is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ max_length = getattr(model_field, 'max_length', None)
+ if max_length is not None:
+ kwargs['max_length'] = max_length
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MaxLengthValidator)
+ ]
+
+ # Ensure that min_length is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ min_length = getattr(model_field, 'min_length', None)
+ if min_length is not None:
+ kwargs['min_length'] = min_length
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MinLengthValidator)
+ ]
+
+ # Ensure that max_value is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ max_value = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MaxValueValidator)
+ ), None)
+ if max_value is not None:
+ kwargs['max_value'] = max_value
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MaxValueValidator)
+ ]
+
+ # Ensure that max_value is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ min_value = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MinValueValidator)
+ ), None)
+ if min_value is not None:
+ kwargs['min_value'] = min_value
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MinValueValidator)
+ ]
+
+ # URLField does not need to include the URLValidator argument,
+ # as it is explicitly added in.
+ if isinstance(model_field, models.URLField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.URLValidator)
+ ]
+
+ # EmailField does not need to include the validate_email argument,
+ # as it is explicitly added in.
+ if isinstance(model_field, models.EmailField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if validator is not validators.validate_email
+ ]
+
+ # SlugField do not need to include the 'validate_slug' argument,
+ if isinstance(model_field, models.SlugField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if validator is not validators.validate_slug
+ ]
+
+ max_digits = getattr(model_field, 'max_digits', None)
+ if max_digits is not None:
+ kwargs['max_digits'] = max_digits
+
+ decimal_places = getattr(model_field, 'decimal_places', None)
+ if decimal_places is not None:
+ kwargs['decimal_places'] = decimal_places
+
+ if isinstance(model_field, models.BooleanField):
+ # models.BooleanField has `blank=True`, but *is* actually
+ # required *unless* a default is provided.
+ # Also note that Django<1.6 uses `default=False` for
+ # models.BooleanField, but Django>=1.6 uses `default=None`.
+ kwargs.pop('required', None)
+
+ if validator_kwarg:
+ kwargs['validators'] = validator_kwarg
+
+ # The following will only be used by ModelField classes.
+ # Gets removed for everything else.
+ kwargs['model_field'] = model_field
+
+ return kwargs
+
+
+def get_relation_kwargs(field_name, relation_info):
+ """
+ Creates a default instance of a flat relational field.
+ """
+ model_field, related_model, to_many, has_through_model = relation_info
+ kwargs = {
+ 'queryset': related_model._default_manager,
+ 'view_name': get_detail_view_name(related_model)
+ }
+
+ if to_many:
+ kwargs['many'] = True
+
+ if has_through_model:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
+
+ if model_field:
+ if model_field.null or model_field.blank:
+ kwargs['required'] = False
+ if model_field.verbose_name and needs_label(model_field, field_name):
+ kwargs['label'] = capfirst(model_field.verbose_name)
+ if not model_field.editable:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
+ help_text = clean_manytomany_helptext(model_field.help_text)
+ if help_text:
+ kwargs['help_text'] = help_text
+
+ return kwargs
+
+
+def get_nested_relation_kwargs(relation_info):
+ kwargs = {'read_only': True}
+ if relation_info.to_many:
+ kwargs['many'] = True
+ return kwargs
+
+
+def get_url_kwargs(model_field):
+ return {
+ 'view_name': get_detail_view_name(model_field)
+ }
diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py
index 960fa4d0..b6c41174 100644
--- a/rest_framework/utils/model_meta.py
+++ b/rest_framework/utils/model_meta.py
@@ -1,7 +1,9 @@
"""
-Helper functions for returning the field information that is associated
+Helper function for returning the field information that is associated
with a model class. This includes returning all the forward and reverse
relationships and their associated metadata.
+
+Usage: `get_field_info(model)` returns a `FieldInfo` instance.
"""
from collections import namedtuple
from django.db import models
@@ -9,8 +11,22 @@ from django.utils import six
from django.utils.datastructures import SortedDict
import inspect
-FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
-RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model'])
+
+FieldInfo = namedtuple('FieldResult', [
+ 'pk', # Model field instance
+ 'fields', # Dict of field name -> model field instance
+ 'forward_relations', # Dict of field name -> RelationInfo
+ 'reverse_relations', # Dict of field name -> RelationInfo
+ 'fields_and_pk', # Shortcut for 'pk' + 'fields'
+ 'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
+])
+
+RelationInfo = namedtuple('RelationInfo', [
+ 'model_field',
+ 'related',
+ 'to_many',
+ 'has_through_model'
+])
def _resolve_model(obj):
@@ -55,7 +71,7 @@ def get_field_info(model):
forward_relations = SortedDict()
for field in [field for field in opts.fields if field.serialize and field.rel]:
forward_relations[field.name] = RelationInfo(
- field=field,
+ model_field=field,
related=_resolve_model(field.rel.to),
to_many=False,
has_through_model=False
@@ -64,7 +80,7 @@ def get_field_info(model):
# Deal with forward many-to-many relationships.
for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo(
- field=field,
+ model_field=field,
related=_resolve_model(field.rel.to),
to_many=True,
has_through_model=(
@@ -77,7 +93,7 @@ def get_field_info(model):
for relation in opts.get_all_related_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
- field=None,
+ model_field=None,
related=relation.model,
to_many=relation.field.rel.multiple,
has_through_model=False
@@ -87,7 +103,7 @@ def get_field_info(model):
for relation in opts.get_all_related_many_to_many_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
- field=None,
+ model_field=None,
related=relation.model,
to_many=True,
has_through_model=(
@@ -96,4 +112,18 @@ def get_field_info(model):
)
)
- return FieldInfo(pk, fields, forward_relations, reverse_relations)
+ # Shortcut that merges both regular fields and the pk,
+ # for simplifying regular field lookup.
+ fields_and_pk = SortedDict()
+ fields_and_pk['pk'] = pk
+ fields_and_pk[pk.name] = pk
+ fields_and_pk.update(fields)
+
+ # Shortcut that merges both forward and reverse relationships
+
+ relations = SortedDict(
+ list(forward_relations.items()) +
+ list(reverse_relations.items())
+ )
+
+ return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations)
--
cgit v1.2.3
From 87734be5f41de921ac32ad1f6664db243aab6d07 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 18 Sep 2014 12:17:21 +0100
Subject: Configuration correctness tests on ModelSerializer
---
rest_framework/serializers.py | 30 +++++++++++++++++++++++++++---
1 file changed, 27 insertions(+), 3 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 99dcc349..9f3e53fd 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -10,7 +10,7 @@ python primitives.
2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
-from django.core.exceptions import ValidationError
+from django.core.exceptions import ImproperlyConfigured, ValidationError
from django.db import models
from django.utils import six
from django.utils.datastructures import SortedDict
@@ -358,6 +358,7 @@ class ModelSerializer(Serializer):
model = getattr(self.Meta, 'model')
fields = getattr(self.Meta, 'fields', None)
depth = getattr(self.Meta, 'depth', 0)
+ extra_kwargs = getattr(self.Meta, 'extra_kwargs', {})
# Retrieve metadata about fields & relationships on the model class.
info = model_meta.get_field_info(model)
@@ -405,9 +406,32 @@ class ModelSerializer(Serializer):
if not issubclass(field_cls, HyperlinkedRelatedField):
kwargs.pop('view_name', None)
- else:
- assert False, 'Field name `%s` is not valid.' % field_name
+ elif hasattr(model, field_name):
+ # Create a read only field for model methods and properties.
+ field_cls = ReadOnlyField
+ kwargs = {}
+ else:
+ raise ImproperlyConfigured(
+ 'Field name `%s` is not valid for model `%s`.' %
+ (field_name, model.__class__.__name__)
+ )
+
+ # Check that any fields declared on the class are
+ # also explicity included in `Meta.fields`.
+ missing_fields = set(declared_fields.keys()) - set(fields)
+ if missing_fields:
+ missing_field = list(missing_fields)[0]
+ raise ImproperlyConfigured(
+ 'Field `%s` has been declared on serializer `%s`, but '
+ 'is missing from `Meta.fields`.' %
+ (missing_field, self.__class__.__name__)
+ )
+
+ # Populate any kwargs defined in `Meta.extra_kwargs`
+ kwargs.update(extra_kwargs.get(field_name, {}))
+
+ # Create the serializer field.
ret[field_name] = field_cls(**kwargs)
return ret
--
cgit v1.2.3
From 9fdb2280d11db126771686d626aa8a0247b8a46c Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 18 Sep 2014 14:23:00 +0100
Subject: First pass on ManyRelation
---
rest_framework/relations.py | 42 +++++++++++++++++++++++++++++++++-
rest_framework/utils/representation.py | 2 ++
2 files changed, 43 insertions(+), 1 deletion(-)
(limited to 'rest_framework')
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 9f44ab63..474d3e75 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -10,7 +10,6 @@ from django.utils.translation import ugettext_lazy as _
class RelatedField(Field):
def __init__(self, **kwargs):
self.queryset = kwargs.pop('queryset', None)
- self.many = kwargs.pop('many', False)
assert self.queryset is not None or kwargs.get('read_only', None), (
'Relational field must provide a `queryset` argument, '
'or set read_only=`True`.'
@@ -21,6 +20,13 @@ class RelatedField(Field):
)
super(RelatedField, self).__init__(**kwargs)
+ def __new__(cls, *args, **kwargs):
+ # We override this method in order to automagically create
+ # `ManyRelation` classes instead when `many=True` is set.
+ if kwargs.pop('many', False):
+ return ManyRelation(child_relation=cls(*args, **kwargs))
+ return super(RelatedField, cls).__new__(cls, *args, **kwargs)
+
def get_queryset(self):
queryset = self.queryset
if isinstance(queryset, QuerySet):
@@ -216,3 +222,37 @@ class SlugRelatedField(RelatedField):
def to_representation(self, obj):
return getattr(obj, self.slug_field)
+
+
+class ManyRelation(Field):
+ """
+ Relationships with `many=True` transparently get coerced into instead being
+ a ManyRelation with a child relationship.
+
+ The `ManyRelation` class is responsible for handling iterating through
+ the values and passing each one to the child relationship.
+
+ You shouldn't need to be using this class directly yourself.
+ """
+
+ def __init__(self, child_relation=None, *args, **kwargs):
+ self.child_relation = child_relation
+ assert child_relation is not None, '`child_relation` is a required argument.'
+ super(ManyRelation, self).__init__(*args, **kwargs)
+
+ def bind(self, field_name, parent, root):
+ # ManyRelation needs to provide the current context to the child relation.
+ super(ManyRelation, self).bind(field_name, parent, root)
+ self.child_relation.bind(field_name, parent, root)
+
+ def to_internal_value(self, data):
+ return [
+ self.child_relation.to_internal_value(item)
+ for item in data
+ ]
+
+ def to_representation(self, obj):
+ return [
+ self.child_relation.to_representation(value)
+ for value in obj.all()
+ ]
diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py
index 71db1886..e64fdd22 100644
--- a/rest_framework/utils/representation.py
+++ b/rest_framework/utils/representation.py
@@ -73,6 +73,8 @@ def serializer_repr(serializer, indent, force_many=None):
ret += serializer_repr(field, indent + 1)
elif hasattr(field, 'child'):
ret += list_repr(field, indent + 1)
+ elif hasattr(field, 'child_relation'):
+ ret += field_repr(field.child_relation, force_many=field.child_relation)
else:
ret += field_repr(field)
return ret
--
cgit v1.2.3
From 106362b437f45e04faaea759df57a66a8a2d7cfd Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 18 Sep 2014 14:58:08 +0100
Subject: ModelSerializer.create() to handle many to many by default
---
rest_framework/relations.py | 5 ++++-
rest_framework/serializers.py | 20 +++++++++++++++++++-
2 files changed, 23 insertions(+), 2 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 474d3e75..5aa1f8bd 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -24,7 +24,10 @@ class RelatedField(Field):
# We override this method in order to automagically create
# `ManyRelation` classes instead when `many=True` is set.
if kwargs.pop('many', False):
- return ManyRelation(child_relation=cls(*args, **kwargs))
+ return ManyRelation(
+ child_relation=cls(*args, **kwargs),
+ read_only=kwargs.get('read_only', False)
+ )
return super(RelatedField, cls).__new__(cls, *args, **kwargs)
def get_queryset(self):
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 9f3e53fd..03e20df8 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -344,7 +344,25 @@ class ModelSerializer(Serializer):
def create(self, attrs):
ModelClass = self.Meta.model
- return ModelClass.objects.create(**attrs)
+
+ # Remove many-to-many relationships from attrs.
+ # They are not valid arguments to the default `.create()` method,
+ # as they require that the instance has already been saved.
+ info = model_meta.get_field_info(ModelClass)
+ many_to_many = {}
+ for key, relation_info in info.relations.items():
+ if relation_info.to_many and (key in attrs):
+ many_to_many[key] = attrs.pop(key)
+
+ instance = ModelClass.objects.create(**attrs)
+
+ # Save many to many relationships after the instance is created.
+ if many_to_many:
+ for key, value in many_to_many.items():
+ setattr(instance, key, value)
+ instance.save()
+
+ return instance
def update(self, obj, attrs):
for attr, value in attrs.items():
--
cgit v1.2.3
From f90049316a3ecca6c92e10b57bfa5becbceff386 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 18 Sep 2014 15:47:27 +0100
Subject: Added a model update integration test
---
rest_framework/serializers.py | 13 ++++++-------
1 file changed, 6 insertions(+), 7 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 03e20df8..d2740fc2 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -350,17 +350,16 @@ class ModelSerializer(Serializer):
# as they require that the instance has already been saved.
info = model_meta.get_field_info(ModelClass)
many_to_many = {}
- for key, relation_info in info.relations.items():
- if relation_info.to_many and (key in attrs):
- many_to_many[key] = attrs.pop(key)
+ for field_name, relation_info in info.relations.items():
+ if relation_info.to_many and (field_name in attrs):
+ many_to_many[field_name] = attrs.pop(field_name)
instance = ModelClass.objects.create(**attrs)
- # Save many to many relationships after the instance is created.
+ # Save many-to-many relationships after the instance is created.
if many_to_many:
- for key, value in many_to_many.items():
- setattr(instance, key, value)
- instance.save()
+ for field_name, value in many_to_many.items():
+ setattr(instance, field_name, value)
return instance
--
cgit v1.2.3
From cf72b9a8b755652cec4ad19a27488e3a79c2e401 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 19 Sep 2014 16:43:13 +0100
Subject: Moar tests
---
rest_framework/serializers.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'rest_framework')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index d2740fc2..d9f9c8cb 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -24,6 +24,7 @@ from rest_framework.utils.field_mapping import (
lookup_class
)
import copy
+import inspect
# Note: We do the following so that users of the framework can use this style:
#
@@ -268,6 +269,7 @@ class ListSerializer(BaseSerializer):
def __init__(self, *args, **kwargs):
self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.'
+ assert not inspect.isclass(self.child), '`child` has not been instantiated.'
self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
--
cgit v1.2.3
From af46fd6b00f1d7f018049c19094af58acb1415fb Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 22 Sep 2014 12:25:57 +0100
Subject: Field tests and associated cleanup
---
rest_framework/fields.py | 64 +++++++++++++++++++++++-------------------------
1 file changed, 31 insertions(+), 33 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 0c78b3fb..db75ddf9 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -12,7 +12,6 @@ from rest_framework.utils import html, representation, humanize_datetime
import datetime
import decimal
import inspect
-import warnings
class empty:
@@ -395,7 +394,7 @@ class IntegerField(Field):
class FloatField(Field):
default_error_messages = {
- 'invalid': _("'%s' value must be a float."),
+ 'invalid': _("A valid number is required."),
}
def __init__(self, **kwargs):
@@ -410,20 +409,20 @@ class FloatField(Field):
def to_internal_value(self, value):
if value is None:
return None
- return float(value)
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ self.fail('invalid')
def to_representation(self, value):
if value is None:
return None
- try:
- return float(value)
- except (TypeError, ValueError):
- self.fail('invalid', value=value)
+ return float(value)
class DecimalField(Field):
default_error_messages = {
- 'invalid': _('Enter a number.'),
+ 'invalid': _('A valid number is required.'),
'max_value': _('Ensure this value is less than or equal to {max_value}.'),
'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
@@ -485,7 +484,7 @@ class DecimalField(Field):
if self.decimal_places is not None and decimals > self.decimal_places:
self.fail('max_decimal_places', max_decimal_places=self.decimal_places)
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
- self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
+ self.fail('max_whole_digits', max_whole_digits=self.max_digits - self.decimal_places)
return value
@@ -511,6 +510,7 @@ class DecimalField(Field):
class DateField(Field):
default_error_messages = {
'invalid': _('Date has wrong format. Use one of these formats instead: {format}'),
+ 'datetime': _('Expected a date but got a datetime.'),
}
format = api_settings.DATE_FORMAT
input_formats = api_settings.DATE_INPUT_FORMATS
@@ -525,12 +525,7 @@ class DateField(Field):
return None
if isinstance(value, datetime.datetime):
- if timezone and settings.USE_TZ and timezone.is_aware(value):
- # Convert aware datetimes to the default time zone
- # before casting them to dates (#17742).
- default_timezone = timezone.get_default_timezone()
- value = timezone.make_naive(value, default_timezone)
- return value.date()
+ self.fail('datetime')
if isinstance(value, datetime.date):
return value
@@ -570,35 +565,38 @@ class DateField(Field):
class DateTimeField(Field):
default_error_messages = {
'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'),
+ 'date': _('Expected a datetime but got a date.'),
}
format = api_settings.DATETIME_FORMAT
input_formats = api_settings.DATETIME_INPUT_FORMATS
+ default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None
- def __init__(self, format=None, input_formats=None, *args, **kwargs):
+ def __init__(self, format=None, input_formats=None, default_timezone=None, *args, **kwargs):
self.format = format if format is not None else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone
super(DateTimeField, self).__init__(*args, **kwargs)
+ def enforce_timezone(self, value):
+ """
+ When `self.default_timezone` is `None`, always return naive datetimes.
+ When `self.default_timezone` is not `None`, always return aware datetimes.
+ """
+ if (self.default_timezone is not None) and not timezone.is_aware(value):
+ return timezone.make_aware(value, self.default_timezone)
+ elif (self.default_timezone is None) and timezone.is_aware(value):
+ return timezone.make_naive(value, timezone.UTC())
+ return value
+
def to_internal_value(self, value):
if value in (None, ''):
return None
- if isinstance(value, datetime.datetime):
- return value
+ if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime):
+ self.fail('date')
- if isinstance(value, datetime.date):
- value = datetime.datetime(value.year, value.month, value.day)
- if settings.USE_TZ:
- # For backwards compatibility, interpret naive datetimes in
- # local time. This won't work during DST change, but we can't
- # do much about it, so we let the exceptions percolate up the
- # call stack.
- warnings.warn("DateTimeField received a naive datetime (%s)"
- " while time zone support is active." % value,
- RuntimeWarning)
- default_timezone = timezone.get_default_timezone()
- value = timezone.make_aware(value, default_timezone)
- return value
+ if isinstance(value, datetime.datetime):
+ return self.enforce_timezone(value)
for format in self.input_formats:
if format.lower() == ISO_8601:
@@ -608,14 +606,14 @@ class DateTimeField(Field):
pass
else:
if parsed is not None:
- return parsed
+ return self.enforce_timezone(parsed)
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
- return parsed
+ return self.enforce_timezone(parsed)
humanized_format = humanize_datetime.datetime_formats(self.input_formats)
self.fail('invalid', format=humanized_format)
--
cgit v1.2.3
From afb3f8ab0ad6c33b147292e9777ba0ddf3871d14 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 22 Sep 2014 13:26:47 +0100
Subject: Tests and tweaks for text fields
---
rest_framework/fields.py | 32 ++++++++++++++++++++++++--------
1 file changed, 24 insertions(+), 8 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index db75ddf9..35bd5c4b 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -12,6 +12,7 @@ from rest_framework.utils import html, representation, humanize_datetime
import datetime
import decimal
import inspect
+import re
class empty:
@@ -325,7 +326,11 @@ class EmailField(CharField):
default_error_messages = {
'invalid': _('Enter a valid email address.')
}
- default_validators = [validators.validate_email]
+
+ def __init__(self, **kwargs):
+ super(EmailField, self).__init__(**kwargs)
+ validator = validators.EmailValidator(message=self.error_messages['invalid'])
+ self.validators = [validator] + self.validators
def to_internal_value(self, data):
if data == '' and not self.allow_blank:
@@ -341,26 +346,37 @@ class EmailField(CharField):
class RegexField(CharField):
+ default_error_messages = {
+ 'invalid': _('This value does not match the required pattern.')
+ }
+
def __init__(self, regex, **kwargs):
- kwargs['validators'] = (
- [validators.RegexValidator(regex)] +
- kwargs.get('validators', [])
- )
super(RegexField, self).__init__(**kwargs)
+ validator = validators.RegexValidator(regex, message=self.error_messages['invalid'])
+ self.validators = [validator] + self.validators
class SlugField(CharField):
default_error_messages = {
'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")
}
- default_validators = [validators.validate_slug]
+
+ def __init__(self, **kwargs):
+ super(SlugField, self).__init__(**kwargs)
+ slug_regex = re.compile(r'^[-a-zA-Z0-9_]+$')
+ validator = validators.RegexValidator(slug_regex, message=self.error_messages['invalid'])
+ self.validators = [validator] + self.validators
class URLField(CharField):
default_error_messages = {
'invalid': _("Enter a valid URL.")
}
- default_validators = [validators.URLValidator()]
+
+ def __init__(self, **kwargs):
+ super(URLField, self).__init__(**kwargs)
+ validator = validators.URLValidator(message=self.error_messages['invalid'])
+ self.validators = [validator] + self.validators
# Number types...
@@ -642,7 +658,7 @@ class TimeField(Field):
self.input_formats = input_formats if input_formats is not None else self.input_formats
super(TimeField, self).__init__(*args, **kwargs)
- def from_native(self, value):
+ def to_internal_value(self, value):
if value in (None, ''):
return None
--
cgit v1.2.3
From c54f394904c3f93211b8aa073de4e9e50110f831 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 22 Sep 2014 13:57:45 +0100
Subject: Ensure 'messages' in fields are respected in preference to default
validator messages
---
rest_framework/compat.py | 19 +++++++++++++++++++
rest_framework/fields.py | 34 ++++++++++++++++++++++------------
2 files changed, 41 insertions(+), 12 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 7c05bed9..2b4ddb02 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -121,6 +121,25 @@ else:
return [m.upper() for m in self.http_method_names if hasattr(self, m)]
+
+# MinValueValidator and MaxValueValidator only accept `message` in 1.8+
+if django.VERSION >= (1, 8):
+ from django.core.validators import MinValueValidator, MaxValueValidator
+else:
+ from django.core.validators import MinValueValidator as DjangoMinValueValidator
+ from django.core.validators import MaxValueValidator as DjangoMaxValueValidator
+
+ class MinValueValidator(DjangoMinValueValidator):
+ def __init__(self, *args, **kwargs):
+ self.message = kwargs.pop('message', self.message)
+ super(MinValueValidator, self).__init__(*args, **kwargs)
+
+ class MaxValueValidator(DjangoMaxValueValidator):
+ def __init__(self, *args, **kwargs):
+ self.message = kwargs.pop('message', self.message)
+ super(MaxValueValidator, self).__init__(*args, **kwargs)
+
+
# PATCH method is not implemented by Django
if 'patch' not in View.http_method_names:
View.http_method_names = View.http_method_names + ['patch']
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 35bd5c4b..5105dfcb 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -6,7 +6,7 @@ from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
from rest_framework import ISO_8601
-from rest_framework.compat import smart_text
+from rest_framework.compat import smart_text, MinValueValidator, MaxValueValidator
from rest_framework.settings import api_settings
from rest_framework.utils import html, representation, humanize_datetime
import datetime
@@ -330,7 +330,7 @@ class EmailField(CharField):
def __init__(self, **kwargs):
super(EmailField, self).__init__(**kwargs)
validator = validators.EmailValidator(message=self.error_messages['invalid'])
- self.validators = [validator] + self.validators
+ self.validators.append(validator)
def to_internal_value(self, data):
if data == '' and not self.allow_blank:
@@ -353,7 +353,7 @@ class RegexField(CharField):
def __init__(self, regex, **kwargs):
super(RegexField, self).__init__(**kwargs)
validator = validators.RegexValidator(regex, message=self.error_messages['invalid'])
- self.validators = [validator] + self.validators
+ self.validators.append(validator)
class SlugField(CharField):
@@ -365,7 +365,7 @@ class SlugField(CharField):
super(SlugField, self).__init__(**kwargs)
slug_regex = re.compile(r'^[-a-zA-Z0-9_]+$')
validator = validators.RegexValidator(slug_regex, message=self.error_messages['invalid'])
- self.validators = [validator] + self.validators
+ self.validators.append(validator)
class URLField(CharField):
@@ -376,14 +376,16 @@ class URLField(CharField):
def __init__(self, **kwargs):
super(URLField, self).__init__(**kwargs)
validator = validators.URLValidator(message=self.error_messages['invalid'])
- self.validators = [validator] + self.validators
+ self.validators.append(validator)
# Number types...
class IntegerField(Field):
default_error_messages = {
- 'invalid': _('A valid integer is required.')
+ 'invalid': _('A valid integer is required.'),
+ 'max_value': _('Ensure this value is less than or equal to {max_value}.'),
+ 'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
}
def __init__(self, **kwargs):
@@ -391,9 +393,11 @@ class IntegerField(Field):
min_value = kwargs.pop('min_value', None)
super(IntegerField, self).__init__(**kwargs)
if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
+ message = self.error_messages['max_value'].format(max_value=max_value)
+ self.validators.append(MaxValueValidator(max_value, message=message))
if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
+ message = self.error_messages['min_value'].format(min_value=min_value)
+ self.validators.append(MinValueValidator(min_value, message=message))
def to_internal_value(self, data):
try:
@@ -411,6 +415,8 @@ class IntegerField(Field):
class FloatField(Field):
default_error_messages = {
'invalid': _("A valid number is required."),
+ 'max_value': _('Ensure this value is less than or equal to {max_value}.'),
+ 'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
}
def __init__(self, **kwargs):
@@ -418,9 +424,11 @@ class FloatField(Field):
min_value = kwargs.pop('min_value', None)
super(FloatField, self).__init__(**kwargs)
if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
+ message = self.error_messages['max_value'].format(max_value=max_value)
+ self.validators.append(MaxValueValidator(max_value, message=message))
if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
+ message = self.error_messages['min_value'].format(min_value=min_value)
+ self.validators.append(MinValueValidator(min_value, message=message))
def to_internal_value(self, value):
if value is None:
@@ -454,9 +462,11 @@ class DecimalField(Field):
self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string
super(DecimalField, self).__init__(**kwargs)
if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
+ message = self.error_messages['max_value'].format(max_value=max_value)
+ self.validators.append(MaxValueValidator(max_value, message=message))
if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
+ message = self.error_messages['min_value'].format(min_value=min_value)
+ self.validators.append(MinValueValidator(min_value, message=message))
def to_internal_value(self, value):
"""
--
cgit v1.2.3
From 249253a144ba4381581809fb3f27959c7bd6e577 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 22 Sep 2014 14:54:33 +0100
Subject: Fix compat issues
---
rest_framework/fields.py | 13 ++++++++++---
1 file changed, 10 insertions(+), 3 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 5105dfcb..5fb99a42 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -209,8 +209,10 @@ class Field(object):
"""
Validate a simple representation and return the internal value.
- The provided data may be `empty` if no representation was included.
- May return `empty` if the field should not be included in the
+ The provided data may be `empty` if no representation was included
+ in the input.
+
+ May raise `SkipField` if the field should not be included in the
validated data.
"""
if data is empty:
@@ -223,6 +225,10 @@ class Field(object):
return value
def run_validators(self, value):
+ """
+ Test the given value against all the validators on the field,
+ and either raise a `ValidationError` or simply return.
+ """
if value in (None, '', [], (), {}):
return
@@ -753,8 +759,9 @@ class MultipleChoiceField(ChoiceField):
}
def to_internal_value(self, data):
- if not hasattr(data, '__iter__'):
+ if isinstance(data, type('')) or not hasattr(data, '__iter__'):
self.fail('not_a_list', input_type=type(data).__name__)
+
return set([
super(MultipleChoiceField, self).to_internal_value(item)
for item in data
--
cgit v1.2.3
From 4db23cae213decc3e8a8613ad5c76a545f8cfb1a Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 22 Sep 2014 15:34:06 +0100
Subject: Tweaks to DecimalField
---
rest_framework/fields.py | 25 +++++++++++++------------
1 file changed, 13 insertions(+), 12 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 5fb99a42..db7ceabb 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -521,20 +521,21 @@ class DecimalField(Field):
return value
def to_representation(self, value):
- if isinstance(value, decimal.Decimal):
- context = decimal.getcontext().copy()
- context.prec = self.max_digits
- quantized = value.quantize(
- decimal.Decimal('.1') ** self.decimal_places,
- context=context
- )
- if not self.coerce_to_string:
- return quantized
- return '{0:f}'.format(quantized)
+ if value in (None, ''):
+ return None
+ if not isinstance(value, decimal.Decimal):
+ value = decimal.Decimal(value)
+
+ context = decimal.getcontext().copy()
+ context.prec = self.max_digits
+ quantized = value.quantize(
+ decimal.Decimal('.1') ** self.decimal_places,
+ context=context
+ )
if not self.coerce_to_string:
- return value
- return '%.*f' % (self.max_decimal_places, value)
+ return quantized
+ return '{0:f}'.format(quantized)
# Date & time fields...
--
cgit v1.2.3
From 5586b6581d9d8db05276c08f2c6deffec04ade4f Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 22 Sep 2014 16:02:59 +0100
Subject: Support format=None for date/time fields
---
rest_framework/fields.py | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index db7ceabb..cbd3334a 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -548,8 +548,8 @@ class DateField(Field):
format = api_settings.DATE_FORMAT
input_formats = api_settings.DATE_INPUT_FORMATS
- def __init__(self, format=None, input_formats=None, *args, **kwargs):
- self.format = format if format is not None else self.format
+ def __init__(self, format=empty, input_formats=None, *args, **kwargs):
+ self.format = format if format is not empty else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
super(DateField, self).__init__(*args, **kwargs)
@@ -604,8 +604,8 @@ class DateTimeField(Field):
input_formats = api_settings.DATETIME_INPUT_FORMATS
default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None
- def __init__(self, format=None, input_formats=None, default_timezone=None, *args, **kwargs):
- self.format = format if format is not None else self.format
+ def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs):
+ self.format = format if format is not empty else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone
super(DateTimeField, self).__init__(*args, **kwargs)
@@ -670,8 +670,8 @@ class TimeField(Field):
format = api_settings.TIME_FORMAT
input_formats = api_settings.TIME_INPUT_FORMATS
- def __init__(self, format=None, input_formats=None, *args, **kwargs):
- self.format = format if format is not None else self.format
+ def __init__(self, format=empty, input_formats=None, *args, **kwargs):
+ self.format = format if format is not empty else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
super(TimeField, self).__init__(*args, **kwargs)
--
cgit v1.2.3
From e5f0a97595ff9280c7876fc917f6feb27b5ea95d Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 22 Sep 2014 16:45:06 +0100
Subject: More compat fixes
---
rest_framework/compat.py | 23 +++++++++++++++++++++++
rest_framework/fields.py | 8 ++++----
2 files changed, 27 insertions(+), 4 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 2b4ddb02..7303c32a 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -139,6 +139,29 @@ else:
self.message = kwargs.pop('message', self.message)
super(MaxValueValidator, self).__init__(*args, **kwargs)
+# URLValidator only accept `message` in 1.6+
+if django.VERSION >= (1, 6):
+ from django.core.validators import URLValidator
+else:
+ from django.core.validators import URLValidator as DjangoURLValidator
+
+ class URLValidator(DjangoURLValidator):
+ def __init__(self, *args, **kwargs):
+ self.message = kwargs.pop('message', self.message)
+ super(URLValidator, self).__init__(*args, **kwargs)
+
+
+# EmailValidator requires explicit regex prior to 1.6+
+if django.VERSION >= (1, 6):
+ from django.core.validators import EmailValidator
+else:
+ from django.core.validators import EmailValidator as DjangoEmailValidator
+ from django.core.validators import email_re
+
+ class EmailValidator(DjangoEmailValidator):
+ def __init__(self, *args, **kwargs):
+ super(EmailValidator, self).__init__(email_re, *args, **kwargs)
+
# PATCH method is not implemented by Django
if 'patch' not in View.http_method_names:
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index cbd3334a..12975ae4 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -6,7 +6,7 @@ from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
from rest_framework import ISO_8601
-from rest_framework.compat import smart_text, MinValueValidator, MaxValueValidator
+from rest_framework.compat import smart_text, EmailValidator, MinValueValidator, MaxValueValidator, URLValidator
from rest_framework.settings import api_settings
from rest_framework.utils import html, representation, humanize_datetime
import datetime
@@ -335,7 +335,7 @@ class EmailField(CharField):
def __init__(self, **kwargs):
super(EmailField, self).__init__(**kwargs)
- validator = validators.EmailValidator(message=self.error_messages['invalid'])
+ validator = EmailValidator(message=self.error_messages['invalid'])
self.validators.append(validator)
def to_internal_value(self, data):
@@ -381,7 +381,7 @@ class URLField(CharField):
def __init__(self, **kwargs):
super(URLField, self).__init__(**kwargs)
- validator = validators.URLValidator(message=self.error_messages['invalid'])
+ validator = URLValidator(message=self.error_messages['invalid'])
self.validators.append(validator)
@@ -525,7 +525,7 @@ class DecimalField(Field):
return None
if not isinstance(value, decimal.Decimal):
- value = decimal.Decimal(value)
+ value = decimal.Decimal(str(value).strip())
context = decimal.getcontext().copy()
context.prec = self.max_digits
--
cgit v1.2.3
From b5454dd02290130a7fb0a0e375f3efecc58edc6d Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 22 Sep 2014 16:50:04 +0100
Subject: Tests and tweaks for choice fields
---
rest_framework/fields.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 12975ae4..500018f3 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -750,7 +750,7 @@ class ChoiceField(Field):
self.fail('invalid_choice', input=data)
def to_representation(self, value):
- return value
+ return self.choice_strings_to_values[str(value)]
class MultipleChoiceField(ChoiceField):
@@ -769,7 +769,7 @@ class MultipleChoiceField(ChoiceField):
])
def to_representation(self, value):
- return value
+ return [self.choice_strings_to_values[str(item)] for item in value]
# File types...
--
cgit v1.2.3
From 5a95baf2a2258fb5297062ac18582129c05fb320 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 22 Sep 2014 16:52:57 +0100
Subject: Tests & tweaks for ChoiceField
---
rest_framework/fields.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 500018f3..80eadf1e 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -769,7 +769,9 @@ class MultipleChoiceField(ChoiceField):
])
def to_representation(self, value):
- return [self.choice_strings_to_values[str(item)] for item in value]
+ return set([
+ self.choice_strings_to_values[str(item)] for item in value
+ ])
# File types...
--
cgit v1.2.3
From 5d80f7f932bfcc0630ac0fdbf07072a53197b98f Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 22 Sep 2014 17:46:02 +0100
Subject: allow_blank, allow_null
---
rest_framework/fields.py | 40 ++++++++++++++++------------------------
1 file changed, 16 insertions(+), 24 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 80eadf1e..48a3e1ab 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -98,14 +98,15 @@ class Field(object):
_creation_counter = 0
default_error_messages = {
- 'required': _('This field is required.')
+ 'required': _('This field is required.'),
+ 'null': _('This field may not be null.')
}
default_validators = []
def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None,
label=None, help_text=None, style=None,
- error_messages=None, validators=[]):
+ error_messages=None, validators=[], allow_null=False):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
@@ -129,6 +130,7 @@ class Field(object):
self.help_text = help_text
self.style = {} if style is None else style
self.validators = validators or self.default_validators[:]
+ self.allow_null = allow_null
# Collect default error message from self and parent classes
messages = {}
@@ -220,6 +222,11 @@ class Field(object):
self.fail('required')
return self.get_default()
+ if data is None:
+ if not self.allow_null:
+ self.fail('null')
+ return None
+
value = self.to_internal_value(data)
self.run_validators(value)
return value
@@ -315,11 +322,14 @@ class CharField(Field):
self.min_length = kwargs.pop('min_length', None)
super(CharField, self).__init__(**kwargs)
+ def run_validation(self, data=empty):
+ if data == '':
+ if not self.allow_blank:
+ self.fail('blank')
+ return ''
+ return super(CharField, self).run_validation(data)
+
def to_internal_value(self, data):
- if data == '' and not self.allow_blank:
- self.fail('blank')
- if data is None:
- return None
return str(data)
def to_representation(self, value):
@@ -339,10 +349,6 @@ class EmailField(CharField):
self.validators.append(validator)
def to_internal_value(self, data):
- if data == '' and not self.allow_blank:
- self.fail('blank')
- if data is None:
- return None
return str(data).strip()
def to_representation(self, value):
@@ -437,8 +443,6 @@ class FloatField(Field):
self.validators.append(MinValueValidator(min_value, message=message))
def to_internal_value(self, value):
- if value is None:
- return None
try:
return float(value)
except (TypeError, ValueError):
@@ -481,9 +485,6 @@ class DecimalField(Field):
than max_digits in the number, and no more than decimal_places digits
after the decimal point.
"""
- if value in (None, ''):
- return None
-
value = smart_text(value).strip()
try:
value = decimal.Decimal(value)
@@ -554,9 +555,6 @@ class DateField(Field):
super(DateField, self).__init__(*args, **kwargs)
def to_internal_value(self, value):
- if value in (None, ''):
- return None
-
if isinstance(value, datetime.datetime):
self.fail('datetime')
@@ -622,9 +620,6 @@ class DateTimeField(Field):
return value
def to_internal_value(self, value):
- if value in (None, ''):
- return None
-
if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime):
self.fail('date')
@@ -676,9 +671,6 @@ class TimeField(Field):
super(TimeField, self).__init__(*args, **kwargs)
def to_internal_value(self, value):
- if value in (None, ''):
- return None
-
if isinstance(value, datetime.time):
return value
--
cgit v1.2.3
From b187f53453d3885cd918f5f9f4490bcc8e3e2410 Mon Sep 17 00:00:00 2001
From: Danilo Bargen
Date: Mon, 2 Jun 2014 00:41:58 +0200
Subject: Changed return status for CSRF failures to HTTP 403
By default, Django returns "HTTP 403 Forbidden" responses when CSRF
validation failed[1]. CSRF is a case of authorization, not of
authentication. Therefore `PermissionDenied` should be raised instead
of `AuthenticationFailed`.
[1] https://docs.djangoproject.com/en/dev/ref/contrib/csrf/#rejected-requests
---
rest_framework/authentication.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'rest_framework')
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index f3fec05e..36d74dd9 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -129,7 +129,7 @@ class SessionAuthentication(BaseAuthentication):
reason = CSRFCheck().process_view(request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
- raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason)
+ raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
class TokenAuthentication(BaseAuthentication):
--
cgit v1.2.3
From f22d0afc3dfc7478e084d1d6ed6b53f71641dec6 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Tue, 23 Sep 2014 14:15:00 +0100
Subject: Tests for field choices
---
rest_framework/fields.py | 21 +++++++------
rest_framework/serializers.py | 3 ++
rest_framework/utils/field_mapping.py | 58 ++++++++++++++++++-----------------
3 files changed, 44 insertions(+), 38 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 48a3e1ab..f5bae734 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -102,6 +102,7 @@ class Field(object):
'null': _('This field may not be null.')
}
default_validators = []
+ default_empty_html = None
def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None,
@@ -185,6 +186,11 @@ class Field(object):
Given the *incoming* primative data, return the value for this field
that should be validated and transformed to a native value.
"""
+ if html.is_html_input(dictionary):
+ # HTML forms will represent empty fields as '', and cannot
+ # represent None or False values directly.
+ ret = dictionary.get(self.field_name, '')
+ return self.default_empty_html if (ret == '') else ret
return dictionary.get(self.field_name, empty)
def get_attribute(self, instance):
@@ -236,9 +242,6 @@ class Field(object):
Test the given value against all the validators on the field,
and either raise a `ValidationError` or simply return.
"""
- if value in (None, '', [], (), {}):
- return
-
errors = []
for validator in self.validators:
try:
@@ -282,16 +285,10 @@ class BooleanField(Field):
default_error_messages = {
'invalid': _('`{input}` is not a valid boolean.')
}
+ default_empty_html = False
TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
- def get_value(self, dictionary):
- if html.is_html_input(dictionary):
- # HTML forms do not send a `False` value on an empty checkbox,
- # so we override the default empty value to be False.
- return dictionary.get(self.field_name, False)
- return dictionary.get(self.field_name, empty)
-
def to_internal_value(self, data):
if data in self.TRUE_VALUES:
return True
@@ -315,6 +312,7 @@ class CharField(Field):
default_error_messages = {
'blank': _('This field may not be blank.')
}
+ default_empty_html = ''
def __init__(self, **kwargs):
self.allow_blank = kwargs.pop('allow_blank', False)
@@ -323,6 +321,9 @@ class CharField(Field):
super(CharField, self).__init__(**kwargs)
def run_validation(self, data=empty):
+ # Test for the empty string here so that it does not get validated,
+ # and so that subclasses do not need to handle it explicitly
+ # inside the `to_internal_value()` method.
if data == '':
if not self.allow_blank:
self.fail('blank')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index d9f9c8cb..949f5915 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -411,6 +411,9 @@ class ModelSerializer(Serializer):
# `ModelField`, which is used when no other typed field
# matched to the model field.
kwargs.pop('model_field', None)
+ if not issubclass(field_cls, CharField):
+ # `allow_blank` is only valid for textual fields.
+ kwargs.pop('allow_blank', None)
elif field_name in info.relations:
# Create forward and reverse relationships.
diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py
index be72e444..1c718ccb 100644
--- a/rest_framework/utils/field_mapping.py
+++ b/rest_framework/utils/field_mapping.py
@@ -49,8 +49,9 @@ def get_field_kwargs(field_name, model_field):
kwargs = {}
validator_kwarg = model_field.validators
- if model_field.null or model_field.blank:
- kwargs['required'] = False
+ # The following will only be used by ModelField classes.
+ # Gets removed for everything else.
+ kwargs['model_field'] = model_field
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
@@ -59,23 +60,26 @@ def get_field_kwargs(field_name, model_field):
kwargs['help_text'] = model_field.help_text
if isinstance(model_field, models.AutoField) or not model_field.editable:
+ # If this field is read-only, then return early.
+ # Further keyword arguments are not valid.
kwargs['read_only'] = True
- # Read only implies that the field is not required.
- # We have a cleaner repr on the instance if we don't set it.
- kwargs.pop('required', None)
+ return kwargs
if model_field.has_default():
- kwargs['default'] = model_field.get_default()
- # Having a default implies that the field is not required.
- # We have a cleaner repr on the instance if we don't set it.
- kwargs.pop('required', None)
+ kwargs['required'] = False
if model_field.flatchoices:
- # If this model field contains choices, then return now,
- # any further keyword arguments are not valid.
+ # If this model field contains choices, then return early.
+ # Further keyword arguments are not valid.
kwargs['choices'] = model_field.flatchoices
return kwargs
+ if model_field.null:
+ kwargs['allow_null'] = True
+
+ if model_field.blank:
+ kwargs['allow_blank'] = True
+
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
@@ -88,7 +92,10 @@ def get_field_kwargs(field_name, model_field):
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
- min_length = getattr(model_field, 'min_length', None)
+ min_length = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MinLengthValidator)
+ ), None)
if min_length is not None:
kwargs['min_length'] = min_length
validator_kwarg = [
@@ -153,20 +160,9 @@ def get_field_kwargs(field_name, model_field):
if decimal_places is not None:
kwargs['decimal_places'] = decimal_places
- if isinstance(model_field, models.BooleanField):
- # models.BooleanField has `blank=True`, but *is* actually
- # required *unless* a default is provided.
- # Also note that Django<1.6 uses `default=False` for
- # models.BooleanField, but Django>=1.6 uses `default=None`.
- kwargs.pop('required', None)
-
if validator_kwarg:
kwargs['validators'] = validator_kwarg
- # The following will only be used by ModelField classes.
- # Gets removed for everything else.
- kwargs['model_field'] = model_field
-
return kwargs
@@ -188,16 +184,22 @@ def get_relation_kwargs(field_name, relation_info):
kwargs.pop('queryset', None)
if model_field:
- if model_field.null or model_field.blank:
- kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
- if not model_field.editable:
- kwargs['read_only'] = True
- kwargs.pop('queryset', None)
help_text = clean_manytomany_helptext(model_field.help_text)
if help_text:
kwargs['help_text'] = help_text
+ if not model_field.editable:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
+ if kwargs.get('read_only', False):
+ # If this field is read-only, then return early.
+ # No further keyword arguments are valid.
+ return kwargs
+ if model_field.has_default():
+ kwargs['required'] = False
+ if model_field.null:
+ kwargs['allow_null'] = True
return kwargs
--
cgit v1.2.3
From 0404f09a7e69f533038d47ca25caad90c0c2659f Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Tue, 23 Sep 2014 14:30:17 +0100
Subject: NullBooleanField
---
rest_framework/fields.py | 37 ++++++++++++++++++++++++++++++++++-
rest_framework/serializers.py | 2 +-
rest_framework/utils/field_mapping.py | 2 +-
3 files changed, 38 insertions(+), 3 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index f5bae734..f859658a 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -289,6 +289,10 @@ class BooleanField(Field):
TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
+ def __init__(self, **kwargs):
+ assert 'allow_null' not in kwargs, '`allow_null` is not a valid option. Use `NullBooleanField` instead.'
+ super(BooleanField, self).__init__(**kwargs)
+
def to_internal_value(self, data):
if data in self.TRUE_VALUES:
return True
@@ -297,7 +301,38 @@ class BooleanField(Field):
self.fail('invalid', input=data)
def to_representation(self, value):
- if value is None:
+ if value in self.TRUE_VALUES:
+ return True
+ elif value in self.FALSE_VALUES:
+ return False
+ return bool(value)
+
+
+class NullBooleanField(Field):
+ default_error_messages = {
+ 'invalid': _('`{input}` is not a valid boolean.')
+ }
+ default_empty_html = None
+ TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
+ FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
+ NULL_VALUES = set(('n', 'N', 'null', 'Null', 'NULL', '', None))
+
+ def __init__(self, **kwargs):
+ assert 'allow_null' not in kwargs, '`allow_null` is not a valid option.'
+ kwargs['allow_null'] = True
+ super(NullBooleanField, self).__init__(**kwargs)
+
+ def to_internal_value(self, data):
+ if data in self.TRUE_VALUES:
+ return True
+ elif data in self.FALSE_VALUES:
+ return False
+ elif data in self.NULL_VALUES:
+ return None
+ self.fail('invalid', input=data)
+
+ def to_representation(self, value):
+ if value in self.NULL_VALUES:
return None
if value in self.TRUE_VALUES:
return True
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 949f5915..d8d72a4c 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -333,7 +333,7 @@ class ModelSerializer(Serializer):
models.FloatField: FloatField,
models.ImageField: ImageField,
models.IntegerField: IntegerField,
- models.NullBooleanField: BooleanField,
+ models.NullBooleanField: NullBooleanField,
models.PositiveIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
models.SlugField: SlugField,
diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py
index 1c718ccb..c208afdc 100644
--- a/rest_framework/utils/field_mapping.py
+++ b/rest_framework/utils/field_mapping.py
@@ -74,7 +74,7 @@ def get_field_kwargs(field_name, model_field):
kwargs['choices'] = model_field.flatchoices
return kwargs
- if model_field.null:
+ if model_field.null and not isinstance(model_field, models.NullBooleanField):
kwargs['allow_null'] = True
if model_field.blank:
--
cgit v1.2.3
From f4b1dcb167be0bbdaae2cc2a92f651536896dc16 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 24 Sep 2014 14:09:49 +0100
Subject: OPTIONS support
---
rest_framework/generics.py | 51 +-------------
rest_framework/metadata.py | 126 ++++++++++++++++++++++++++++++++++
rest_framework/serializers.py | 8 +--
rest_framework/settings.py | 2 +
rest_framework/utils/field_mapping.py | 20 +++---
rest_framework/views.py | 24 ++-----
6 files changed, 150 insertions(+), 81 deletions(-)
create mode 100644 rest_framework/metadata.py
(limited to 'rest_framework')
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index eb6b64ef..f49b0a43 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -4,13 +4,11 @@ Generic views that provide commonly needed behaviour.
from __future__ import unicode_literals
from django.db.models.query import QuerySet
-from django.core.exceptions import PermissionDenied
from django.core.paginator import Paginator, InvalidPage
from django.http import Http404
from django.shortcuts import get_object_or_404 as _get_object_or_404
from django.utils.translation import ugettext as _
-from rest_framework import views, mixins, exceptions
-from rest_framework.request import clone_request
+from rest_framework import views, mixins
from rest_framework.settings import api_settings
@@ -249,53 +247,6 @@ class GenericAPIView(views.APIView):
return obj
- # The following are placeholder methods,
- # and are intended to be overridden.
- #
- # The are not called by GenericAPIView directly,
- # but are used by the mixin methods.
- def metadata(self, request):
- """
- Return a dictionary of metadata about the view.
- Used to return responses for OPTIONS requests.
-
- We override the default behavior, and add some extra information
- about the required request body for POST and PUT operations.
- """
- ret = super(GenericAPIView, self).metadata(request)
-
- actions = {}
- for method in ('PUT', 'POST'):
- if method not in self.allowed_methods:
- continue
-
- cloned_request = clone_request(request, method)
- try:
- # Test global permissions
- self.check_permissions(cloned_request)
- # Test object permissions
- if method == 'PUT':
- try:
- self.get_object()
- except Http404:
- # Http404 should be acceptable and the serializer
- # metadata should be populated. Except this so the
- # outer "else" clause of the try-except-else block
- # will be executed.
- pass
- except (exceptions.APIException, PermissionDenied):
- pass
- else:
- # If user has appropriate permissions for the view, include
- # appropriate metadata about the fields that should be supplied.
- serializer = self.get_serializer()
- actions[method] = serializer.metadata()
-
- if actions:
- ret['actions'] = actions
-
- return ret
-
# Concrete view classes that provide method handlers
# by composing the mixin classes with the base view.
diff --git a/rest_framework/metadata.py b/rest_framework/metadata.py
new file mode 100644
index 00000000..580259de
--- /dev/null
+++ b/rest_framework/metadata.py
@@ -0,0 +1,126 @@
+"""
+The metadata API is used to allow cusomization of how `OPTIONS` requests
+are handled. We currently provide a single default implementation that returns
+some fairly ad-hoc information about the view.
+
+Future implementations might use JSON schema or other definations in order
+to return this information in a more standardized way.
+"""
+from __future__ import unicode_literals
+
+from django.core.exceptions import PermissionDenied
+from django.http import Http404
+from django.utils import six
+from django.utils.datastructures import SortedDict
+from rest_framework import exceptions, serializers
+from rest_framework.compat import force_text
+from rest_framework.request import clone_request
+from rest_framework.utils.field_mapping import ClassLookupDict
+
+
+class BaseMetadata(object):
+ def determine_metadata(self, request, view):
+ """
+ Return a dictionary of metadata about the view.
+ Used to return responses for OPTIONS requests.
+ """
+ raise NotImplementedError(".determine_metadata() must be overridden.")
+
+
+class SimpleMetadata(BaseMetadata):
+ """
+ This is the default metadata implementation.
+ It returns an ad-hoc set of information about the view.
+ There are not any formalized standards for `OPTIONS` responses
+ for us to base this on.
+ """
+ label_lookup = ClassLookupDict({
+ serializers.Field: 'field',
+ serializers.BooleanField: 'boolean',
+ serializers.CharField: 'string',
+ serializers.URLField: 'url',
+ serializers.EmailField: 'email',
+ serializers.RegexField: 'regex',
+ serializers.SlugField: 'slug',
+ serializers.IntegerField: 'integer',
+ serializers.FloatField: 'float',
+ serializers.DecimalField: 'decimal',
+ serializers.DateField: 'date',
+ serializers.DateTimeField: 'datetime',
+ serializers.TimeField: 'time',
+ serializers.ChoiceField: 'choice',
+ serializers.MultipleChoiceField: 'multiple choice',
+ serializers.FileField: 'file upload',
+ serializers.ImageField: 'image upload',
+ })
+
+ def determine_metadata(self, request, view):
+ metadata = SortedDict()
+ metadata['name'] = view.get_view_name()
+ metadata['description'] = view.get_view_description()
+ metadata['renders'] = [renderer.media_type for renderer in view.renderer_classes]
+ metadata['parses'] = [parser.media_type for parser in view.parser_classes]
+ if hasattr(view, 'get_serializer'):
+ actions = self.determine_actions(request, view)
+ if actions:
+ metadata['actions'] = actions
+ return metadata
+
+ def determine_actions(self, request, view):
+ """
+ For generic class based views we return information about
+ the fields that are accepted for 'PUT' and 'POST' methods.
+ """
+ actions = {}
+ for method in set(['PUT', 'POST']) & set(view.allowed_methods):
+ view.request = clone_request(request, method)
+ try:
+ # Test global permissions
+ if hasattr(view, 'check_permissions'):
+ view.check_permissions(view.request)
+ # Test object permissions
+ if method == 'PUT' and hasattr(view, 'get_object'):
+ view.get_object()
+ except (exceptions.APIException, PermissionDenied, Http404):
+ pass
+ else:
+ # If user has appropriate permissions for the view, include
+ # appropriate metadata about the fields that should be supplied.
+ serializer = view.get_serializer()
+ actions[method] = self.get_serializer_info(serializer)
+ finally:
+ view.request = request
+
+ return actions
+
+ def get_serializer_info(self, serializer):
+ """
+ Given an instance of a serializer, return a dictionary of metadata
+ about its fields.
+ """
+ return SortedDict([
+ (field_name, self.get_field_info(field))
+ for field_name, field in six.iteritems(serializer.fields)
+ ])
+
+ def get_field_info(self, field):
+ """
+ Given an instance of a serializer field, return a dictionary
+ of metadata about it.
+ """
+ field_info = SortedDict()
+ field_info['type'] = self.label_lookup[field]
+ field_info['required'] = getattr(field, 'required', False)
+
+ for attr in ['read_only', 'label', 'help_text', 'min_length', 'max_length']:
+ value = getattr(field, attr, None)
+ if value is not None and value != '':
+ field_info[attr] = force_text(value, strings_only=True)
+
+ if hasattr(field, 'choices'):
+ field_info['choices'] = [
+ {'value': choice_value, 'display_name': choice_name}
+ for choice_value, choice_name in field.choices.items()
+ ]
+
+ return field_info
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index d8d72a4c..8902294b 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -21,7 +21,7 @@ from rest_framework.utils import html, model_meta, representation
from rest_framework.utils.field_mapping import (
get_url_kwargs, get_field_kwargs,
get_relation_kwargs, get_nested_relation_kwargs,
- lookup_class
+ ClassLookupDict
)
import copy
import inspect
@@ -318,7 +318,7 @@ class ListSerializer(BaseSerializer):
class ModelSerializer(Serializer):
- _field_mapping = {
+ _field_mapping = ClassLookupDict({
models.AutoField: IntegerField,
models.BigIntegerField: IntegerField,
models.BooleanField: BooleanField,
@@ -341,7 +341,7 @@ class ModelSerializer(Serializer):
models.TextField: CharField,
models.TimeField: TimeField,
models.URLField: URLField,
- }
+ })
_related_class = PrimaryKeyRelatedField
def create(self, attrs):
@@ -400,7 +400,7 @@ class ModelSerializer(Serializer):
elif field_name in info.fields_and_pk:
# Create regular model fields.
model_field = info.fields_and_pk[field_name]
- field_cls = lookup_class(self._field_mapping, model_field)
+ field_cls = self._field_mapping[model_field]
kwargs = get_field_kwargs(field_name, model_field)
if 'choices' in kwargs:
# Fields with choices get coerced into `ChoiceField`
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 421e146c..d7fb0a43 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -45,6 +45,7 @@ DEFAULTS = {
),
'DEFAULT_THROTTLE_CLASSES': (),
'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
+ 'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata',
# Genric view behavior
'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.serializers.ModelSerializer',
@@ -121,6 +122,7 @@ IMPORT_STRINGS = (
'DEFAULT_PERMISSION_CLASSES',
'DEFAULT_THROTTLE_CLASSES',
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
+ 'DEFAULT_METADATA_CLASS',
'DEFAULT_MODEL_SERIALIZER_CLASS',
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'DEFAULT_FILTER_BACKENDS',
diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py
index c208afdc..c3794083 100644
--- a/rest_framework/utils/field_mapping.py
+++ b/rest_framework/utils/field_mapping.py
@@ -9,17 +9,21 @@ from rest_framework.compat import clean_manytomany_helptext
import inspect
-def lookup_class(mapping, instance):
+class ClassLookupDict(object):
"""
- Takes a dictionary with classes as keys, and an object.
- Traverses the object's inheritance hierarchy in method
- resolution order, and returns the first matching value
+ Takes a dictionary with classes as keys.
+ Lookups against this object will traverses the object's inheritance
+ hierarchy in method resolution order, and returns the first matching value
from the dictionary or raises a KeyError if nothing matches.
"""
- for cls in inspect.getmro(instance.__class__):
- if cls in mapping:
- return mapping[cls]
- raise KeyError('Class %s not found in lookup.', cls.__name__)
+ def __init__(self, mapping):
+ self.mapping = mapping
+
+ def __getitem__(self, key):
+ for cls in inspect.getmro(key.__class__):
+ if cls in self.mapping:
+ return self.mapping[cls]
+ raise KeyError('Class %s not found in lookup.', cls.__name__)
def needs_label(model_field, field_name):
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 9f08a4ad..835e223a 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -5,7 +5,6 @@ from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied, ValidationError, NON_FIELD_ERRORS
from django.http import Http404
-from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions
from rest_framework.compat import smart_text, HttpResponseBase, View
@@ -99,6 +98,7 @@ class APIView(View):
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
+ metadata_class = api_settings.DEFAULT_METADATA_CLASS
# Allow dependancy injection of other settings to make testing easier.
settings = api_settings
@@ -418,22 +418,8 @@ class APIView(View):
def options(self, request, *args, **kwargs):
"""
Handler method for HTTP 'OPTIONS' request.
- We may as well implement this as Django will otherwise provide
- a less useful default implementation.
"""
- return Response(self.metadata(request), status=status.HTTP_200_OK)
-
- def metadata(self, request):
- """
- Return a dictionary of metadata about the view.
- Used to return responses for OPTIONS requests.
- """
- # By default we can't provide any form-like information, however the
- # generic views override this implementation and add additional
- # information for POST and PUT methods, based on the serializer.
- ret = SortedDict()
- ret['name'] = self.get_view_name()
- ret['description'] = self.get_view_description()
- ret['renders'] = [renderer.media_type for renderer in self.renderer_classes]
- ret['parses'] = [parser.media_type for parser in self.parser_classes]
- return ret
+ if self.metadata_class is None:
+ return self.http_method_not_allowed(request, *args, **kwargs)
+ data = self.metadata_class().determine_metadata(request, self)
+ return Response(data, status=status.HTTP_200_OK)
--
cgit v1.2.3
From 127c0bd3d68860dd6567d81047257fbc3e70b4b9 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 24 Sep 2014 20:25:59 +0100
Subject: Custom deepcopy on Field classes
---
rest_framework/fields.py | 6 ++++++
1 file changed, 6 insertions(+)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index f859658a..1f7d964a 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -9,6 +9,7 @@ from rest_framework import ISO_8601
from rest_framework.compat import smart_text, EmailValidator, MinValueValidator, MaxValueValidator, URLValidator
from rest_framework.settings import api_settings
from rest_framework.utils import html, representation, humanize_datetime
+import copy
import datetime
import decimal
import inspect
@@ -150,6 +151,11 @@ class Field(object):
instance._kwargs = kwargs
return instance
+ def __deepcopy__(self, memo):
+ args = copy.deepcopy(self._args)
+ kwargs = copy.deepcopy(self._kwargs)
+ return self.__class__(*args, **kwargs)
+
def bind(self, field_name, parent, root):
"""
Setup the context for the field instance.
--
cgit v1.2.3
From fb1546ee50953faae8af505a0c90da00ac08ad92 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 24 Sep 2014 20:53:37 +0100
Subject: Enforce field_name != source
---
rest_framework/fields.py | 11 +++++++++++
1 file changed, 11 insertions(+)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 1f7d964a..9280ea3a 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -160,6 +160,17 @@ class Field(object):
"""
Setup the context for the field instance.
"""
+
+ # In order to enforce a consistent style, we error if a redundant
+ # 'source' argument has been used. For example:
+ # my_field = serializer.CharField(source='my_field')
+ assert self._kwargs.get('source') != field_name, (
+ "It is redundant to specify `source='%s'` on field '%s' in "
+ "serializer '%s', as it is the same the field name. "
+ "Remove the `source` keyword argument." %
+ (field_name, self.__class__.__name__, parent.__class__.__name__)
+ )
+
self.field_name = field_name
self.parent = parent
self.root = root
--
cgit v1.2.3
From 1420c76453c37c023a901dd0938d717b7b5e52ca Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 25 Sep 2014 10:49:25 +0100
Subject: Ensure proper sorting of 'choices' attribute on ChoiceField
---
rest_framework/fields.py | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 9280ea3a..d1aebbaf 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -2,6 +2,7 @@ from django.conf import settings
from django.core import validators
from django.core.exceptions import ValidationError
from django.utils import timezone
+from django.utils.datastructures import SortedDict
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
@@ -166,7 +167,7 @@ class Field(object):
# my_field = serializer.CharField(source='my_field')
assert self._kwargs.get('source') != field_name, (
"It is redundant to specify `source='%s'` on field '%s' in "
- "serializer '%s', as it is the same the field name. "
+ "serializer '%s', because it is the same as the field name. "
"Remove the `source` keyword argument." %
(field_name, self.__class__.__name__, parent.__class__.__name__)
)
@@ -303,6 +304,7 @@ class BooleanField(Field):
'invalid': _('`{input}` is not a valid boolean.')
}
default_empty_html = False
+ initial = False
TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
@@ -365,6 +367,7 @@ class CharField(Field):
'blank': _('This field may not be blank.')
}
default_empty_html = ''
+ initial = ''
def __init__(self, **kwargs):
self.allow_blank = kwargs.pop('allow_blank', False)
@@ -775,9 +778,9 @@ class ChoiceField(Field):
for item in choices
]
if all(pairs):
- self.choices = dict([(key, display_value) for key, display_value in choices])
+ self.choices = SortedDict([(key, display_value) for key, display_value in choices])
else:
- self.choices = dict([(item, item) for item in choices])
+ self.choices = SortedDict([(item, item) for item in choices])
# Map the string representation of choices to the underlying value.
# Allows us to deal with eg. integer choices while supporting either
--
cgit v1.2.3
From b22c9602fa0f717b688fdb35e4f6f42c189af3f3 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 25 Sep 2014 11:04:18 +0100
Subject: Automatic field binding
---
rest_framework/metadata.py | 3 +--
rest_framework/pagination.py | 1 -
rest_framework/serializers.py | 30 +++++++++++++++++++++++++-----
3 files changed, 26 insertions(+), 8 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/metadata.py b/rest_framework/metadata.py
index 580259de..af4bc396 100644
--- a/rest_framework/metadata.py
+++ b/rest_framework/metadata.py
@@ -10,7 +10,6 @@ from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied
from django.http import Http404
-from django.utils import six
from django.utils.datastructures import SortedDict
from rest_framework import exceptions, serializers
from rest_framework.compat import force_text
@@ -100,7 +99,7 @@ class SimpleMetadata(BaseMetadata):
"""
return SortedDict([
(field_name, self.get_field_info(field))
- for field_name, field in six.iteritems(serializer.fields)
+ for field_name, field in serializer.fields.items()
])
def get_field_info(self, field):
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index c5a9270a..fb451285 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -72,7 +72,6 @@ class BasePaginationSerializer(serializers.Serializer):
child=object_serializer(),
source='object_list'
)
- self.fields[results_field].bind(results_field, self, self)
class PaginationSerializer(BasePaginationSerializer):
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 8902294b..12e38090 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -149,6 +149,28 @@ class SerializerMetaclass(type):
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
+class BindingDict(object):
+ def __init__(self, serializer):
+ self.serializer = serializer
+ self.fields = SortedDict()
+
+ def __setitem__(self, key, field):
+ self.fields[key] = field
+ field.bind(field_name=key, parent=self.serializer, root=self.serializer)
+
+ def __getitem__(self, key):
+ return self.fields[key]
+
+ def __delitem__(self, key):
+ del self.fields[key]
+
+ def items(self):
+ return self.fields.items()
+
+ def values(self):
+ return self.fields.values()
+
+
@six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer):
def __init__(self, *args, **kwargs):
@@ -161,11 +183,9 @@ class Serializer(BaseSerializer):
# Every new serializer is created with a clone of the field instances.
# This allows users to dynamically modify the fields on a serializer
# instance without affecting every other serializer class.
- self.fields = self._get_base_fields()
-
- # Setup all the child fields, to provide them with the current context.
- for field_name, field in self.fields.items():
- field.bind(field_name, self, self)
+ self.fields = BindingDict(self)
+ for key, value in self._get_base_fields().items():
+ self.fields[key] = value
def __new__(cls, *args, **kwargs):
# We override this method in order to automagically create
--
cgit v1.2.3
From 64632da3718f501cb8174243385d38b547c2fefd Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 25 Sep 2014 11:40:32 +0100
Subject: Clean up bind - no longer needs to be called multiple times in nested
fields
---
rest_framework/fields.py | 21 ++++++++++++++++-----
rest_framework/relations.py | 5 -----
rest_framework/serializers.py | 26 +++++++++-----------------
3 files changed, 25 insertions(+), 27 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index d1aebbaf..446732c3 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -109,7 +109,8 @@ class Field(object):
def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None,
label=None, help_text=None, style=None,
- error_messages=None, validators=[], allow_null=False):
+ error_messages=None, validators=[], allow_null=False,
+ context=None):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
@@ -135,6 +136,11 @@ class Field(object):
self.validators = validators or self.default_validators[:]
self.allow_null = allow_null
+ # These are set up by `.bind()` when the field is added to a serializer.
+ self.field_name = None
+ self.parent = None
+ self._context = {} if (context is None) else context
+
# Collect default error message from self and parent classes
messages = {}
for cls in reversed(self.__class__.__mro__):
@@ -157,7 +163,14 @@ class Field(object):
kwargs = copy.deepcopy(self._kwargs)
return self.__class__(*args, **kwargs)
- def bind(self, field_name, parent, root):
+ @property
+ def context(self):
+ root = self
+ while root.parent is not None:
+ root = root.parent
+ return root._context
+
+ def bind(self, field_name, parent):
"""
Setup the context for the field instance.
"""
@@ -174,10 +187,8 @@ class Field(object):
self.field_name = field_name
self.parent = parent
- self.root = root
- self.context = parent.context
- # `self.label` should deafult to being based on the field name.
+ # `self.label` should default to being based on the field name.
if self.label is None:
self.label = field_name.replace('_', ' ').capitalize()
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 5aa1f8bd..b37a6fed 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -243,11 +243,6 @@ class ManyRelation(Field):
assert child_relation is not None, '`child_relation` is a required argument.'
super(ManyRelation, self).__init__(*args, **kwargs)
- def bind(self, field_name, parent, root):
- # ManyRelation needs to provide the current context to the child relation.
- super(ManyRelation, self).bind(field_name, parent, root)
- self.child_relation.bind(field_name, parent, root)
-
def to_internal_value(self, data):
return [
self.child_relation.to_internal_value(item)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 12e38090..04721c7a 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -150,13 +150,20 @@ class SerializerMetaclass(type):
class BindingDict(object):
+ """
+ This dict-like object is used to store fields on a serializer.
+
+ This ensures that whenever fields are added to the serializer we call
+ `field.bind()` so that the `field_name` and `parent` attributes
+ can be set correctly.
+ """
def __init__(self, serializer):
self.serializer = serializer
self.fields = SortedDict()
def __setitem__(self, key, field):
self.fields[key] = field
- field.bind(field_name=key, parent=self.serializer, root=self.serializer)
+ field.bind(field_name=key, parent=self.serializer)
def __getitem__(self, key):
return self.fields[key]
@@ -174,7 +181,6 @@ class BindingDict(object):
@six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer):
def __init__(self, *args, **kwargs):
- self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
kwargs.pop('many', None)
@@ -198,13 +204,6 @@ class Serializer(BaseSerializer):
def _get_base_fields(self):
return copy.deepcopy(self._declared_fields)
- def bind(self, field_name, parent, root):
- # If the serializer is used as a field then when it becomes bound
- # it also needs to bind all its child fields.
- super(Serializer, self).bind(field_name, parent, root)
- for field_name, field in self.fields.items():
- field.bind(field_name, self, root)
-
def get_initial(self):
return dict([
(field.field_name, field.get_initial())
@@ -290,17 +289,10 @@ class ListSerializer(BaseSerializer):
self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.'
assert not inspect.isclass(self.child), '`child` has not been instantiated.'
- self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
super(ListSerializer, self).__init__(*args, **kwargs)
- self.child.bind('', self, self)
-
- def bind(self, field_name, parent, root):
- # If the list is used as a field then it needs to provide
- # the current context to the child serializer.
- super(ListSerializer, self).bind(field_name, parent, root)
- self.child.bind(field_name, self, root)
+ self.child.bind(field_name='', parent=self)
def get_value(self, dictionary):
# We override the default field access in order to support
--
cgit v1.2.3
From b47ca158b9ba9733baad080e648d24b0465ec697 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 25 Sep 2014 12:09:12 +0100
Subject: Check for redundant on SerializerMethodField
---
rest_framework/fields.py | 29 ++++++++++++++++++++++-------
1 file changed, 22 insertions(+), 7 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 446732c3..328e93ef 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -178,7 +178,7 @@ class Field(object):
# In order to enforce a consistent style, we error if a redundant
# 'source' argument has been used. For example:
# my_field = serializer.CharField(source='my_field')
- assert self._kwargs.get('source') != field_name, (
+ assert self.source != field_name, (
"It is redundant to specify `source='%s'` on field '%s' in "
"serializer '%s', because it is the same as the field name. "
"Remove the `source` keyword argument." %
@@ -883,17 +883,32 @@ class SerializerMethodField(Field):
def get_extra_info(self, obj):
return ... # Calculate some data to return.
"""
- def __init__(self, method_attr=None, **kwargs):
- self.method_attr = method_attr
+ def __init__(self, method_name=None, **kwargs):
+ self.method_name = method_name
kwargs['source'] = '*'
kwargs['read_only'] = True
super(SerializerMethodField, self).__init__(**kwargs)
+ def bind(self, field_name, parent):
+ # In order to enforce a consistent style, we error if a redundant
+ # 'method_name' argument has been used. For example:
+ # my_field = serializer.CharField(source='my_field')
+ default_method_name = 'get_{field_name}'.format(field_name=field_name)
+ assert self.method_name != default_method_name, (
+ "It is redundant to specify `%s` on SerializerMethodField '%s' in "
+ "serializer '%s', because it is the same as the default method name. "
+ "Remove the `method_name` argument." %
+ (self.method_name, field_name, parent.__class__.__name__)
+ )
+
+ # The method name should default to `get_{field_name}`.
+ if self.method_name is None:
+ self.method_name = default_method_name
+
+ super(SerializerMethodField, self).bind(field_name, parent)
+
def to_representation(self, value):
- method_attr = self.method_attr
- if method_attr is None:
- method_attr = 'get_{field_name}'.format(field_name=self.field_name)
- method = getattr(self.parent, method_attr)
+ method = getattr(self.parent, self.method_name)
return method(value)
--
cgit v1.2.3
From 8ee92f8a18c3a31a2a95233f36754203dc60bb18 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 25 Sep 2014 13:10:33 +0100
Subject: Refuse to downcast from datetime to date or time
---
rest_framework/fields.py | 118 ++++++++++++++++++++++++++---------------------
1 file changed, 65 insertions(+), 53 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 328e93ef..d855e6fd 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -608,120 +608,126 @@ class DecimalField(Field):
# Date & time fields...
-class DateField(Field):
+class DateTimeField(Field):
default_error_messages = {
- 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'),
- 'datetime': _('Expected a date but got a datetime.'),
+ 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'),
+ 'date': _('Expected a datetime but got a date.'),
}
- format = api_settings.DATE_FORMAT
- input_formats = api_settings.DATE_INPUT_FORMATS
+ format = api_settings.DATETIME_FORMAT
+ input_formats = api_settings.DATETIME_INPUT_FORMATS
+ default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None
- def __init__(self, format=empty, input_formats=None, *args, **kwargs):
+ def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs):
self.format = format if format is not empty else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
- super(DateField, self).__init__(*args, **kwargs)
+ self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone
+ super(DateTimeField, self).__init__(*args, **kwargs)
+
+ def enforce_timezone(self, value):
+ """
+ When `self.default_timezone` is `None`, always return naive datetimes.
+ When `self.default_timezone` is not `None`, always return aware datetimes.
+ """
+ if (self.default_timezone is not None) and not timezone.is_aware(value):
+ return timezone.make_aware(value, self.default_timezone)
+ elif (self.default_timezone is None) and timezone.is_aware(value):
+ return timezone.make_naive(value, timezone.UTC())
+ return value
def to_internal_value(self, value):
- if isinstance(value, datetime.datetime):
- self.fail('datetime')
+ if (isinstance(value, datetime.date) and not isinstance(value, datetime.datetime):
+ self.fail('date')
- if isinstance(value, datetime.date):
- return value
+ if isinstance(value, datetime.datetime):
+ return self.enforce_timezone(value)
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
- parsed = parse_date(value)
+ parsed = parse_datetime(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
- return parsed
+ return self.enforce_timezone(parsed)
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
- return parsed.date()
+ return self.enforce_timezone(parsed)
- humanized_format = humanize_datetime.date_formats(self.input_formats)
+ humanized_format = humanize_datetime.datetime_formats(self.input_formats)
self.fail('invalid', format=humanized_format)
def to_representation(self, value):
if value is None or self.format is None:
return value
- if isinstance(value, datetime.datetime):
- value = value.date()
-
if self.format.lower() == ISO_8601:
- return value.isoformat()
+ ret = value.isoformat()
+ if ret.endswith('+00:00'):
+ ret = ret[:-6] + 'Z'
+ return ret
return value.strftime(self.format)
-class DateTimeField(Field):
+class DateField(Field):
default_error_messages = {
- 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'),
- 'date': _('Expected a datetime but got a date.'),
+ 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'),
+ 'datetime': _('Expected a date but got a datetime.'),
}
- format = api_settings.DATETIME_FORMAT
- input_formats = api_settings.DATETIME_INPUT_FORMATS
- default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None
+ format = api_settings.DATE_FORMAT
+ input_formats = api_settings.DATE_INPUT_FORMATS
- def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs):
+ def __init__(self, format=empty, input_formats=None, *args, **kwargs):
self.format = format if format is not empty else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
- self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone
- super(DateTimeField, self).__init__(*args, **kwargs)
-
- def enforce_timezone(self, value):
- """
- When `self.default_timezone` is `None`, always return naive datetimes.
- When `self.default_timezone` is not `None`, always return aware datetimes.
- """
- if (self.default_timezone is not None) and not timezone.is_aware(value):
- return timezone.make_aware(value, self.default_timezone)
- elif (self.default_timezone is None) and timezone.is_aware(value):
- return timezone.make_naive(value, timezone.UTC())
- return value
+ super(DateField, self).__init__(*args, **kwargs)
def to_internal_value(self, value):
- if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime):
- self.fail('date')
-
if isinstance(value, datetime.datetime):
- return self.enforce_timezone(value)
+ self.fail('datetime')
+
+ if isinstance(value, datetime.date):
+ return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
- parsed = parse_datetime(value)
+ parsed = parse_date(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
- return self.enforce_timezone(parsed)
+ return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
- return self.enforce_timezone(parsed)
+ return parsed.date()
- humanized_format = humanize_datetime.datetime_formats(self.input_formats)
+ humanized_format = humanize_datetime.date_formats(self.input_formats)
self.fail('invalid', format=humanized_format)
def to_representation(self, value):
if value is None or self.format is None:
return value
+ # Applying a `DateField` to a datetime value is almost always
+ # not a sensible thing to do, as it means naively dropping
+ # any explicit or implicit timezone info.
+ assert not isinstance(value, datetime.datetime), (
+ 'Expected a `date`, but got a `datetime`. Refusing to coerce, '
+ 'as this may mean losing timezone information. Use a custom '
+ 'read-only field and deal with timezone issues explicitly.'
+ )
+
if self.format.lower() == ISO_8601:
- ret = value.isoformat()
- if ret.endswith('+00:00'):
- ret = ret[:-6] + 'Z'
- return ret
+ return value.isoformat()
return value.strftime(self.format)
@@ -765,8 +771,14 @@ class TimeField(Field):
if value is None or self.format is None:
return value
- if isinstance(value, datetime.datetime):
- value = value.time()
+ # Applying a `TimeField` to a datetime value is almost always
+ # not a sensible thing to do, as it means naively dropping
+ # any explicit or implicit timezone info.
+ assert not isinstance(value, datetime.datetime), (
+ 'Expected a `time`, but got a `datetime`. Refusing to coerce, '
+ 'as this may mean losing timezone information. Use a custom '
+ 'read-only field and deal with timezone issues explicitly.'
+ )
if self.format.lower() == ISO_8601:
return value.isoformat()
--
cgit v1.2.3
From 3a5335f09f58439f8e3c0bddbed8e4c7eeb32482 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 25 Sep 2014 13:12:02 +0100
Subject: Fix syntax error
---
rest_framework/fields.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index d855e6fd..7beccbb7 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -635,7 +635,7 @@ class DateTimeField(Field):
return value
def to_internal_value(self, value):
- if (isinstance(value, datetime.date) and not isinstance(value, datetime.datetime):
+ if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime):
self.fail('date')
if isinstance(value, datetime.datetime):
--
cgit v1.2.3
From 417fe1b675bd1d42518fb89a6f81547caef5b735 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 25 Sep 2014 13:37:26 +0100
Subject: Partial support
---
rest_framework/fields.py | 35 +++++++++++++++++++++++++----------
rest_framework/serializers.py | 6 ++++--
2 files changed, 29 insertions(+), 12 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 7beccbb7..032bfd04 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -109,8 +109,7 @@ class Field(object):
def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None,
label=None, help_text=None, style=None,
- error_messages=None, validators=[], allow_null=False,
- context=None):
+ error_messages=None, validators=[], allow_null=False):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
@@ -139,7 +138,6 @@ class Field(object):
# These are set up by `.bind()` when the field is added to a serializer.
self.field_name = None
self.parent = None
- self._context = {} if (context is None) else context
# Collect default error message from self and parent classes
messages = {}
@@ -163,13 +161,6 @@ class Field(object):
kwargs = copy.deepcopy(self._kwargs)
return self.__class__(*args, **kwargs)
- @property
- def context(self):
- root = self
- while root.parent is not None:
- root = root.parent
- return root._context
-
def bind(self, field_name, parent):
"""
Setup the context for the field instance.
@@ -254,6 +245,8 @@ class Field(object):
"""
if data is empty:
if self.required:
+ if getattr(self.root, 'partial', False):
+ raise SkipField()
self.fail('required')
return self.get_default()
@@ -304,7 +297,29 @@ class Field(object):
raise AssertionError(msg)
raise ValidationError(msg.format(**kwargs))
+ @property
+ def root(self):
+ """
+ Returns the top-level serializer for this field.
+ """
+ root = self
+ while root.parent is not None:
+ root = root.parent
+ return root
+
+ @property
+ def context(self):
+ """
+ Returns the context as passed to the root serializer on initialization.
+ """
+ return getattr(self.root, '_context', {})
+
def __repr__(self):
+ """
+ Fields are represented using their initial calling arguments.
+ This allows us to create descriptive representations for serializer
+ instances that show all the declared fields on the serializer.
+ """
return representation.field_repr(self)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 04721c7a..b6a1898c 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -181,8 +181,9 @@ class BindingDict(object):
@six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer):
def __init__(self, *args, **kwargs):
- kwargs.pop('partial', None)
kwargs.pop('many', None)
+ self.partial = kwargs.pop('partial', False)
+ self._context = kwargs.pop('context', {})
super(Serializer, self).__init__(*args, **kwargs)
@@ -289,7 +290,8 @@ class ListSerializer(BaseSerializer):
self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.'
assert not inspect.isclass(self.child), '`child` has not been instantiated.'
- kwargs.pop('partial', None)
+ self.partial = kwargs.pop('partial', False)
+ self._context = kwargs.pop('context', {})
super(ListSerializer, self).__init__(*args, **kwargs)
self.child.bind(field_name='', parent=self)
--
cgit v1.2.3
From 2859eaf524bca23f27e666d24a0b63ba61698a76 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 26 Sep 2014 10:46:52 +0100
Subject: request.data attribute
---
rest_framework/authtoken/views.py | 2 +-
rest_framework/fields.py | 51 +++++++++++++++++++++++----------------
rest_framework/filters.py | 6 ++---
rest_framework/generics.py | 4 +--
rest_framework/mixins.py | 6 ++---
rest_framework/negotiation.py | 4 +--
rest_framework/renderers.py | 4 +--
rest_framework/request.py | 37 +++++++++++++++++++++++++---
rest_framework/serializers.py | 21 +++++++---------
9 files changed, 86 insertions(+), 49 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py
index 94e6f061..103abb27 100644
--- a/rest_framework/authtoken/views.py
+++ b/rest_framework/authtoken/views.py
@@ -16,7 +16,7 @@ class ObtainAuthToken(APIView):
model = Token
def post(self, request):
- serializer = self.serializer_class(data=request.DATA)
+ serializer = self.serializer_class(data=request.data)
if serializer.is_valid():
user = serializer.validated_data['user']
token, created = Token.objects.get_or_create(user=user)
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 032bfd04..ec07a413 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -56,7 +56,7 @@ def get_attribute(instance, attrs):
except AttributeError as exc:
try:
return instance[attr]
- except (KeyError, TypeError):
+ except (KeyError, TypeError, AttributeError):
raise exc
return instance
@@ -90,6 +90,7 @@ NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
+USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField'
MISSING_ERROR_MESSAGE = (
'ValidationError raised by `{class_name}`, but error key `{key}` does '
'not exist in the `error_messages` dictionary.'
@@ -105,9 +106,10 @@ class Field(object):
}
default_validators = []
default_empty_html = None
+ initial = None
def __init__(self, read_only=False, write_only=False,
- required=None, default=empty, initial=None, source=None,
+ required=None, default=empty, initial=empty, source=None,
label=None, help_text=None, style=None,
error_messages=None, validators=[], allow_null=False):
self._creation_counter = Field._creation_counter
@@ -122,13 +124,14 @@ class Field(object):
assert not (read_only and required), NOT_READ_ONLY_REQUIRED
assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
+ assert not (read_only and self.__class__ == Field), USE_READONLYFIELD
self.read_only = read_only
self.write_only = write_only
self.required = required
self.default = default
self.source = source
- self.initial = initial
+ self.initial = self.initial if (initial is empty) else initial
self.label = label
self.help_text = help_text
self.style = {} if style is None else style
@@ -146,24 +149,10 @@ class Field(object):
messages.update(error_messages or {})
self.error_messages = messages
- def __new__(cls, *args, **kwargs):
- """
- When a field is instantiated, we store the arguments that were used,
- so that we can present a helpful representation of the object.
- """
- instance = super(Field, cls).__new__(cls)
- instance._args = args
- instance._kwargs = kwargs
- return instance
-
- def __deepcopy__(self, memo):
- args = copy.deepcopy(self._args)
- kwargs = copy.deepcopy(self._kwargs)
- return self.__class__(*args, **kwargs)
-
def bind(self, field_name, parent):
"""
- Setup the context for the field instance.
+ Initializes the field name and parent for the field instance.
+ Called when a field is added to the parent serializer instance.
"""
# In order to enforce a consistent style, we error if a redundant
@@ -244,9 +233,9 @@ class Field(object):
validated data.
"""
if data is empty:
+ if getattr(self.root, 'partial', False):
+ raise SkipField()
if self.required:
- if getattr(self.root, 'partial', False):
- raise SkipField()
self.fail('required')
return self.get_default()
@@ -314,6 +303,25 @@ class Field(object):
"""
return getattr(self.root, '_context', {})
+ def __new__(cls, *args, **kwargs):
+ """
+ When a field is instantiated, we store the arguments that were used,
+ so that we can present a helpful representation of the object.
+ """
+ instance = super(Field, cls).__new__(cls)
+ instance._args = args
+ instance._kwargs = kwargs
+ return instance
+
+ def __deepcopy__(self, memo):
+ """
+ When cloning fields we instantiate using the arguments it was
+ originally created with, rather than copying the complete state.
+ """
+ args = copy.deepcopy(self._args)
+ kwargs = copy.deepcopy(self._kwargs)
+ return self.__class__(*args, **kwargs)
+
def __repr__(self):
"""
Fields are represented using their initial calling arguments.
@@ -358,6 +366,7 @@ class NullBooleanField(Field):
'invalid': _('`{input}` is not a valid boolean.')
}
default_empty_html = None
+ initial = None
TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
NULL_VALUES = set(('n', 'N', 'null', 'Null', 'NULL', '', None))
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index c580f935..085dfe65 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -64,7 +64,7 @@ class DjangoFilterBackend(BaseFilterBackend):
filter_class = self.get_filter_class(view, queryset)
if filter_class:
- return filter_class(request.QUERY_PARAMS, queryset=queryset).qs
+ return filter_class(request.query_params, queryset=queryset).qs
return queryset
@@ -78,7 +78,7 @@ class SearchFilter(BaseFilterBackend):
Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited.
"""
- params = request.QUERY_PARAMS.get(self.search_param, '')
+ params = request.query_params.get(self.search_param, '')
return params.replace(',', ' ').split()
def construct_search(self, field_name):
@@ -121,7 +121,7 @@ class OrderingFilter(BaseFilterBackend):
the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings.
"""
- params = request.QUERY_PARAMS.get(self.ordering_param)
+ params = request.query_params.get(self.ordering_param)
if params:
return [param.strip() for param in params.split(',')]
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index f49b0a43..cf903dab 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -112,7 +112,7 @@ class GenericAPIView(views.APIView):
paginator = self.paginator_class(queryset, page_size)
page_kwarg = self.kwargs.get(self.page_kwarg)
- page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
+ page_query_param = self.request.query_params.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1
try:
page_number = paginator.validate_number(page)
@@ -166,7 +166,7 @@ class GenericAPIView(views.APIView):
if self.paginate_by_param:
try:
return strict_positive_int(
- self.request.QUERY_PARAMS[self.paginate_by_param],
+ self.request.query_params[self.paginate_by_param],
cutoff=self.max_paginate_by
)
except (KeyError, ValueError):
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 14a6b44b..04b7a763 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -18,7 +18,7 @@ class CreateModelMixin(object):
Create a model instance.
"""
def create(self, request, *args, **kwargs):
- serializer = self.get_serializer(data=request.DATA)
+ serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
headers = self.get_success_headers(serializer.data)
@@ -62,7 +62,7 @@ class UpdateModelMixin(object):
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
- serializer = self.get_serializer(instance, data=request.DATA, partial=partial)
+ serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data)
@@ -95,7 +95,7 @@ class AllowPUTAsCreateMixin(object):
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object_or_none()
- serializer = self.get_serializer(instance, data=request.DATA, partial=partial)
+ serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
if instance is None:
diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py
index ca7b5397..1838130a 100644
--- a/rest_framework/negotiation.py
+++ b/rest_framework/negotiation.py
@@ -38,7 +38,7 @@ class DefaultContentNegotiation(BaseContentNegotiation):
"""
# Allow URL style format override. eg. "?format=json
format_query_param = self.settings.URL_FORMAT_OVERRIDE
- format = format_suffix or request.QUERY_PARAMS.get(format_query_param)
+ format = format_suffix or request.query_params.get(format_query_param)
if format:
renderers = self.filter_renderers(renderers, format)
@@ -87,5 +87,5 @@ class DefaultContentNegotiation(BaseContentNegotiation):
Allows URL style accept override. eg. "?accept=application/json"
"""
header = request.META.get('HTTP_ACCEPT', '*/*')
- header = request.QUERY_PARAMS.get(self.settings.URL_ACCEPT_OVERRIDE, header)
+ header = request.query_params.get(self.settings.URL_ACCEPT_OVERRIDE, header)
return [token.strip() for token in header.split(',')]
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 3bf03e62..225f9fe8 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -120,7 +120,7 @@ class JSONPRenderer(JSONRenderer):
Determine the name of the callback to wrap around the json output.
"""
request = renderer_context.get('request', None)
- params = request and request.QUERY_PARAMS or {}
+ params = request and request.query_params or {}
return params.get(self.callback_parameter, self.default_callback)
def render(self, data, accepted_media_type=None, renderer_context=None):
@@ -426,7 +426,7 @@ class BrowsableAPIRenderer(BaseRenderer):
"""
if request.method == method:
try:
- data = request.DATA
+ data = request.data
# files = request.FILES
except ParseError:
data = None
diff --git a/rest_framework/request.py b/rest_framework/request.py
index 27532661..d80baa70 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -4,7 +4,7 @@ The Request class is used as a wrapper around the standard request object.
The wrapped request then offers a richer API, in particular :
- content automatically parsed according to `Content-Type` header,
- and available as `request.DATA`
+ and available as `request.data`
- full support of PUT method, including support for file uploads
- form overloading of HTTP method, content type and content
"""
@@ -13,6 +13,7 @@ from django.conf import settings
from django.http import QueryDict
from django.http.multipartparser import parse_header
from django.utils.datastructures import MultiValueDict
+from django.utils.datastructures import MergeDict as DjangoMergeDict
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions
from rest_framework.compat import BytesIO
@@ -58,6 +59,15 @@ class override_method(object):
self.view.action = self.action
+class MergeDict(DjangoMergeDict, dict):
+ """
+ Using this as a workaround until the parsers API is properly
+ addressed in 3.1.
+ """
+ def __init__(self, *dicts):
+ self.dicts = dicts
+
+
class Empty(object):
"""
Placeholder for unset attributes.
@@ -82,6 +92,7 @@ def clone_request(request, method):
parser_context=request.parser_context)
ret._data = request._data
ret._files = request._files
+ ret._full_data = request._full_data
ret._content_type = request._content_type
ret._stream = request._stream
ret._method = method
@@ -133,6 +144,7 @@ class Request(object):
self.parser_context = parser_context
self._data = Empty
self._files = Empty
+ self._full_data = Empty
self._method = Empty
self._content_type = Empty
self._stream = Empty
@@ -186,12 +198,25 @@ class Request(object):
return self._stream
@property
- def QUERY_PARAMS(self):
+ def query_params(self):
"""
More semantically correct name for request.GET.
"""
return self._request.GET
+ @property
+ def QUERY_PARAMS(self):
+ """
+ Synonym for `.query_params`, for backwards compatibility.
+ """
+ return self._request.GET
+
+ @property
+ def data(self):
+ if not _hasattr(self, '_full_data'):
+ self._load_data_and_files()
+ return self._full_data
+
@property
def DATA(self):
"""
@@ -272,6 +297,10 @@ class Request(object):
if not _hasattr(self, '_data'):
self._data, self._files = self._parse()
+ if self._files:
+ self._full_data = MergeDict(self._data, self._files)
+ else:
+ self._full_data = self._data
def _load_method_and_content_type(self):
"""
@@ -333,6 +362,7 @@ class Request(object):
# At this point we're committed to parsing the request as form data.
self._data = self._request.POST
self._files = self._request.FILES
+ self._full_data = MergeDict(self._data, self._files)
# Method overloading - change the method and remove the param from the content.
if (
@@ -350,7 +380,7 @@ class Request(object):
):
self._content_type = self._data[self._CONTENTTYPE_PARAM]
self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))
- self._data, self._files = (Empty, Empty)
+ self._data, self._files, self._full_data = (Empty, Empty, Empty)
def _parse(self):
"""
@@ -380,6 +410,7 @@ class Request(object):
# logging the request or similar.
self._data = QueryDict('', encoding=self._request._encoding)
self._files = MultiValueDict()
+ self._full_data = self._data
raise
# Parser classes may return the raw data, or a
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index b6a1898c..a2b878ec 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -57,21 +57,24 @@ class BaseSerializer(Field):
def to_representation(self, instance):
raise NotImplementedError('`to_representation()` must be implemented.')
- def update(self, instance, attrs):
+ def update(self, instance, validated_data):
raise NotImplementedError('`update()` must be implemented.')
- def create(self, attrs):
+ def create(self, validated_data):
raise NotImplementedError('`create()` must be implemented.')
def save(self, extras=None):
- attrs = self.validated_data
+ validated_data = self.validated_data
if extras is not None:
- attrs = dict(list(attrs.items()) + list(extras.items()))
+ validated_data = dict(
+ list(validated_data.items()) +
+ list(extras.items())
+ )
if self.instance is not None:
- self.update(self.instance, attrs)
+ self.update(self.instance, validated_data)
else:
- self.instance = self.create(attrs)
+ self.instance = self.create(validated_data)
return self.instance
@@ -321,12 +324,6 @@ class ListSerializer(BaseSerializer):
def create(self, attrs_list):
return [self.child.create(attrs) for attrs in attrs_list]
- def save(self):
- if self.instance is not None:
- self.update(self.instance, self.validated_data)
- self.instance = self.create(self.validated_data)
- return self.instance
-
def __repr__(self):
return representation.list_repr(self, indent=1)
--
cgit v1.2.3
From 43e80c74b225e17edfe8a90da893823bf50b946f Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 26 Sep 2014 11:56:29 +0100
Subject: Release notes
---
rest_framework/serializers.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'rest_framework')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index a2b878ec..86bed773 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -75,6 +75,9 @@ class BaseSerializer(Field):
self.update(self.instance, validated_data)
else:
self.instance = self.create(validated_data)
+ assert self.instance is not None, (
+ '`create()` did not return an object instance.'
+ )
return self.instance
--
cgit v1.2.3
From 8b8623c5f84d443d26804cac52a793a3037a1dd0 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 26 Sep 2014 12:48:20 +0100
Subject: Allow many, partial and context in BaseSerializer
---
rest_framework/serializers.py | 28 ++++++++++++----------------
1 file changed, 12 insertions(+), 16 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 86bed773..245ec26f 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -47,9 +47,20 @@ class BaseSerializer(Field):
"""
def __init__(self, instance=None, data=None, **kwargs):
- super(BaseSerializer, self).__init__(**kwargs)
self.instance = instance
self._initial_data = data
+ self.partial = kwargs.pop('partial', False)
+ self._context = kwargs.pop('context', {})
+ kwargs.pop('many', None)
+ super(BaseSerializer, self).__init__(**kwargs)
+
+ def __new__(cls, *args, **kwargs):
+ # We override this method in order to automagically create
+ # `ListSerializer` classes instead when `many=True` is set.
+ if kwargs.pop('many', False):
+ kwargs['child'] = cls()
+ return ListSerializer(*args, **kwargs)
+ return super(BaseSerializer, cls).__new__(cls, *args, **kwargs)
def to_internal_value(self, data):
raise NotImplementedError('`to_internal_value()` must be implemented.')
@@ -187,10 +198,6 @@ class BindingDict(object):
@six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer):
def __init__(self, *args, **kwargs):
- kwargs.pop('many', None)
- self.partial = kwargs.pop('partial', False)
- self._context = kwargs.pop('context', {})
-
super(Serializer, self).__init__(*args, **kwargs)
# Every new serializer is created with a clone of the field instances.
@@ -200,14 +207,6 @@ class Serializer(BaseSerializer):
for key, value in self._get_base_fields().items():
self.fields[key] = value
- def __new__(cls, *args, **kwargs):
- # We override this method in order to automagically create
- # `ListSerializer` classes instead when `many=True` is set.
- if kwargs.pop('many', False):
- kwargs['child'] = cls()
- return ListSerializer(*args, **kwargs)
- return super(Serializer, cls).__new__(cls, *args, **kwargs)
-
def _get_base_fields(self):
return copy.deepcopy(self._declared_fields)
@@ -296,9 +295,6 @@ class ListSerializer(BaseSerializer):
self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.'
assert not inspect.isclass(self.child), '`child` has not been instantiated.'
- self.partial = kwargs.pop('partial', False)
- self._context = kwargs.pop('context', {})
-
super(ListSerializer, self).__init__(*args, **kwargs)
self.child.bind(field_name='', parent=self)
--
cgit v1.2.3
From 2e87de01430d7fec83f00948e60c8d61b317053b Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 26 Sep 2014 13:08:20 +0100
Subject: Added ListField
---
rest_framework/fields.py | 38 ++++++++++++++++++++++++++++++++++++++
rest_framework/serializers.py | 6 ++++--
2 files changed, 42 insertions(+), 2 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index ec07a413..cf42d36c 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -881,6 +881,44 @@ class ImageField(Field):
# Advanced field types...
+class ListField(Field):
+ child = None
+ initial = []
+ default_error_messages = {
+ 'not_a_list': _('Expected a list of items but got type `{input_type}`')
+ }
+
+ def __init__(self, *args, **kwargs):
+ self.child = kwargs.pop('child', copy.deepcopy(self.child))
+ assert self.child is not None, '`child` is a required argument.'
+ assert not inspect.isclass(self.child), '`child` has not been instantiated.'
+ super(ListField, self).__init__(*args, **kwargs)
+ self.child.bind(field_name='', parent=self)
+
+ def get_value(self, dictionary):
+ # We override the default field access in order to support
+ # lists in HTML forms.
+ if html.is_html_input(dictionary):
+ return html.parse_html_list(dictionary, prefix=self.field_name)
+ return dictionary.get(self.field_name, empty)
+
+ def to_internal_value(self, data):
+ """
+ List of dicts of native values <- List of dicts of primitive datatypes.
+ """
+ if html.is_html_input(data):
+ data = html.parse_html_list(data)
+ if isinstance(data, type('')) or not hasattr(data, '__iter__'):
+ self.fail('not_a_list', input_type=type(data).__name__)
+ return [self.child.run_validation(item) for item in data]
+
+ def to_representation(self, data):
+ """
+ List of object instances -> List of dicts of primitive datatypes.
+ """
+ return [self.child.to_representation(item) for item in data]
+
+
class ReadOnlyField(Field):
"""
A read-only field that simply returns the field value.
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 245ec26f..fa2e8fb1 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -287,6 +287,9 @@ class Serializer(BaseSerializer):
return representation.serializer_repr(self, indent=1)
+# There's some replication of `ListField` here,
+# but that's probably better than obfuscating the call hierarchy.
+
class ListSerializer(BaseSerializer):
child = None
initial = []
@@ -301,7 +304,7 @@ class ListSerializer(BaseSerializer):
def get_value(self, dictionary):
# We override the default field access in order to support
# lists in HTML forms.
- if is_html_input(dictionary):
+ if html.is_html_input(dictionary):
return html.parse_html_list(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty)
@@ -311,7 +314,6 @@ class ListSerializer(BaseSerializer):
"""
if html.is_html_input(data):
data = html.parse_html_list(data)
-
return [self.child.run_validation(item) for item in data]
def to_representation(self, data):
--
cgit v1.2.3
From 33ccf40b76ddae790c34c294a133219e68efb946 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 26 Sep 2014 13:14:08 +0100
Subject: Update version number
---
rest_framework/__init__.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'rest_framework')
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 7f724c18..261c9c98 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -8,7 +8,7 @@ ______ _____ _____ _____ __
"""
__title__ = 'Django REST framework'
-__version__ = '2.4.3'
+__version__ = '3.0.0'
__author__ = 'Tom Christie'
__license__ = 'BSD 2-Clause'
__copyright__ = 'Copyright 2011-2014 Tom Christie'
--
cgit v1.2.3
From 609014460861fdfe82054551790d6439292dde7b Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 26 Sep 2014 14:32:44 +0100
Subject: Simplify serialization slightly
---
rest_framework/fields.py | 11 +++++++----
rest_framework/serializers.py | 3 +--
2 files changed, 8 insertions(+), 6 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index cf42d36c..4c49aaba 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -202,12 +202,15 @@ class Field(object):
return self.default_empty_html if (ret == '') else ret
return dictionary.get(self.field_name, empty)
- def get_attribute(self, instance):
+ def get_field_representation(self, instance):
"""
- Given the *outgoing* object instance, return the value for this field
- that should be returned as a primative value.
+ Given the outgoing object instance, return the primative value
+ that should be used for this field.
"""
- return get_attribute(instance, self.source_attrs)
+ attribute = get_attribute(instance, self.source_attrs)
+ if attribute is None:
+ return None
+ return self.to_representation(attribute)
def get_default(self):
"""
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index fa2e8fb1..080b958d 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -268,8 +268,7 @@ class Serializer(BaseSerializer):
fields = [field for field in self.fields.values() if not field.write_only]
for field in fields:
- native_value = field.get_attribute(instance)
- ret[field.field_name] = field.to_representation(native_value)
+ ret[field.field_name] = field.get_field_representation(instance)
return ret
--
cgit v1.2.3
From dee3f78cb688b1bee892ef78d6eec23ccf67a80e Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 26 Sep 2014 17:06:20 +0100
Subject: FileField and ImageField
---
rest_framework/compat.py | 9 -----
rest_framework/fields.py | 82 ++++++++++++++++++++++++++++++++++++----------
rest_framework/settings.py | 3 +-
3 files changed, 66 insertions(+), 28 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 7303c32a..89af9b48 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -84,15 +84,6 @@ except ImportError:
from collections import UserDict
from collections import MutableMapping as DictMixin
-# Try to import PIL in either of the two ways it can end up installed.
-try:
- from PIL import Image
-except ImportError:
- try:
- import Image
- except ImportError:
- Image = None
-
def get_model_name(model_cls):
try:
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 4c49aaba..f4b53279 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,3 +1,4 @@
+from django import forms
from django.conf import settings
from django.core import validators
from django.core.exceptions import ValidationError
@@ -427,8 +428,6 @@ class CharField(Field):
return str(data)
def to_representation(self, value):
- if value is None:
- return None
return str(value)
@@ -446,8 +445,6 @@ class EmailField(CharField):
return str(data).strip()
def to_representation(self, value):
- if value is None:
- return None
return str(value).strip()
@@ -513,8 +510,6 @@ class IntegerField(Field):
return data
def to_representation(self, value):
- if value is None:
- return None
return int(value)
@@ -543,8 +538,6 @@ class FloatField(Field):
self.fail('invalid')
def to_representation(self, value):
- if value is None:
- return None
return float(value)
@@ -616,9 +609,6 @@ class DecimalField(Field):
return value
def to_representation(self, value):
- if value in (None, ''):
- return None
-
if not isinstance(value, decimal.Decimal):
value = decimal.Decimal(str(value).strip())
@@ -689,7 +679,7 @@ class DateTimeField(Field):
self.fail('invalid', format=humanized_format)
def to_representation(self, value):
- if value is None or self.format is None:
+ if self.format is None:
return value
if self.format.lower() == ISO_8601:
@@ -741,7 +731,7 @@ class DateField(Field):
self.fail('invalid', format=humanized_format)
def to_representation(self, value):
- if value is None or self.format is None:
+ if self.format is None:
return value
# Applying a `DateField` to a datetime value is almost always
@@ -795,7 +785,7 @@ class TimeField(Field):
self.fail('invalid', format=humanized_format)
def to_representation(self, value):
- if value is None or self.format is None:
+ if self.format is None:
return value
# Applying a `TimeField` to a datetime value is almost always
@@ -875,14 +865,68 @@ class MultipleChoiceField(ChoiceField):
# File types...
class FileField(Field):
- pass # TODO
+ default_error_messages = {
+ 'required': _("No file was submitted."),
+ 'invalid': _("The submitted data was not a file. Check the encoding type on the form."),
+ 'no_name': _("No filename could be determined."),
+ 'empty': _("The submitted file is empty."),
+ 'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'),
+ }
+ use_url = api_settings.UPLOADED_FILES_USE_URL
+ def __init__(self, *args, **kwargs):
+ self.max_length = kwargs.pop('max_length', None)
+ self.allow_empty_file = kwargs.pop('allow_empty_file', False)
+ self.use_url = kwargs.pop('use_url', self.use_url)
+ super(FileField, self).__init__(*args, **kwargs)
-class ImageField(Field):
- pass # TODO
+ def to_internal_value(self, data):
+ try:
+ # `UploadedFile` objects should have name and size attributes.
+ file_name = data.name
+ file_size = data.size
+ except AttributeError:
+ self.fail('invalid')
+ if not file_name:
+ self.fail('no_name')
+ if not self.allow_empty_file and not file_size:
+ self.fail('empty')
+ if self.max_length and len(file_name) > self.max_length:
+ self.fail('max_length', max_length=self.max_length, length=len(file_name))
-# Advanced field types...
+ return data
+
+ def to_representation(self, value):
+ if self.use_url:
+ return settings.MEDIA_URL + value.url
+ return value.name
+
+
+class ImageField(FileField):
+ default_error_messages = {
+ 'invalid_image': _(
+ 'Upload a valid image. The file you uploaded was either not an '
+ 'image or a corrupted image.'
+ ),
+ }
+
+ def __init__(self, *args, **kwargs):
+ self._DjangoImageField = kwargs.pop('_DjangoImageField', forms.ImageField)
+ super(ImageField, self).__init__(*args, **kwargs)
+
+ def to_internal_value(self, data):
+ # Image validation is a bit grungy, so we'll just outright
+ # defer to Django's implementation so we don't need to
+ # consider it, or treat PIL as a test dependancy.
+ file_object = super(ImageField, self).to_internal_value(data)
+ django_field = self._DjangoImageField()
+ django_field.error_messages = self.error_messages
+ django_field.to_python(file_object)
+ return file_object
+
+
+# Composite field types...
class ListField(Field):
child = None
@@ -922,6 +966,8 @@ class ListField(Field):
return [self.child.to_representation(item) for item in data]
+# Miscellaneous field types...
+
class ReadOnlyField(Field):
"""
A read-only field that simply returns the field value.
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index d7fb0a43..1e8c27fc 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -110,7 +110,8 @@ DEFAULTS = {
# Encoding
'UNICODE_JSON': True,
'COMPACT_JSON': True,
- 'COERCE_DECIMAL_TO_STRING': True
+ 'COERCE_DECIMAL_TO_STRING': True,
+ 'UPLOADED_FILES_USE_URL': True
}
--
cgit v1.2.3
From 43fd5a873051c99600386c1fdc9fa368edeb6eda Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 29 Sep 2014 09:24:03 +0100
Subject: Uniqueness validation
---
rest_framework/fields.py | 4 +++
rest_framework/utils/field_mapping.py | 5 +++
rest_framework/validators.py | 57 +++++++++++++++++++++++++++++++++++
3 files changed, 66 insertions(+)
create mode 100644 rest_framework/validators.py
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index f4b53279..231f693c 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -150,6 +150,10 @@ class Field(object):
messages.update(error_messages or {})
self.error_messages = messages
+ for validator in validators:
+ if getattr(validator, 'requires_context', False):
+ validator.serializer_field = self
+
def bind(self, field_name, parent):
"""
Initializes the field name and parent for the field instance.
diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py
index c3794083..cf9d910a 100644
--- a/rest_framework/utils/field_mapping.py
+++ b/rest_framework/utils/field_mapping.py
@@ -6,6 +6,7 @@ from django.core import validators
from django.db import models
from django.utils.text import capfirst
from rest_framework.compat import clean_manytomany_helptext
+from rest_framework.validators import UniqueValidator
import inspect
@@ -156,6 +157,10 @@ def get_field_kwargs(field_name, model_field):
if validator is not validators.validate_slug
]
+ if getattr(model_field, 'unique', False):
+ validator = UniqueValidator(queryset=model_field.model._default_manager)
+ validator_kwarg.append(validator)
+
max_digits = getattr(model_field, 'max_digits', None)
if max_digits is not None:
kwargs['max_digits'] = max_digits
diff --git a/rest_framework/validators.py b/rest_framework/validators.py
new file mode 100644
index 00000000..f5fbeb3c
--- /dev/null
+++ b/rest_framework/validators.py
@@ -0,0 +1,57 @@
+from django.core.exceptions import ValidationError
+
+
+class UniqueValidator:
+ # Validators with `requires_context` will have the field instance
+ # passed to them when the field is instantiated.
+ requires_context = True
+
+ def __init__(self, queryset):
+ self.queryset = queryset
+ self.serializer_field = None
+
+ def get_queryset(self):
+ return self.queryset.all()
+
+ def __call__(self, value):
+ field = self.serializer_field
+
+ # Determine the model field name that the serializer field corresponds to.
+ field_name = field.source_attrs[0] if field.source_attrs else field.field_name
+
+ # Determine the existing instance, if this is an update operation.
+ instance = getattr(field.parent, 'instance', None)
+
+ # Ensure uniqueness.
+ filter_kwargs = {field_name: value}
+ queryset = self.get_queryset().filter(**filter_kwargs)
+ if instance:
+ queryset = queryset.exclude(pk=instance.pk)
+ if queryset.exists():
+ raise ValidationError('This field must be unique.')
+
+
+class UniqueTogetherValidator:
+ requires_context = True
+
+ def __init__(self, queryset, fields):
+ self.queryset = queryset
+ self.fields = fields
+ self.serializer_field = None
+
+ def __call__(self, value):
+ serializer = self.serializer_field
+
+ # Determine the existing instance, if this is an update operation.
+ instance = getattr(serializer, 'instance', None)
+
+ # Ensure uniqueness.
+ filter_kwargs = dict([
+ (field_name, value[field_name]) for field_name in self.fields
+ ])
+ queryset = self.get_queryset().filter(**filter_kwargs)
+ if instance:
+ queryset = queryset.exclude(pk=instance.pk)
+ if queryset.exists():
+ field_names = ' and '.join(self.fields)
+ raise ValidationError('The fields %s must make a unique set.' % field_names)
--
cgit v1.2.3
From 9805a085fb115785f272489dc24b51ba6f8e6329 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 29 Sep 2014 11:23:02 +0100
Subject: UniqueTogetherValidator
---
rest_framework/serializers.py | 80 ++++++++++++++++++++++++++++++++++++++-----
rest_framework/validators.py | 38 +++++++++++++++-----
2 files changed, 101 insertions(+), 17 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 080b958d..09ad376a 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -23,6 +23,7 @@ from rest_framework.utils.field_mapping import (
get_relation_kwargs, get_nested_relation_kwargs,
ClassLookupDict
)
+from rest_framework.validators import UniqueTogetherValidator
import copy
import inspect
@@ -95,7 +96,7 @@ class BaseSerializer(Field):
def is_valid(self, raise_exception=False):
if not hasattr(self, '_validated_data'):
try:
- self._validated_data = self.to_internal_value(self._initial_data)
+ self._validated_data = self.run_validation(self._initial_data)
except ValidationError as exc:
self._validated_data = {}
self._errors = exc.message_dict
@@ -223,15 +224,43 @@ class Serializer(BaseSerializer):
return html.parse_html_dict(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty)
- def to_internal_value(self, data):
+ def run_validation(self, data=empty):
"""
- Dict of native values <- Dict of primitive datatypes.
+ We override the default `run_validation`, because the validation
+ performed by validators and the `.validate()` method should
+ be coerced into an error dictionary with a 'non_fields_error' key.
"""
+ if data is empty:
+ if getattr(self.root, 'partial', False):
+ raise SkipField()
+ if self.required:
+ self.fail('required')
+ return self.get_default()
+
+ if data is None:
+ if not self.allow_null:
+ self.fail('null')
+ return None
+
if not isinstance(data, dict):
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
})
+ value = self.to_internal_value(data)
+ try:
+ self.run_validators(value)
+ self.validate(value)
+ except ValidationError as exc:
+ raise ValidationError({
+ api_settings.NON_FIELD_ERRORS_KEY: exc.messages
+ })
+ return value
+
+ def to_internal_value(self, data):
+ """
+ Dict of native values <- Dict of primitive datatypes.
+ """
ret = {}
errors = {}
fields = [field for field in self.fields.values() if not field.read_only]
@@ -253,12 +282,7 @@ class Serializer(BaseSerializer):
if errors:
raise ValidationError(errors)
- try:
- return self.validate(ret)
- except ValidationError as exc:
- raise ValidationError({
- api_settings.NON_FIELD_ERRORS_KEY: exc.messages
- })
+ return ret
def to_representation(self, instance):
"""
@@ -355,6 +379,14 @@ class ModelSerializer(Serializer):
})
_related_class = PrimaryKeyRelatedField
+ def __init__(self, *args, **kwargs):
+ super(ModelSerializer, self).__init__(*args, **kwargs)
+ if 'validators' not in kwargs:
+ validators = self.get_unique_together_validators()
+ if validators:
+ self.validators.extend(validators)
+ self._kwargs['validators'] = validators
+
def create(self, attrs):
ModelClass = self.Meta.model
@@ -381,6 +413,36 @@ class ModelSerializer(Serializer):
setattr(obj, attr, value)
obj.save()
+ def get_unique_together_validators(self):
+ field_names = set([
+ field.source for field in self.fields.values()
+ if (field.source != '*') and ('.' not in field.source)
+ ])
+
+ validators = []
+ model_class = self.Meta.model
+
+ for unique_together in model_class._meta.unique_together:
+ if field_names.issuperset(set(unique_together)):
+ validator = UniqueTogetherValidator(
+ queryset=model_class._default_manager,
+ fields=unique_together
+ )
+ validator.serializer_field = self
+ validators.append(validator)
+
+ for parent_class in model_class._meta.parents.keys():
+ for unique_together in parent_class._meta.unique_together:
+ if field_names.issuperset(set(unique_together)):
+ validator = UniqueTogetherValidator(
+ queryset=parent_class._default_manager,
+ fields=unique_together
+ )
+ validator.serializer_field = self
+ validators.append(validator)
+
+ return validators
+
def _get_base_fields(self):
declared_fields = copy.deepcopy(self._declared_fields)
diff --git a/rest_framework/validators.py b/rest_framework/validators.py
index f5fbeb3c..20de4b42 100644
--- a/rest_framework/validators.py
+++ b/rest_framework/validators.py
@@ -1,18 +1,26 @@
+"""
+We perform uniqueness checks explicitly on the serializer class, rather
+the using Django's `.full_clean()`.
+
+This gives us better seperation of concerns, allows us to use single-step
+object creation, and makes it possible to switch between using the implicit
+`ModelSerializer` class and an equivelent explicit `Serializer` class.
+"""
from django.core.exceptions import ValidationError
+from django.utils.translation import ugettext_lazy as _
+from rest_framework.utils.representation import smart_repr
class UniqueValidator:
# Validators with `requires_context` will have the field instance
# passed to them when the field is instantiated.
requires_context = True
+ message = _('This field must be unique.')
def __init__(self, queryset):
self.queryset = queryset
self.serializer_field = None
- def get_queryset(self):
- return self.queryset.all()
-
def __call__(self, value):
field = self.serializer_field
@@ -24,15 +32,22 @@ class UniqueValidator:
# Ensure uniqueness.
filter_kwargs = {field_name: value}
- queryset = self.get_queryset().filter(**filter_kwargs)
+ queryset = self.queryset.filter(**filter_kwargs)
if instance:
queryset = queryset.exclude(pk=instance.pk)
if queryset.exists():
- raise ValidationError('This field must be unique.')
+ raise ValidationError(self.message)
+
+ def __repr__(self):
+ return '<%s(queryset=%s)>' % (
+ self.__class__.__name__,
+ smart_repr(self.queryset)
+ )
class UniqueTogetherValidator:
requires_context = True
+ message = _('The fields {field_names} must make a unique set.')
def __init__(self, queryset, fields):
self.queryset = queryset
@@ -49,9 +64,16 @@ class UniqueTogetherValidator:
filter_kwargs = dict([
(field_name, value[field_name]) for field_name in self.fields
])
- queryset = self.get_queryset().filter(**filter_kwargs)
+ queryset = self.queryset.filter(**filter_kwargs)
if instance:
queryset = queryset.exclude(pk=instance.pk)
if queryset.exists():
- field_names = ' and '.join(self.fields)
- raise ValidationError('The fields %s must make a unique set.' % field_names)
+ field_names = ', '.join(self.fields)
+ raise ValidationError(self.message.format(field_names=field_names))
+
+ def __repr__(self):
+ return '<%s(queryset=%s, fields=%s)>' % (
+ self.__class__.__name__,
+ smart_repr(self.queryset),
+ smart_repr(self.fields)
+ )
--
cgit v1.2.3
From d2d412993f537952fd7809ded3e981f85ec318e9 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 29 Sep 2014 11:24:21 +0100
Subject: .validate() on serializer fields
---
rest_framework/fields.py | 4 ++++
1 file changed, 4 insertions(+)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 231f693c..fee6080a 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -254,6 +254,7 @@ class Field(object):
value = self.to_internal_value(data)
self.run_validators(value)
+ self.validate(value)
return value
def run_validators(self, value):
@@ -270,6 +271,9 @@ class Field(object):
if errors:
raise ValidationError(errors)
+ def validate(self, value):
+ pass
+
def to_internal_value(self, data):
"""
Transform the *incoming* primative data into a native value.
--
cgit v1.2.3
From d1b2c8ac7faec65483cbddf4f1718ca4f5805246 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 29 Sep 2014 14:12:26 +0100
Subject: Absolute URLs for file fields
---
rest_framework/fields.py | 12 +++++++-----
rest_framework/serializers.py | 2 --
2 files changed, 7 insertions(+), 7 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index fee6080a..f7ea3b0c 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -150,10 +150,6 @@ class Field(object):
messages.update(error_messages or {})
self.error_messages = messages
- for validator in validators:
- if getattr(validator, 'requires_context', False):
- validator.serializer_field = self
-
def bind(self, field_name, parent):
"""
Initializes the field name and parent for the field instance.
@@ -264,6 +260,8 @@ class Field(object):
"""
errors = []
for validator in self.validators:
+ if getattr(validator, 'requires_context', False):
+ validator.serializer_field = self
try:
validator(value)
except ValidationError as exc:
@@ -907,7 +905,11 @@ class FileField(Field):
def to_representation(self, value):
if self.use_url:
- return settings.MEDIA_URL + value.url
+ url = settings.MEDIA_URL + value.url
+ request = self.context.get('request', None)
+ if request is not None:
+ return request.build_absolute_uri(url)
+ return url
return value.name
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 09ad376a..0faa5671 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -428,7 +428,6 @@ class ModelSerializer(Serializer):
queryset=model_class._default_manager,
fields=unique_together
)
- validator.serializer_field = self
validators.append(validator)
for parent_class in model_class._meta.parents.keys():
@@ -438,7 +437,6 @@ class ModelSerializer(Serializer):
queryset=parent_class._default_manager,
fields=unique_together
)
- validator.serializer_field = self
validators.append(validator)
return validators
--
cgit v1.2.3
From 381771731f48c75e7d5951e353049cceec386512 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 1 Oct 2014 13:09:14 +0100
Subject: Use six.text_type instead of str everywhere
---
rest_framework/compat.py | 9 +++++----
rest_framework/fields.py | 22 +++++++++++-----------
rest_framework/filters.py | 3 ++-
rest_framework/generics.py | 5 +++--
rest_framework/parsers.py | 3 ++-
rest_framework/relations.py | 3 ++-
rest_framework/reverse.py | 3 ++-
rest_framework/utils/encoders.py | 6 +++---
8 files changed, 30 insertions(+), 24 deletions(-)
(limited to 'rest_framework')
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 89af9b48..3993cee6 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -5,11 +5,12 @@ versions of django/python, and compatibility wrappers around optional packages.
# flake8: noqa
from __future__ import unicode_literals
-import django
-import inspect
+
from django.core.exceptions import ImproperlyConfigured
from django.conf import settings
from django.utils import six
+import django
+import inspect
# Handle django.utils.encoding rename in 1.5 onwards.
@@ -177,12 +178,12 @@ class RequestFactory(DjangoRequestFactory):
r = {
'PATH_INFO': self._get_path(parsed),
'QUERY_STRING': force_text(parsed[4]),
- 'REQUEST_METHOD': str(method),
+ 'REQUEST_METHOD': six.text_type(method),
}
if data:
r.update({
'CONTENT_LENGTH': len(data),
- 'CONTENT_TYPE': str(content_type),
+ 'CONTENT_TYPE': six.text_type(content_type),
'wsgi.input': FakePayload(data),
})
elif django.VERSION <= (1, 4):
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index f7ea3b0c..f3ff2233 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -2,7 +2,7 @@ from django import forms
from django.conf import settings
from django.core import validators
from django.core.exceptions import ValidationError
-from django.utils import timezone
+from django.utils import six, timezone
from django.utils.datastructures import SortedDict
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import is_protected_type
@@ -431,10 +431,10 @@ class CharField(Field):
return super(CharField, self).run_validation(data)
def to_internal_value(self, data):
- return str(data)
+ return six.text_type(data)
def to_representation(self, value):
- return str(value)
+ return six.text_type(value)
class EmailField(CharField):
@@ -448,10 +448,10 @@ class EmailField(CharField):
self.validators.append(validator)
def to_internal_value(self, data):
- return str(data).strip()
+ return six.text_type(data).strip()
def to_representation(self, value):
- return str(value).strip()
+ return six.text_type(value).strip()
class RegexField(CharField):
@@ -510,7 +510,7 @@ class IntegerField(Field):
def to_internal_value(self, data):
try:
- data = int(str(data))
+ data = int(six.text_type(data))
except (ValueError, TypeError):
self.fail('invalid')
return data
@@ -616,7 +616,7 @@ class DecimalField(Field):
def to_representation(self, value):
if not isinstance(value, decimal.Decimal):
- value = decimal.Decimal(str(value).strip())
+ value = decimal.Decimal(six.text_type(value).strip())
context = decimal.getcontext().copy()
context.prec = self.max_digits
@@ -832,19 +832,19 @@ class ChoiceField(Field):
# Allows us to deal with eg. integer choices while supporting either
# integer or string input, but still get the correct datatype out.
self.choice_strings_to_values = dict([
- (str(key), key) for key in self.choices.keys()
+ (six.text_type(key), key) for key in self.choices.keys()
])
super(ChoiceField, self).__init__(**kwargs)
def to_internal_value(self, data):
try:
- return self.choice_strings_to_values[str(data)]
+ return self.choice_strings_to_values[six.text_type(data)]
except KeyError:
self.fail('invalid_choice', input=data)
def to_representation(self, value):
- return self.choice_strings_to_values[str(value)]
+ return self.choice_strings_to_values[six.text_type(value)]
class MultipleChoiceField(ChoiceField):
@@ -864,7 +864,7 @@ class MultipleChoiceField(ChoiceField):
def to_representation(self, value):
return set([
- self.choice_strings_to_values[str(item)] for item in value
+ self.choice_strings_to_values[six.text_type(item)] for item in value
])
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index 085dfe65..4c485668 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -3,6 +3,7 @@ Provides generic filtering backends that can be used to filter the results
returned by list views.
"""
from __future__ import unicode_literals
+
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.utils import six
@@ -97,7 +98,7 @@ class SearchFilter(BaseFilterBackend):
if not search_fields:
return queryset
- orm_lookups = [self.construct_search(str(search_field))
+ orm_lookups = [self.construct_search(six.text_type(search_field))
for search_field in search_fields]
for search_term in self.get_search_terms(request):
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index cf903dab..3d6cf168 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -3,10 +3,11 @@ Generic views that provide commonly needed behaviour.
"""
from __future__ import unicode_literals
-from django.db.models.query import QuerySet
from django.core.paginator import Paginator, InvalidPage
+from django.db.models.query import QuerySet
from django.http import Http404
from django.shortcuts import get_object_or_404 as _get_object_or_404
+from django.utils import six
from django.utils.translation import ugettext as _
from rest_framework import views, mixins
from rest_framework.settings import api_settings
@@ -127,7 +128,7 @@ class GenericAPIView(views.APIView):
error_format = _('Invalid page (%(page_number)s): %(message)s')
raise Http404(error_format % {
'page_number': page_number,
- 'message': str(exc)
+ 'message': six.text_type(exc)
})
return page
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index fa02ecf1..ccb82f03 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -5,6 +5,7 @@ They give us a generic way of being able to handle various media types
on the request, such as form content or json encoded data.
"""
from __future__ import unicode_literals
+
from django.conf import settings
from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict
@@ -132,7 +133,7 @@ class MultiPartParser(BaseParser):
data, files = parser.parse()
return DataAndFiles(data, files)
except MultiPartParserError as exc:
- raise ParseError('Multipart form parse error - %s' % str(exc))
+ raise ParseError('Multipart form parse error - %s' % six.text_type(exc))
class XMLParser(BaseParser):
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index b37a6fed..b5effc6c 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -4,6 +4,7 @@ from rest_framework.reverse import reverse
from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured
from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch, Resolver404
from django.db.models.query import QuerySet
+from django.utils import six
from django.utils.translation import ugettext_lazy as _
@@ -49,7 +50,7 @@ class StringRelatedField(Field):
super(StringRelatedField, self).__init__(**kwargs)
def to_representation(self, value):
- return str(value)
+ return six.text_type(value)
class PrimaryKeyRelatedField(RelatedField):
diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py
index a51b07f5..a74e8aa2 100644
--- a/rest_framework/reverse.py
+++ b/rest_framework/reverse.py
@@ -3,6 +3,7 @@ Provide reverse functions that return fully qualified URLs
"""
from __future__ import unicode_literals
from django.core.urlresolvers import reverse as django_reverse
+from django.utils import six
from django.utils.functional import lazy
@@ -20,4 +21,4 @@ def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra
return url
-reverse_lazy = lazy(reverse, str)
+reverse_lazy = lazy(reverse, six.text_type)
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 174b08b8..7c4179a1 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -2,8 +2,8 @@
Helper classes for parsers.
"""
from __future__ import unicode_literals
-from django.utils import timezone
from django.db.models.query import QuerySet
+from django.utils import six, timezone
from django.utils.datastructures import SortedDict
from django.utils.functional import Promise
from rest_framework.compat import force_text
@@ -40,7 +40,7 @@ class JSONEncoder(json.JSONEncoder):
representation = representation[:12]
return representation
elif isinstance(obj, datetime.timedelta):
- return str(obj.total_seconds())
+ return six.text_type(obj.total_seconds())
elif isinstance(obj, decimal.Decimal):
# Serializers will coerce decimals to strings by default.
return float(obj)
@@ -72,7 +72,7 @@ else:
than the usual behaviour of sorting the keys.
"""
def represent_decimal(self, data):
- return self.represent_scalar('tag:yaml.org,2002:str', str(data))
+ return self.represent_scalar('tag:yaml.org,2002:str', six.text_type(data))
def represent_mapping(self, tag, mapping, flow_style=None):
value = []
--
cgit v1.2.3
From c630a12e26f29145784523dd1b01ab0b3576f42c Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 1 Oct 2014 13:24:47 +0100
Subject: Deal with lazy strings in serializer reprs
---
rest_framework/utils/representation.py | 5 +++++
1 file changed, 5 insertions(+)
(limited to 'rest_framework')
diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py
index e64fdd22..180b51f8 100644
--- a/rest_framework/utils/representation.py
+++ b/rest_framework/utils/representation.py
@@ -3,6 +3,8 @@ Helper functions for creating user-friendly representations
of serializer classes and serializer fields.
"""
from django.db import models
+from django.utils.functional import Promise
+from rest_framework.compat import force_text
import re
@@ -19,6 +21,9 @@ def smart_repr(value):
if isinstance(value, models.Manager):
return manager_repr(value)
+ if isinstance(value, Promise) and value._delegate_text:
+ value = force_text(value)
+
value = repr(value)
# Representations like u'help text'
--
cgit v1.2.3
From c171fa21ac62538331755524057d2435f33ec8a5 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 1 Oct 2014 19:44:46 +0100
Subject: First pass at HTML form rendering
---
rest_framework/renderers.py | 47 ++++++++++++++++++++--
rest_framework/serializers.py | 2 +
.../templates/rest_framework/fields/attrs.html | 1 +
.../rest_framework/fields/horizontal/checkbox.html | 10 +++++
.../rest_framework/fields/horizontal/fieldset.html | 10 +++++
.../rest_framework/fields/horizontal/input.html | 7 ++++
.../rest_framework/fields/horizontal/label.html | 1 +
.../rest_framework/fields/horizontal/select.html | 10 +++++
.../fields/horizontal/select_checkbox.html | 22 ++++++++++
.../fields/horizontal/select_multiple.html | 10 +++++
.../fields/horizontal/select_radio.html | 22 ++++++++++
.../rest_framework/fields/horizontal/textarea.html | 7 ++++
.../rest_framework/fields/inline/checkbox.html | 6 +++
.../rest_framework/fields/inline/fieldset.html | 3 ++
.../rest_framework/fields/inline/input.html | 4 ++
.../rest_framework/fields/inline/label.html | 1 +
.../rest_framework/fields/inline/select.html | 8 ++++
.../fields/inline/select_checkbox.html | 11 +++++
.../fields/inline/select_multiple.html | 8 ++++
.../rest_framework/fields/inline/select_radio.html | 11 +++++
.../rest_framework/fields/inline/textarea.html | 4 ++
.../rest_framework/fields/vertical/checkbox.html | 6 +++
.../rest_framework/fields/vertical/fieldset.html | 6 +++
.../rest_framework/fields/vertical/input.html | 5 +++
.../rest_framework/fields/vertical/label.html | 1 +
.../rest_framework/fields/vertical/select.html | 8 ++++
.../fields/vertical/select_checkbox.html | 22 ++++++++++
.../fields/vertical/select_multiple.html | 8 ++++
.../fields/vertical/select_radio.html | 22 ++++++++++
.../rest_framework/fields/vertical/textarea.html | 5 +++
rest_framework/templates/rest_framework/form.html | 40 ++++++++++++------
rest_framework/templatetags/rest_framework.py | 8 ++++
rest_framework/utils/field_mapping.py | 3 ++
33 files changed, 324 insertions(+), 15 deletions(-)
create mode 100644 rest_framework/templates/rest_framework/fields/attrs.html
create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/checkbox.html
create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/fieldset.html
create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/input.html
create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/label.html
create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/select.html
create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html
create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html
create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/select_radio.html
create mode 100644 rest_framework/templates/rest_framework/fields/horizontal/textarea.html
create mode 100644 rest_framework/templates/rest_framework/fields/inline/checkbox.html
create mode 100644 rest_framework/templates/rest_framework/fields/inline/fieldset.html
create mode 100644 rest_framework/templates/rest_framework/fields/inline/input.html
create mode 100644 rest_framework/templates/rest_framework/fields/inline/label.html
create mode 100644 rest_framework/templates/rest_framework/fields/inline/select.html
create mode 100644 rest_framework/templates/rest_framework/fields/inline/select_checkbox.html
create mode 100644 rest_framework/templates/rest_framework/fields/inline/select_multiple.html
create mode 100644 rest_framework/templates/rest_framework/fields/inline/select_radio.html
create mode 100644 rest_framework/templates/rest_framework/fields/inline/textarea.html
create mode 100644 rest_framework/templates/rest_framework/fields/vertical/checkbox.html
create mode 100644 rest_framework/templates/rest_framework/fields/vertical/fieldset.html
create mode 100644 rest_framework/templates/rest_framework/fields/vertical/input.html
create mode 100644 rest_framework/templates/rest_framework/fields/vertical/label.html
create mode 100644 rest_framework/templates/rest_framework/fields/vertical/select.html
create mode 100644 rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html
create mode 100644 rest_framework/templates/rest_framework/fields/vertical/select_multiple.html
create mode 100644 rest_framework/templates/rest_framework/fields/vertical/select_radio.html
create mode 100644 rest_framework/templates/rest_framework/fields/vertical/textarea.html
(limited to 'rest_framework')
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 225f9fe8..6483a47c 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -13,17 +13,18 @@ import django
from django import forms
from django.core.exceptions import ImproperlyConfigured
from django.http.multipartparser import parse_header
-from django.template import RequestContext, loader, Template
+from django.template import Context, RequestContext, loader, Template
from django.test.client import encode_multipart
from django.utils import six
from django.utils.xmlutils import SimplerXMLGenerator
+from rest_framework import exceptions, serializers, status, VERSION
from rest_framework.compat import StringIO, smart_text, yaml
from rest_framework.exceptions import ParseError
from rest_framework.settings import api_settings
from rest_framework.request import is_form_media_type, override_method
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
-from rest_framework import exceptions, status, VERSION
+from rest_framework.utils.field_mapping import ClassLookupDict
def zero_as_none(value):
@@ -341,6 +342,42 @@ class HTMLFormRenderer(BaseRenderer):
template = 'rest_framework/form.html'
charset = 'utf-8'
+ field_templates = ClassLookupDict({
+ serializers.Field: {
+ 'default': 'input.html'
+ },
+ serializers.BooleanField: {
+ 'default': 'checkbox.html'
+ },
+ serializers.CharField: {
+ 'default': 'input.html',
+ 'textarea': 'textarea.html'
+ },
+ serializers.ChoiceField: {
+ 'default': 'select.html',
+ 'radio': 'select_radio.html'
+ },
+ serializers.MultipleChoiceField: {
+ 'default': 'select_multiple.html',
+ 'checkbox': 'select_checkbox.html'
+ }
+ })
+
+ def render_field(self, field, value, errors, layout=None):
+ layout = layout or 'vertical'
+ style_type = field.style.get('type', 'default')
+ if style_type == 'textarea' and layout == 'inline':
+ style_type = 'default'
+ base = self.field_templates[field][style_type]
+ template_name = 'rest_framework/fields/' + layout + '/' + base
+ template = loader.get_template(template_name)
+ context = Context({
+ 'field': field,
+ 'value': value,
+ 'errors': errors
+ })
+ return template.render(context)
+
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
Render serializer data and return an HTML form, as a string.
@@ -349,7 +386,11 @@ class HTMLFormRenderer(BaseRenderer):
request = renderer_context['request']
template = loader.get_template(self.template)
- context = RequestContext(request, {'form': data})
+ context = RequestContext(request, {
+ 'form': data,
+ 'layout': getattr(getattr(data, 'Meta', None), 'layout', 'vertical'),
+ 'renderer': self
+ })
return template.render(context)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 0faa5671..5da81247 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -302,6 +302,8 @@ class Serializer(BaseSerializer):
def __iter__(self):
errors = self.errors if hasattr(self, '_errors') else {}
for field in self.fields.values():
+ if field.read_only:
+ continue
value = self.data.get(field.field_name) if self.data else None
error = errors.get(field.field_name)
yield FieldResult(field, value, error)
diff --git a/rest_framework/templates/rest_framework/fields/attrs.html b/rest_framework/templates/rest_framework/fields/attrs.html
new file mode 100644
index 00000000..b5a4dbcf
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/attrs.html
@@ -0,0 +1 @@
+name="{{ field.field_name }}" {% if field.style.placeholder %}placeholder="{{ field.style.placeholder }}"{% endif %} {% if field.style.rows %}rows="{{ field.style.rows }}"{% endif %}
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/checkbox.html b/rest_framework/templates/rest_framework/fields/horizontal/checkbox.html
new file mode 100644
index 00000000..dce4a5cf
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/horizontal/checkbox.html
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html b/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html
new file mode 100644
index 00000000..86417633
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/horizontal/fieldset.html
@@ -0,0 +1,10 @@
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/input.html b/rest_framework/templates/rest_framework/fields/horizontal/input.html
new file mode 100644
index 00000000..310154bb
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/horizontal/input.html
@@ -0,0 +1,7 @@
+
+ {% include "rest_framework/fields/horizontal/label.html" %}
+
+
+ {% if field.help_text %}
{{ field.help_text }}
{% endif %}
+
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/label.html b/rest_framework/templates/rest_framework/fields/horizontal/label.html
new file mode 100644
index 00000000..bf21f78c
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/horizontal/label.html
@@ -0,0 +1 @@
+{% if field.label %}{% endif %}
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select.html b/rest_framework/templates/rest_framework/fields/horizontal/select.html
new file mode 100644
index 00000000..3f8cab0a
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/horizontal/select.html
@@ -0,0 +1,10 @@
+
+ {% include "rest_framework/fields/horizontal/label.html" %}
+
+
+
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html b/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html
new file mode 100644
index 00000000..659eede8
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/horizontal/select_checkbox.html
@@ -0,0 +1,22 @@
+
+ {% include "rest_framework/fields/horizontal/label.html" %}
+
+ {% if field.style.inline %}
+ {% for key, text in field.choices.items %}
+
+ {% endfor %}
+ {% else %}
+ {% for key, text in field.choices.items %}
+
+
+
+ {% endfor %}
+ {% endif %}
+
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html b/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html
new file mode 100644
index 00000000..da25eb2b
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/horizontal/select_multiple.html
@@ -0,0 +1,10 @@
+
+ {% include "rest_framework/fields/horizontal/label.html" %}
+
+
+
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/select_radio.html b/rest_framework/templates/rest_framework/fields/horizontal/select_radio.html
new file mode 100644
index 00000000..188f05e2
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/horizontal/select_radio.html
@@ -0,0 +1,22 @@
+
+ {% include "rest_framework/fields/horizontal/label.html" %}
+
+ {% if field.style.inline %}
+ {% for key, text in field.choices.items %}
+
+ {% endfor %}
+ {% else %}
+ {% for key, text in field.choices.items %}
+
+
+
+ {% endfor %}
+ {% endif %}
+
+
diff --git a/rest_framework/templates/rest_framework/fields/horizontal/textarea.html b/rest_framework/templates/rest_framework/fields/horizontal/textarea.html
new file mode 100644
index 00000000..e99266f3
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/horizontal/textarea.html
@@ -0,0 +1,7 @@
+
+ {% include "rest_framework/fields/horizontal/label.html" %}
+
+
+ {% if field.help_text %}
{{ field.help_text }}
{% endif %}
+
+
diff --git a/rest_framework/templates/rest_framework/fields/inline/checkbox.html b/rest_framework/templates/rest_framework/fields/inline/checkbox.html
new file mode 100644
index 00000000..01d30aae
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/inline/checkbox.html
@@ -0,0 +1,6 @@
+
+
+
diff --git a/rest_framework/templates/rest_framework/fields/inline/fieldset.html b/rest_framework/templates/rest_framework/fields/inline/fieldset.html
new file mode 100644
index 00000000..d22982fd
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/inline/fieldset.html
@@ -0,0 +1,3 @@
+{% for field_item in value.field_items.values() %}
+ {{ renderer.render_field(field_item, layout=layout) }}
+{% endfor %}
diff --git a/rest_framework/templates/rest_framework/fields/inline/input.html b/rest_framework/templates/rest_framework/fields/inline/input.html
new file mode 100644
index 00000000..aefd1672
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/inline/input.html
@@ -0,0 +1,4 @@
+
+ {% include "rest_framework/fields/inline/label.html" %}
+
+
diff --git a/rest_framework/templates/rest_framework/fields/inline/label.html b/rest_framework/templates/rest_framework/fields/inline/label.html
new file mode 100644
index 00000000..7d546a57
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/inline/label.html
@@ -0,0 +1 @@
+{% if field.label %}{% endif %}
diff --git a/rest_framework/templates/rest_framework/fields/inline/select.html b/rest_framework/templates/rest_framework/fields/inline/select.html
new file mode 100644
index 00000000..cb9a7013
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/inline/select.html
@@ -0,0 +1,8 @@
+
+ {% include "rest_framework/fields/inline/label.html" %}
+
+
diff --git a/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html b/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html
new file mode 100644
index 00000000..424df93e
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/inline/select_checkbox.html
@@ -0,0 +1,11 @@
+
+ {% include "rest_framework/fields/inline/label.html" %}
+ {% for key, text in field.choices.items %}
+
+
+
+ {% endfor %}
+
diff --git a/rest_framework/templates/rest_framework/fields/inline/select_multiple.html b/rest_framework/templates/rest_framework/fields/inline/select_multiple.html
new file mode 100644
index 00000000..6fdfd672
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/inline/select_multiple.html
@@ -0,0 +1,8 @@
+
+ {% include "rest_framework/fields/inline/label.html" %}
+
+
diff --git a/rest_framework/templates/rest_framework/fields/inline/select_radio.html b/rest_framework/templates/rest_framework/fields/inline/select_radio.html
new file mode 100644
index 00000000..ddabc9e9
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/inline/select_radio.html
@@ -0,0 +1,11 @@
+
+ {% include "rest_framework/fields/inline/label.html" %}
+ {% for key, text in field.choices.items %}
+
+
+
+ {% endfor %}
+
diff --git a/rest_framework/templates/rest_framework/fields/inline/textarea.html b/rest_framework/templates/rest_framework/fields/inline/textarea.html
new file mode 100644
index 00000000..31366809
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/inline/textarea.html
@@ -0,0 +1,4 @@
+
+ {% include "rest_framework/fields/inline/label.html" %}
+
+
diff --git a/rest_framework/templates/rest_framework/fields/vertical/checkbox.html b/rest_framework/templates/rest_framework/fields/vertical/checkbox.html
new file mode 100644
index 00000000..01d30aae
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/vertical/checkbox.html
@@ -0,0 +1,6 @@
+
+
+
diff --git a/rest_framework/templates/rest_framework/fields/vertical/fieldset.html b/rest_framework/templates/rest_framework/fields/vertical/fieldset.html
new file mode 100644
index 00000000..cad32df9
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/vertical/fieldset.html
@@ -0,0 +1,6 @@
+
diff --git a/rest_framework/templates/rest_framework/fields/vertical/input.html b/rest_framework/templates/rest_framework/fields/vertical/input.html
new file mode 100644
index 00000000..c25407d1
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/vertical/input.html
@@ -0,0 +1,5 @@
+
+ {% include "rest_framework/fields/vertical/label.html" %}
+
+ {% if field.help_text %}
{{ field.help_text }}
{% endif %}
+
diff --git a/rest_framework/templates/rest_framework/fields/vertical/label.html b/rest_framework/templates/rest_framework/fields/vertical/label.html
new file mode 100644
index 00000000..651939b2
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/vertical/label.html
@@ -0,0 +1 @@
+{% if field.label %}{% endif %}
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select.html b/rest_framework/templates/rest_framework/fields/vertical/select.html
new file mode 100644
index 00000000..44679d8a
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/vertical/select.html
@@ -0,0 +1,8 @@
+
+ {% include "rest_framework/fields/vertical/label.html" %}
+
+
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html b/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html
new file mode 100644
index 00000000..e60574c0
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/vertical/select_checkbox.html
@@ -0,0 +1,22 @@
+
+ {% include "rest_framework/fields/vertical/label.html" %}
+ {% if field.style.inline %}
+
+ {% for key, text in field.choices.items %}
+
+ {% endfor %}
+
+ {% else %}
+ {% for key, text in field.choices.items %}
+
+
+
+ {% endfor %}
+ {% endif %}
+
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html b/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html
new file mode 100644
index 00000000..f0fa418b
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/vertical/select_multiple.html
@@ -0,0 +1,8 @@
+
+ {% include "rest_framework/fields/vertical/label.html" %}
+
+
diff --git a/rest_framework/templates/rest_framework/fields/vertical/select_radio.html b/rest_framework/templates/rest_framework/fields/vertical/select_radio.html
new file mode 100644
index 00000000..4ffe38ea
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/vertical/select_radio.html
@@ -0,0 +1,22 @@
+
+ {% include "rest_framework/fields/vertical/label.html" %}
+ {% if field.style.inline %}
+
+ {% for key, text in field.choices.items %}
+
+ {% endfor %}
+
+ {% else %}
+ {% for key, text in field.choices.items %}
+
+
+
+ {% endfor %}
+ {% endif %}
+
diff --git a/rest_framework/templates/rest_framework/fields/vertical/textarea.html b/rest_framework/templates/rest_framework/fields/vertical/textarea.html
new file mode 100644
index 00000000..33cb27c7
--- /dev/null
+++ b/rest_framework/templates/rest_framework/fields/vertical/textarea.html
@@ -0,0 +1,5 @@
+
+ {% include "rest_framework/fields/vertical/label.html" %}
+
+ {% if field.help_text %}