diff options
Diffstat (limited to 'rest_framework/utils')
| -rw-r--r-- | rest_framework/utils/encoders.py | 21 | ||||
| -rw-r--r-- | rest_framework/utils/field_mapping.py | 114 | ||||
| -rw-r--r-- | rest_framework/utils/html.py | 8 | ||||
| -rw-r--r-- | rest_framework/utils/model_meta.py | 23 | ||||
| -rw-r--r-- | rest_framework/utils/representation.py | 10 | ||||
| -rw-r--r-- | rest_framework/utils/serializer_helpers.py | 102 |
6 files changed, 210 insertions, 68 deletions
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 174b08b8..4d6bb3a3 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -2,11 +2,10 @@ Helper classes for parsers. """ from __future__ import unicode_literals -from django.utils import timezone from django.db.models.query import QuerySet -from django.utils.datastructures import SortedDict +from django.utils import six, timezone from django.utils.functional import Promise -from rest_framework.compat import force_text +from rest_framework.compat import force_text, OrderedDict import datetime import decimal import types @@ -40,12 +39,12 @@ class JSONEncoder(json.JSONEncoder): representation = representation[:12] return representation elif isinstance(obj, datetime.timedelta): - return str(obj.total_seconds()) + return six.text_type(obj.total_seconds()) elif isinstance(obj, decimal.Decimal): # Serializers will coerce decimals to strings by default. return float(obj) elif isinstance(obj, QuerySet): - return list(obj) + return tuple(obj) elif hasattr(obj, 'tolist'): # Numpy arrays and array scalars. return obj.tolist() @@ -55,7 +54,7 @@ class JSONEncoder(json.JSONEncoder): except: pass elif hasattr(obj, '__iter__'): - return [item for item in obj] + return tuple(item for item in obj) return super(JSONEncoder, self).default(obj) @@ -68,11 +67,11 @@ else: class SafeDumper(yaml.SafeDumper): """ Handles decimals as strings. - Handles SortedDicts as usual dicts, but preserves field order, rather + Handles OrderedDicts as usual dicts, but preserves field order, rather than the usual behaviour of sorting the keys. """ def represent_decimal(self, data): - return self.represent_scalar('tag:yaml.org,2002:str', str(data)) + return self.represent_scalar('tag:yaml.org,2002:str', six.text_type(data)) def represent_mapping(self, tag, mapping, flow_style=None): value = [] @@ -82,7 +81,7 @@ else: best_style = True if hasattr(mapping, 'items'): mapping = list(mapping.items()) - if not isinstance(mapping, SortedDict): + if not isinstance(mapping, OrderedDict): mapping.sort() for item_key, item_value in mapping: node_key = self.represent_data(item_key) @@ -104,7 +103,7 @@ else: SafeDumper.represent_decimal ) SafeDumper.add_representer( - SortedDict, + OrderedDict, yaml.representer.SafeRepresenter.represent_dict ) # SafeDumper.add_representer( @@ -112,7 +111,7 @@ else: # yaml.representer.SafeRepresenter.represent_dict # ) # SafeDumper.add_representer( - # SortedDictWithMetadata, + # OrderedDictWithMetadata, # yaml.representer.SafeRepresenter.represent_dict # ) SafeDumper.add_representer( diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index be72e444..9c187176 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -6,20 +6,32 @@ from django.core import validators from django.db import models from django.utils.text import capfirst from rest_framework.compat import clean_manytomany_helptext +from rest_framework.validators import UniqueValidator import inspect -def lookup_class(mapping, instance): +class ClassLookupDict(object): """ - Takes a dictionary with classes as keys, and an object. - Traverses the object's inheritance hierarchy in method - resolution order, and returns the first matching value + Takes a dictionary with classes as keys. + Lookups against this object will traverses the object's inheritance + hierarchy in method resolution order, and returns the first matching value from the dictionary or raises a KeyError if nothing matches. """ - for cls in inspect.getmro(instance.__class__): - if cls in mapping: - return mapping[cls] - raise KeyError('Class %s not found in lookup.', cls.__name__) + def __init__(self, mapping): + self.mapping = mapping + + def __getitem__(self, key): + if hasattr(key, '_proxy_class'): + # Deal with proxy classes. Ie. BoundField behaves as if it + # is a Field instance when using ClassLookupDict. + base_class = key._proxy_class + else: + base_class = key.__class__ + + for cls in inspect.getmro(base_class): + if cls in self.mapping: + return self.mapping[cls] + raise KeyError('Class %s not found in lookup.', cls.__name__) def needs_label(model_field, field_name): @@ -49,8 +61,9 @@ def get_field_kwargs(field_name, model_field): kwargs = {} validator_kwarg = model_field.validators - if model_field.null or model_field.blank: - kwargs['required'] = False + # The following will only be used by ModelField classes. + # Gets removed for everything else. + kwargs['model_field'] = model_field if model_field.verbose_name and needs_label(model_field, field_name): kwargs['label'] = capfirst(model_field.verbose_name) @@ -58,24 +71,38 @@ def get_field_kwargs(field_name, model_field): if model_field.help_text: kwargs['help_text'] = model_field.help_text + max_digits = getattr(model_field, 'max_digits', None) + if max_digits is not None: + kwargs['max_digits'] = max_digits + + decimal_places = getattr(model_field, 'decimal_places', None) + if decimal_places is not None: + kwargs['decimal_places'] = decimal_places + + if isinstance(model_field, models.TextField): + kwargs['style'] = {'type': 'textarea'} + if isinstance(model_field, models.AutoField) or not model_field.editable: + # If this field is read-only, then return early. + # Further keyword arguments are not valid. kwargs['read_only'] = True - # Read only implies that the field is not required. - # We have a cleaner repr on the instance if we don't set it. - kwargs.pop('required', None) + return kwargs - if model_field.has_default(): - kwargs['default'] = model_field.get_default() - # Having a default implies that the field is not required. - # We have a cleaner repr on the instance if we don't set it. - kwargs.pop('required', None) + if model_field.has_default() or model_field.blank or model_field.null: + kwargs['required'] = False if model_field.flatchoices: - # If this model field contains choices, then return now, - # any further keyword arguments are not valid. + # If this model field contains choices, then return early. + # Further keyword arguments are not valid. kwargs['choices'] = model_field.flatchoices return kwargs + if model_field.null and not isinstance(model_field, models.NullBooleanField): + kwargs['allow_null'] = True + + if model_field.blank: + kwargs['allow_blank'] = True + # Ensure that max_length is passed explicitly as a keyword arg, # rather than as a validator. max_length = getattr(model_field, 'max_length', None) @@ -88,7 +115,10 @@ def get_field_kwargs(field_name, model_field): # Ensure that min_length is passed explicitly as a keyword arg, # rather than as a validator. - min_length = getattr(model_field, 'min_length', None) + min_length = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MinLengthValidator) + ), None) if min_length is not None: kwargs['min_length'] = min_length validator_kwarg = [ @@ -145,28 +175,13 @@ def get_field_kwargs(field_name, model_field): if validator is not validators.validate_slug ] - max_digits = getattr(model_field, 'max_digits', None) - if max_digits is not None: - kwargs['max_digits'] = max_digits - - decimal_places = getattr(model_field, 'decimal_places', None) - if decimal_places is not None: - kwargs['decimal_places'] = decimal_places - - if isinstance(model_field, models.BooleanField): - # models.BooleanField has `blank=True`, but *is* actually - # required *unless* a default is provided. - # Also note that Django<1.6 uses `default=False` for - # models.BooleanField, but Django>=1.6 uses `default=None`. - kwargs.pop('required', None) + if getattr(model_field, 'unique', False): + validator = UniqueValidator(queryset=model_field.model._default_manager) + validator_kwarg.append(validator) if validator_kwarg: kwargs['validators'] = validator_kwarg - # The following will only be used by ModelField classes. - # Gets removed for everything else. - kwargs['model_field'] = model_field - return kwargs @@ -188,16 +203,27 @@ def get_relation_kwargs(field_name, relation_info): kwargs.pop('queryset', None) if model_field: - if model_field.null or model_field.blank: - kwargs['required'] = False if model_field.verbose_name and needs_label(model_field, field_name): kwargs['label'] = capfirst(model_field.verbose_name) - if not model_field.editable: - kwargs['read_only'] = True - kwargs.pop('queryset', None) help_text = clean_manytomany_helptext(model_field.help_text) if help_text: kwargs['help_text'] = help_text + if not model_field.editable: + kwargs['read_only'] = True + kwargs.pop('queryset', None) + if kwargs.get('read_only', False): + # If this field is read-only, then return early. + # No further keyword arguments are valid. + return kwargs + if model_field.has_default() or model_field.null: + kwargs['required'] = False + if model_field.null: + kwargs['allow_null'] = True + if model_field.validators: + kwargs['validators'] = model_field.validators + if getattr(model_field, 'unique', False): + validator = UniqueValidator(queryset=model_field.model._default_manager) + kwargs['validators'] = kwargs.get('validators', []) + [validator] return kwargs diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py index edc591e9..d773952d 100644 --- a/rest_framework/utils/html.py +++ b/rest_framework/utils/html.py @@ -2,6 +2,7 @@ Helpers for dealing with HTML input. """ import re +from django.utils.datastructures import MultiValueDict def is_html_input(dictionary): @@ -35,7 +36,7 @@ def parse_html_list(dictionary, prefix=''): '[0]foo': 'abc', '[0]bar': 'def', '[1]foo': 'hij', - '[2]bar': 'klm', + '[1]bar': 'klm', } --> [ @@ -43,7 +44,6 @@ def parse_html_list(dictionary, prefix=''): {'foo': 'hij', 'bar': 'klm'} ] """ - Dict = type(dictionary) ret = {} regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix)) for field, value in dictionary.items(): @@ -57,7 +57,7 @@ def parse_html_list(dictionary, prefix=''): elif isinstance(ret.get(index), dict): ret[index][key] = value else: - ret[index] = Dict({key: value}) + ret[index] = MultiValueDict({key: [value]}) return [ret[item] for item in sorted(ret.keys())] @@ -72,7 +72,7 @@ def parse_html_dict(dictionary, prefix): --> { 'profile': { - 'username': 'example, + 'username': 'example', 'email': 'example@example.com' } } diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index b6c41174..c98725c6 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -6,9 +6,10 @@ relationships and their associated metadata. Usage: `get_field_info(model)` returns a `FieldInfo` instance. """ from collections import namedtuple +from django.core.exceptions import ImproperlyConfigured from django.db import models from django.utils import six -from django.utils.datastructures import SortedDict +from rest_framework.compat import OrderedDict import inspect @@ -43,7 +44,11 @@ def _resolve_model(obj): """ if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: app_name, model_name = obj.split('.') - return models.get_model(app_name, model_name) + resolved_model = models.get_model(app_name, model_name) + if resolved_model is None: + msg = "Django did not return a model for {0}.{1}" + raise ImproperlyConfigured(msg.format(app_name, model_name)) + return resolved_model elif inspect.isclass(obj) and issubclass(obj, models.Model): return obj raise ValueError("{0} is not a Django model".format(obj)) @@ -63,12 +68,12 @@ def get_field_info(model): pk = pk.rel.to._meta.pk # Deal with regular fields. - fields = SortedDict() + fields = OrderedDict() for field in [field for field in opts.fields if field.serialize and not field.rel]: fields[field.name] = field # Deal with forward relationships. - forward_relations = SortedDict() + forward_relations = OrderedDict() for field in [field for field in opts.fields if field.serialize and field.rel]: forward_relations[field.name] = RelationInfo( model_field=field, @@ -89,7 +94,7 @@ def get_field_info(model): ) # Deal with reverse relationships. - reverse_relations = SortedDict() + reverse_relations = OrderedDict() for relation in opts.get_all_related_objects(): accessor_name = relation.get_accessor_name() reverse_relations[accessor_name] = RelationInfo( @@ -107,21 +112,21 @@ def get_field_info(model): related=relation.model, to_many=True, has_through_model=( - hasattr(relation.field.rel, 'through') and - not relation.field.rel.through._meta.auto_created + (getattr(relation.field.rel, 'through', None) is not None) + and not relation.field.rel.through._meta.auto_created ) ) # Shortcut that merges both regular fields and the pk, # for simplifying regular field lookup. - fields_and_pk = SortedDict() + fields_and_pk = OrderedDict() fields_and_pk['pk'] = pk fields_and_pk[pk.name] = pk fields_and_pk.update(fields) # Shortcut that merges both forward and reverse relationships - relations = SortedDict( + relations = OrderedDict( list(forward_relations.items()) + list(reverse_relations.items()) ) diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py index e64fdd22..2a7c4675 100644 --- a/rest_framework/utils/representation.py +++ b/rest_framework/utils/representation.py @@ -3,6 +3,8 @@ Helper functions for creating user-friendly representations of serializer classes and serializer fields. """ from django.db import models +from django.utils.functional import Promise +from rest_framework.compat import force_text import re @@ -19,6 +21,9 @@ def smart_repr(value): if isinstance(value, models.Manager): return manager_repr(value) + if isinstance(value, Promise) and value._delegate_text: + value = force_text(value) + value = repr(value) # Representations like u'help text' @@ -77,6 +82,11 @@ def serializer_repr(serializer, indent, force_many=None): ret += field_repr(field.child_relation, force_many=field.child_relation) else: ret += field_repr(field) + + if serializer.validators: + ret += '\n' + indent_str + 'class Meta:' + ret += '\n' + indent_str + ' validators = ' + smart_repr(serializer.validators) + return ret diff --git a/rest_framework/utils/serializer_helpers.py b/rest_framework/utils/serializer_helpers.py new file mode 100644 index 00000000..92d19857 --- /dev/null +++ b/rest_framework/utils/serializer_helpers.py @@ -0,0 +1,102 @@ +from rest_framework.compat import OrderedDict + + +class ReturnDict(OrderedDict): + """ + Return object from `serialier.data` for the `Serializer` class. + Includes a backlink to the serializer instance for renderers + to use if they need richer field information. + """ + def __init__(self, *args, **kwargs): + self.serializer = kwargs.pop('serializer') + super(ReturnDict, self).__init__(*args, **kwargs) + + def copy(self): + return ReturnDict(self, serializer=self.serializer) + + +class ReturnList(list): + """ + Return object from `serialier.data` for the `SerializerList` class. + Includes a backlink to the serializer instance for renderers + to use if they need richer field information. + """ + def __init__(self, *args, **kwargs): + self.serializer = kwargs.pop('serializer') + super(ReturnList, self).__init__(*args, **kwargs) + + +class BoundField(object): + """ + A field object that also includes `.value` and `.error` properties. + Returned when iterating over a serializer instance, + providing an API similar to Django forms and form fields. + """ + def __init__(self, field, value, errors, prefix=''): + self._field = field + self.value = value + self.errors = errors + self.name = prefix + self.field_name + + def __getattr__(self, attr_name): + return getattr(self._field, attr_name) + + @property + def _proxy_class(self): + return self._field.__class__ + + def __repr__(self): + return '<%s value=%s errors=%s>' % ( + self.__class__.__name__, self.value, self.errors + ) + + +class NestedBoundField(BoundField): + """ + This `BoundField` additionally implements __iter__ and __getitem__ + in order to support nested bound fields. This class is the type of + `BoundField` that is used for serializer fields. + """ + def __iter__(self): + for field in self.fields.values(): + yield self[field.field_name] + + def __getitem__(self, key): + field = self.fields[key] + value = self.value.get(key) if self.value else None + error = self.errors.get(key) if self.errors else None + if hasattr(field, 'fields'): + return NestedBoundField(field, value, error, prefix=self.name + '.') + return BoundField(field, value, error, prefix=self.name + '.') + + +class BindingDict(object): + """ + This dict-like object is used to store fields on a serializer. + + This ensures that whenever fields are added to the serializer we call + `field.bind()` so that the `field_name` and `parent` attributes + can be set correctly. + """ + def __init__(self, serializer): + self.serializer = serializer + self.fields = OrderedDict() + + def __setitem__(self, key, field): + self.fields[key] = field + field.bind(field_name=key, parent=self.serializer) + + def __getitem__(self, key): + return self.fields[key] + + def __delitem__(self, key): + del self.fields[key] + + def items(self): + return self.fields.items() + + def keys(self): + return self.fields.keys() + + def values(self): + return self.fields.values() |
