aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/utils
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/utils')
-rw-r--r--rest_framework/utils/encoders.py21
-rw-r--r--rest_framework/utils/field_mapping.py114
-rw-r--r--rest_framework/utils/html.py8
-rw-r--r--rest_framework/utils/model_meta.py23
-rw-r--r--rest_framework/utils/representation.py10
-rw-r--r--rest_framework/utils/serializer_helpers.py102
6 files changed, 210 insertions, 68 deletions
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 174b08b8..4d6bb3a3 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -2,11 +2,10 @@
Helper classes for parsers.
"""
from __future__ import unicode_literals
-from django.utils import timezone
from django.db.models.query import QuerySet
-from django.utils.datastructures import SortedDict
+from django.utils import six, timezone
from django.utils.functional import Promise
-from rest_framework.compat import force_text
+from rest_framework.compat import force_text, OrderedDict
import datetime
import decimal
import types
@@ -40,12 +39,12 @@ class JSONEncoder(json.JSONEncoder):
representation = representation[:12]
return representation
elif isinstance(obj, datetime.timedelta):
- return str(obj.total_seconds())
+ return six.text_type(obj.total_seconds())
elif isinstance(obj, decimal.Decimal):
# Serializers will coerce decimals to strings by default.
return float(obj)
elif isinstance(obj, QuerySet):
- return list(obj)
+ return tuple(obj)
elif hasattr(obj, 'tolist'):
# Numpy arrays and array scalars.
return obj.tolist()
@@ -55,7 +54,7 @@ class JSONEncoder(json.JSONEncoder):
except:
pass
elif hasattr(obj, '__iter__'):
- return [item for item in obj]
+ return tuple(item for item in obj)
return super(JSONEncoder, self).default(obj)
@@ -68,11 +67,11 @@ else:
class SafeDumper(yaml.SafeDumper):
"""
Handles decimals as strings.
- Handles SortedDicts as usual dicts, but preserves field order, rather
+ Handles OrderedDicts as usual dicts, but preserves field order, rather
than the usual behaviour of sorting the keys.
"""
def represent_decimal(self, data):
- return self.represent_scalar('tag:yaml.org,2002:str', str(data))
+ return self.represent_scalar('tag:yaml.org,2002:str', six.text_type(data))
def represent_mapping(self, tag, mapping, flow_style=None):
value = []
@@ -82,7 +81,7 @@ else:
best_style = True
if hasattr(mapping, 'items'):
mapping = list(mapping.items())
- if not isinstance(mapping, SortedDict):
+ if not isinstance(mapping, OrderedDict):
mapping.sort()
for item_key, item_value in mapping:
node_key = self.represent_data(item_key)
@@ -104,7 +103,7 @@ else:
SafeDumper.represent_decimal
)
SafeDumper.add_representer(
- SortedDict,
+ OrderedDict,
yaml.representer.SafeRepresenter.represent_dict
)
# SafeDumper.add_representer(
@@ -112,7 +111,7 @@ else:
# yaml.representer.SafeRepresenter.represent_dict
# )
# SafeDumper.add_representer(
- # SortedDictWithMetadata,
+ # OrderedDictWithMetadata,
# yaml.representer.SafeRepresenter.represent_dict
# )
SafeDumper.add_representer(
diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py
index be72e444..9c187176 100644
--- a/rest_framework/utils/field_mapping.py
+++ b/rest_framework/utils/field_mapping.py
@@ -6,20 +6,32 @@ from django.core import validators
from django.db import models
from django.utils.text import capfirst
from rest_framework.compat import clean_manytomany_helptext
+from rest_framework.validators import UniqueValidator
import inspect
-def lookup_class(mapping, instance):
+class ClassLookupDict(object):
"""
- Takes a dictionary with classes as keys, and an object.
- Traverses the object's inheritance hierarchy in method
- resolution order, and returns the first matching value
+ Takes a dictionary with classes as keys.
+ Lookups against this object will traverses the object's inheritance
+ hierarchy in method resolution order, and returns the first matching value
from the dictionary or raises a KeyError if nothing matches.
"""
- for cls in inspect.getmro(instance.__class__):
- if cls in mapping:
- return mapping[cls]
- raise KeyError('Class %s not found in lookup.', cls.__name__)
+ def __init__(self, mapping):
+ self.mapping = mapping
+
+ def __getitem__(self, key):
+ if hasattr(key, '_proxy_class'):
+ # Deal with proxy classes. Ie. BoundField behaves as if it
+ # is a Field instance when using ClassLookupDict.
+ base_class = key._proxy_class
+ else:
+ base_class = key.__class__
+
+ for cls in inspect.getmro(base_class):
+ if cls in self.mapping:
+ return self.mapping[cls]
+ raise KeyError('Class %s not found in lookup.', cls.__name__)
def needs_label(model_field, field_name):
@@ -49,8 +61,9 @@ def get_field_kwargs(field_name, model_field):
kwargs = {}
validator_kwarg = model_field.validators
- if model_field.null or model_field.blank:
- kwargs['required'] = False
+ # The following will only be used by ModelField classes.
+ # Gets removed for everything else.
+ kwargs['model_field'] = model_field
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
@@ -58,24 +71,38 @@ def get_field_kwargs(field_name, model_field):
if model_field.help_text:
kwargs['help_text'] = model_field.help_text
+ max_digits = getattr(model_field, 'max_digits', None)
+ if max_digits is not None:
+ kwargs['max_digits'] = max_digits
+
+ decimal_places = getattr(model_field, 'decimal_places', None)
+ if decimal_places is not None:
+ kwargs['decimal_places'] = decimal_places
+
+ if isinstance(model_field, models.TextField):
+ kwargs['style'] = {'type': 'textarea'}
+
if isinstance(model_field, models.AutoField) or not model_field.editable:
+ # If this field is read-only, then return early.
+ # Further keyword arguments are not valid.
kwargs['read_only'] = True
- # Read only implies that the field is not required.
- # We have a cleaner repr on the instance if we don't set it.
- kwargs.pop('required', None)
+ return kwargs
- if model_field.has_default():
- kwargs['default'] = model_field.get_default()
- # Having a default implies that the field is not required.
- # We have a cleaner repr on the instance if we don't set it.
- kwargs.pop('required', None)
+ if model_field.has_default() or model_field.blank or model_field.null:
+ kwargs['required'] = False
if model_field.flatchoices:
- # If this model field contains choices, then return now,
- # any further keyword arguments are not valid.
+ # If this model field contains choices, then return early.
+ # Further keyword arguments are not valid.
kwargs['choices'] = model_field.flatchoices
return kwargs
+ if model_field.null and not isinstance(model_field, models.NullBooleanField):
+ kwargs['allow_null'] = True
+
+ if model_field.blank:
+ kwargs['allow_blank'] = True
+
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
@@ -88,7 +115,10 @@ def get_field_kwargs(field_name, model_field):
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
- min_length = getattr(model_field, 'min_length', None)
+ min_length = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MinLengthValidator)
+ ), None)
if min_length is not None:
kwargs['min_length'] = min_length
validator_kwarg = [
@@ -145,28 +175,13 @@ def get_field_kwargs(field_name, model_field):
if validator is not validators.validate_slug
]
- max_digits = getattr(model_field, 'max_digits', None)
- if max_digits is not None:
- kwargs['max_digits'] = max_digits
-
- decimal_places = getattr(model_field, 'decimal_places', None)
- if decimal_places is not None:
- kwargs['decimal_places'] = decimal_places
-
- if isinstance(model_field, models.BooleanField):
- # models.BooleanField has `blank=True`, but *is* actually
- # required *unless* a default is provided.
- # Also note that Django<1.6 uses `default=False` for
- # models.BooleanField, but Django>=1.6 uses `default=None`.
- kwargs.pop('required', None)
+ if getattr(model_field, 'unique', False):
+ validator = UniqueValidator(queryset=model_field.model._default_manager)
+ validator_kwarg.append(validator)
if validator_kwarg:
kwargs['validators'] = validator_kwarg
- # The following will only be used by ModelField classes.
- # Gets removed for everything else.
- kwargs['model_field'] = model_field
-
return kwargs
@@ -188,16 +203,27 @@ def get_relation_kwargs(field_name, relation_info):
kwargs.pop('queryset', None)
if model_field:
- if model_field.null or model_field.blank:
- kwargs['required'] = False
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
- if not model_field.editable:
- kwargs['read_only'] = True
- kwargs.pop('queryset', None)
help_text = clean_manytomany_helptext(model_field.help_text)
if help_text:
kwargs['help_text'] = help_text
+ if not model_field.editable:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
+ if kwargs.get('read_only', False):
+ # If this field is read-only, then return early.
+ # No further keyword arguments are valid.
+ return kwargs
+ if model_field.has_default() or model_field.null:
+ kwargs['required'] = False
+ if model_field.null:
+ kwargs['allow_null'] = True
+ if model_field.validators:
+ kwargs['validators'] = model_field.validators
+ if getattr(model_field, 'unique', False):
+ validator = UniqueValidator(queryset=model_field.model._default_manager)
+ kwargs['validators'] = kwargs.get('validators', []) + [validator]
return kwargs
diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py
index edc591e9..d773952d 100644
--- a/rest_framework/utils/html.py
+++ b/rest_framework/utils/html.py
@@ -2,6 +2,7 @@
Helpers for dealing with HTML input.
"""
import re
+from django.utils.datastructures import MultiValueDict
def is_html_input(dictionary):
@@ -35,7 +36,7 @@ def parse_html_list(dictionary, prefix=''):
'[0]foo': 'abc',
'[0]bar': 'def',
'[1]foo': 'hij',
- '[2]bar': 'klm',
+ '[1]bar': 'klm',
}
-->
[
@@ -43,7 +44,6 @@ def parse_html_list(dictionary, prefix=''):
{'foo': 'hij', 'bar': 'klm'}
]
"""
- Dict = type(dictionary)
ret = {}
regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix))
for field, value in dictionary.items():
@@ -57,7 +57,7 @@ def parse_html_list(dictionary, prefix=''):
elif isinstance(ret.get(index), dict):
ret[index][key] = value
else:
- ret[index] = Dict({key: value})
+ ret[index] = MultiValueDict({key: [value]})
return [ret[item] for item in sorted(ret.keys())]
@@ -72,7 +72,7 @@ def parse_html_dict(dictionary, prefix):
-->
{
'profile': {
- 'username': 'example,
+ 'username': 'example',
'email': 'example@example.com'
}
}
diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py
index b6c41174..c98725c6 100644
--- a/rest_framework/utils/model_meta.py
+++ b/rest_framework/utils/model_meta.py
@@ -6,9 +6,10 @@ relationships and their associated metadata.
Usage: `get_field_info(model)` returns a `FieldInfo` instance.
"""
from collections import namedtuple
+from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.utils import six
-from django.utils.datastructures import SortedDict
+from rest_framework.compat import OrderedDict
import inspect
@@ -43,7 +44,11 @@ def _resolve_model(obj):
"""
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
- return models.get_model(app_name, model_name)
+ resolved_model = models.get_model(app_name, model_name)
+ if resolved_model is None:
+ msg = "Django did not return a model for {0}.{1}"
+ raise ImproperlyConfigured(msg.format(app_name, model_name))
+ return resolved_model
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
raise ValueError("{0} is not a Django model".format(obj))
@@ -63,12 +68,12 @@ def get_field_info(model):
pk = pk.rel.to._meta.pk
# Deal with regular fields.
- fields = SortedDict()
+ fields = OrderedDict()
for field in [field for field in opts.fields if field.serialize and not field.rel]:
fields[field.name] = field
# Deal with forward relationships.
- forward_relations = SortedDict()
+ forward_relations = OrderedDict()
for field in [field for field in opts.fields if field.serialize and field.rel]:
forward_relations[field.name] = RelationInfo(
model_field=field,
@@ -89,7 +94,7 @@ def get_field_info(model):
)
# Deal with reverse relationships.
- reverse_relations = SortedDict()
+ reverse_relations = OrderedDict()
for relation in opts.get_all_related_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
@@ -107,21 +112,21 @@ def get_field_info(model):
related=relation.model,
to_many=True,
has_through_model=(
- hasattr(relation.field.rel, 'through') and
- not relation.field.rel.through._meta.auto_created
+ (getattr(relation.field.rel, 'through', None) is not None)
+ and not relation.field.rel.through._meta.auto_created
)
)
# Shortcut that merges both regular fields and the pk,
# for simplifying regular field lookup.
- fields_and_pk = SortedDict()
+ fields_and_pk = OrderedDict()
fields_and_pk['pk'] = pk
fields_and_pk[pk.name] = pk
fields_and_pk.update(fields)
# Shortcut that merges both forward and reverse relationships
- relations = SortedDict(
+ relations = OrderedDict(
list(forward_relations.items()) +
list(reverse_relations.items())
)
diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py
index e64fdd22..2a7c4675 100644
--- a/rest_framework/utils/representation.py
+++ b/rest_framework/utils/representation.py
@@ -3,6 +3,8 @@ Helper functions for creating user-friendly representations
of serializer classes and serializer fields.
"""
from django.db import models
+from django.utils.functional import Promise
+from rest_framework.compat import force_text
import re
@@ -19,6 +21,9 @@ def smart_repr(value):
if isinstance(value, models.Manager):
return manager_repr(value)
+ if isinstance(value, Promise) and value._delegate_text:
+ value = force_text(value)
+
value = repr(value)
# Representations like u'help text'
@@ -77,6 +82,11 @@ def serializer_repr(serializer, indent, force_many=None):
ret += field_repr(field.child_relation, force_many=field.child_relation)
else:
ret += field_repr(field)
+
+ if serializer.validators:
+ ret += '\n' + indent_str + 'class Meta:'
+ ret += '\n' + indent_str + ' validators = ' + smart_repr(serializer.validators)
+
return ret
diff --git a/rest_framework/utils/serializer_helpers.py b/rest_framework/utils/serializer_helpers.py
new file mode 100644
index 00000000..92d19857
--- /dev/null
+++ b/rest_framework/utils/serializer_helpers.py
@@ -0,0 +1,102 @@
+from rest_framework.compat import OrderedDict
+
+
+class ReturnDict(OrderedDict):
+ """
+ Return object from `serialier.data` for the `Serializer` class.
+ Includes a backlink to the serializer instance for renderers
+ to use if they need richer field information.
+ """
+ def __init__(self, *args, **kwargs):
+ self.serializer = kwargs.pop('serializer')
+ super(ReturnDict, self).__init__(*args, **kwargs)
+
+ def copy(self):
+ return ReturnDict(self, serializer=self.serializer)
+
+
+class ReturnList(list):
+ """
+ Return object from `serialier.data` for the `SerializerList` class.
+ Includes a backlink to the serializer instance for renderers
+ to use if they need richer field information.
+ """
+ def __init__(self, *args, **kwargs):
+ self.serializer = kwargs.pop('serializer')
+ super(ReturnList, self).__init__(*args, **kwargs)
+
+
+class BoundField(object):
+ """
+ A field object that also includes `.value` and `.error` properties.
+ Returned when iterating over a serializer instance,
+ providing an API similar to Django forms and form fields.
+ """
+ def __init__(self, field, value, errors, prefix=''):
+ self._field = field
+ self.value = value
+ self.errors = errors
+ self.name = prefix + self.field_name
+
+ def __getattr__(self, attr_name):
+ return getattr(self._field, attr_name)
+
+ @property
+ def _proxy_class(self):
+ return self._field.__class__
+
+ def __repr__(self):
+ return '<%s value=%s errors=%s>' % (
+ self.__class__.__name__, self.value, self.errors
+ )
+
+
+class NestedBoundField(BoundField):
+ """
+ This `BoundField` additionally implements __iter__ and __getitem__
+ in order to support nested bound fields. This class is the type of
+ `BoundField` that is used for serializer fields.
+ """
+ def __iter__(self):
+ for field in self.fields.values():
+ yield self[field.field_name]
+
+ def __getitem__(self, key):
+ field = self.fields[key]
+ value = self.value.get(key) if self.value else None
+ error = self.errors.get(key) if self.errors else None
+ if hasattr(field, 'fields'):
+ return NestedBoundField(field, value, error, prefix=self.name + '.')
+ return BoundField(field, value, error, prefix=self.name + '.')
+
+
+class BindingDict(object):
+ """
+ This dict-like object is used to store fields on a serializer.
+
+ This ensures that whenever fields are added to the serializer we call
+ `field.bind()` so that the `field_name` and `parent` attributes
+ can be set correctly.
+ """
+ def __init__(self, serializer):
+ self.serializer = serializer
+ self.fields = OrderedDict()
+
+ def __setitem__(self, key, field):
+ self.fields[key] = field
+ field.bind(field_name=key, parent=self.serializer)
+
+ def __getitem__(self, key):
+ return self.fields[key]
+
+ def __delitem__(self, key):
+ del self.fields[key]
+
+ def items(self):
+ return self.fields.items()
+
+ def keys(self):
+ return self.fields.keys()
+
+ def values(self):
+ return self.fields.values()