diff options
Diffstat (limited to 'rest_framework/utils')
| -rw-r--r-- | rest_framework/utils/breadcrumbs.py | 6 | ||||
| -rw-r--r-- | rest_framework/utils/encoders.py | 124 | ||||
| -rw-r--r-- | rest_framework/utils/field_mapping.py | 249 | ||||
| -rw-r--r-- | rest_framework/utils/formatting.py | 43 | ||||
| -rw-r--r-- | rest_framework/utils/html.py | 88 | ||||
| -rw-r--r-- | rest_framework/utils/humanize_datetime.py | 47 | ||||
| -rw-r--r-- | rest_framework/utils/mediatypes.py | 9 | ||||
| -rw-r--r-- | rest_framework/utils/model_meta.py | 169 | ||||
| -rw-r--r-- | rest_framework/utils/representation.py | 99 | ||||
| -rw-r--r-- | rest_framework/utils/serializer_helpers.py | 120 | ||||
| -rw-r--r-- | rest_framework/utils/urls.py | 25 |
11 files changed, 861 insertions, 118 deletions
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index d51374b0..e6690d17 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -1,6 +1,5 @@ from __future__ import unicode_literals from django.core.urlresolvers import resolve, get_script_prefix -from rest_framework.utils.formatting import get_view_name def get_breadcrumbs(url): @@ -9,8 +8,11 @@ def get_breadcrumbs(url): tuple of (name, url). """ + from rest_framework.settings import api_settings from rest_framework.views import APIView + view_name_func = api_settings.VIEW_NAME_FUNCTION + def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen): """ Add tuples of (name, url) to the breadcrumbs list, @@ -30,7 +32,7 @@ def get_breadcrumbs(url): # Probably an optional trailing slash. if not seen or seen[-1] != view: suffix = getattr(view, 'suffix', None) - name = get_view_name(view.cls, suffix) + name = view_name_func(cls, suffix) breadcrumbs_list.insert(0, (name, prefix + url)) seen.append(view) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index b26a2085..2160d18b 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -2,96 +2,60 @@ Helper classes for parsers. """ from __future__ import unicode_literals -from django.utils.datastructures import SortedDict +from django.db.models.query import QuerySet +from django.utils import six, timezone +from django.utils.encoding import force_text from django.utils.functional import Promise -from rest_framework.compat import timezone, force_text -from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata +from rest_framework.compat import total_seconds import datetime import decimal -import types import json +import uuid class JSONEncoder(json.JSONEncoder): """ JSONEncoder subclass that knows how to encode date/time/timedelta, - decimal types, and generators. + decimal types, generators and other basic python objects. """ - def default(self, o): + def default(self, obj): # For Date Time string spec, see ECMA 262 # http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15 - if isinstance(o, Promise): - return force_text(o) - elif isinstance(o, datetime.datetime): - r = o.isoformat() - if o.microsecond: - r = r[:23] + r[26:] - if r.endswith('+00:00'): - r = r[:-6] + 'Z' - return r - elif isinstance(o, datetime.date): - return o.isoformat() - elif isinstance(o, datetime.time): - if timezone and timezone.is_aware(o): + if isinstance(obj, Promise): + return force_text(obj) + elif isinstance(obj, datetime.datetime): + representation = obj.isoformat() + if obj.microsecond: + representation = representation[:23] + representation[26:] + if representation.endswith('+00:00'): + representation = representation[:-6] + 'Z' + return representation + elif isinstance(obj, datetime.date): + return obj.isoformat() + elif isinstance(obj, datetime.time): + if timezone and timezone.is_aware(obj): raise ValueError("JSON can't represent timezone-aware times.") - r = o.isoformat() - if o.microsecond: - r = r[:12] - return r - elif isinstance(o, datetime.timedelta): - return str(o.total_seconds()) - elif isinstance(o, decimal.Decimal): - return str(o) - elif hasattr(o, '__iter__'): - return [i for i in o] - return super(JSONEncoder, self).default(o) - - -try: - import yaml -except ImportError: - SafeDumper = None -else: - # Adapted from http://pyyaml.org/attachment/ticket/161/use_ordered_dict.py - class SafeDumper(yaml.SafeDumper): - """ - Handles decimals as strings. - Handles SortedDicts as usual dicts, but preserves field order, rather - than the usual behaviour of sorting the keys. - """ - def represent_decimal(self, data): - return self.represent_scalar('tag:yaml.org,2002:str', str(data)) - - def represent_mapping(self, tag, mapping, flow_style=None): - value = [] - node = yaml.MappingNode(tag, value, flow_style=flow_style) - if self.alias_key is not None: - self.represented_objects[self.alias_key] = node - best_style = True - if hasattr(mapping, 'items'): - mapping = list(mapping.items()) - if not isinstance(mapping, SortedDict): - mapping.sort() - for item_key, item_value in mapping: - node_key = self.represent_data(item_key) - node_value = self.represent_data(item_value) - if not (isinstance(node_key, yaml.ScalarNode) and not node_key.style): - best_style = False - if not (isinstance(node_value, yaml.ScalarNode) and not node_value.style): - best_style = False - value.append((node_key, node_value)) - if flow_style is None: - if self.default_flow_style is not None: - node.flow_style = self.default_flow_style - else: - node.flow_style = best_style - return node - - SafeDumper.add_representer(SortedDict, - yaml.representer.SafeRepresenter.represent_dict) - SafeDumper.add_representer(DictWithMetadata, - yaml.representer.SafeRepresenter.represent_dict) - SafeDumper.add_representer(SortedDictWithMetadata, - yaml.representer.SafeRepresenter.represent_dict) - SafeDumper.add_representer(types.GeneratorType, - yaml.representer.SafeRepresenter.represent_list) + representation = obj.isoformat() + if obj.microsecond: + representation = representation[:12] + return representation + elif isinstance(obj, datetime.timedelta): + return six.text_type(total_seconds(obj)) + elif isinstance(obj, decimal.Decimal): + # Serializers will coerce decimals to strings by default. + return float(obj) + elif isinstance(obj, uuid.UUID): + return six.text_type(obj) + elif isinstance(obj, QuerySet): + return tuple(obj) + elif hasattr(obj, 'tolist'): + # Numpy arrays and array scalars. + return obj.tolist() + elif hasattr(obj, '__getitem__'): + try: + return dict(obj) + except: + pass + elif hasattr(obj, '__iter__'): + return tuple(item for item in obj) + return super(JSONEncoder, self).default(obj) diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py new file mode 100644 index 00000000..c97ec5d0 --- /dev/null +++ b/rest_framework/utils/field_mapping.py @@ -0,0 +1,249 @@ +""" +Helper functions for mapping model fields to a dictionary of default +keyword arguments that should be used for their equivelent serializer fields. +""" +from django.core import validators +from django.db import models +from django.utils.text import capfirst +from rest_framework.compat import clean_manytomany_helptext +from rest_framework.validators import UniqueValidator +import inspect + + +NUMERIC_FIELD_TYPES = ( + models.IntegerField, models.FloatField, models.DecimalField +) + + +class ClassLookupDict(object): + """ + Takes a dictionary with classes as keys. + Lookups against this object will traverses the object's inheritance + hierarchy in method resolution order, and returns the first matching value + from the dictionary or raises a KeyError if nothing matches. + """ + def __init__(self, mapping): + self.mapping = mapping + + def __getitem__(self, key): + if hasattr(key, '_proxy_class'): + # Deal with proxy classes. Ie. BoundField behaves as if it + # is a Field instance when using ClassLookupDict. + base_class = key._proxy_class + else: + base_class = key.__class__ + + for cls in inspect.getmro(base_class): + if cls in self.mapping: + return self.mapping[cls] + raise KeyError('Class %s not found in lookup.', cls.__name__) + + def __setitem__(self, key, value): + self.mapping[key] = value + + +def needs_label(model_field, field_name): + """ + Returns `True` if the label based on the model's verbose name + is not equal to the default label it would have based on it's field name. + """ + default_label = field_name.replace('_', ' ').capitalize() + return capfirst(model_field.verbose_name) != default_label + + +def get_detail_view_name(model): + """ + Given a model class, return the view name to use for URL relationships + that refer to instances of the model. + """ + return '%(model_name)s-detail' % { + 'app_label': model._meta.app_label, + 'model_name': model._meta.object_name.lower() + } + + +def get_field_kwargs(field_name, model_field): + """ + Creates a default instance of a basic non-relational field. + """ + kwargs = {} + validator_kwarg = list(model_field.validators) + + # The following will only be used by ModelField classes. + # Gets removed for everything else. + kwargs['model_field'] = model_field + + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) + + if model_field.help_text: + kwargs['help_text'] = model_field.help_text + + max_digits = getattr(model_field, 'max_digits', None) + if max_digits is not None: + kwargs['max_digits'] = max_digits + + decimal_places = getattr(model_field, 'decimal_places', None) + if decimal_places is not None: + kwargs['decimal_places'] = decimal_places + + if isinstance(model_field, models.TextField): + kwargs['style'] = {'base_template': 'textarea.html'} + + if isinstance(model_field, models.AutoField) or not model_field.editable: + # If this field is read-only, then return early. + # Further keyword arguments are not valid. + kwargs['read_only'] = True + return kwargs + + if model_field.has_default() or model_field.blank or model_field.null: + kwargs['required'] = False + + if model_field.null and not isinstance(model_field, models.NullBooleanField): + kwargs['allow_null'] = True + + if model_field.blank: + kwargs['allow_blank'] = True + + if model_field.flatchoices: + # If this model field contains choices, then return early. + # Further keyword arguments are not valid. + kwargs['choices'] = model_field.flatchoices + return kwargs + + # Ensure that max_length is passed explicitly as a keyword arg, + # rather than as a validator. + max_length = getattr(model_field, 'max_length', None) + if max_length is not None and isinstance(model_field, models.CharField): + kwargs['max_length'] = max_length + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MaxLengthValidator) + ] + + # Ensure that min_length is passed explicitly as a keyword arg, + # rather than as a validator. + min_length = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MinLengthValidator) + ), None) + if min_length is not None and isinstance(model_field, models.CharField): + kwargs['min_length'] = min_length + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MinLengthValidator) + ] + + # Ensure that max_value is passed explicitly as a keyword arg, + # rather than as a validator. + max_value = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MaxValueValidator) + ), None) + if max_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES): + kwargs['max_value'] = max_value + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MaxValueValidator) + ] + + # Ensure that max_value is passed explicitly as a keyword arg, + # rather than as a validator. + min_value = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MinValueValidator) + ), None) + if min_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES): + kwargs['min_value'] = min_value + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MinValueValidator) + ] + + # URLField does not need to include the URLValidator argument, + # as it is explicitly added in. + if isinstance(model_field, models.URLField): + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.URLValidator) + ] + + # EmailField does not need to include the validate_email argument, + # as it is explicitly added in. + if isinstance(model_field, models.EmailField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_email + ] + + # SlugField do not need to include the 'validate_slug' argument, + if isinstance(model_field, models.SlugField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_slug + ] + + if getattr(model_field, 'unique', False): + validator = UniqueValidator(queryset=model_field.model._default_manager) + validator_kwarg.append(validator) + + if validator_kwarg: + kwargs['validators'] = validator_kwarg + + return kwargs + + +def get_relation_kwargs(field_name, relation_info): + """ + Creates a default instance of a flat relational field. + """ + model_field, related_model, to_many, has_through_model = relation_info + kwargs = { + 'queryset': related_model._default_manager, + 'view_name': get_detail_view_name(related_model) + } + + if to_many: + kwargs['many'] = True + + if has_through_model: + kwargs['read_only'] = True + kwargs.pop('queryset', None) + + if model_field: + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) + help_text = clean_manytomany_helptext(model_field.help_text) + if help_text: + kwargs['help_text'] = help_text + if not model_field.editable: + kwargs['read_only'] = True + kwargs.pop('queryset', None) + if kwargs.get('read_only', False): + # If this field is read-only, then return early. + # No further keyword arguments are valid. + return kwargs + if model_field.has_default() or model_field.null: + kwargs['required'] = False + if model_field.null: + kwargs['allow_null'] = True + if model_field.validators: + kwargs['validators'] = model_field.validators + if getattr(model_field, 'unique', False): + validator = UniqueValidator(queryset=model_field.model._default_manager) + kwargs['validators'] = kwargs.get('validators', []) + [validator] + + return kwargs + + +def get_nested_relation_kwargs(relation_info): + kwargs = {'read_only': True} + if relation_info.to_many: + kwargs['many'] = True + return kwargs + + +def get_url_kwargs(model_field): + return { + 'view_name': get_detail_view_name(model_field) + } diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 4bec8387..8b6f005e 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -2,14 +2,13 @@ Utility functions to return a formatted name and description for a given view. """ from __future__ import unicode_literals - from django.utils.html import escape from django.utils.safestring import mark_safe -from rest_framework.compat import apply_markdown, smart_text +from rest_framework.compat import apply_markdown, force_text import re -def _remove_trailing_string(content, trailing): +def remove_trailing_string(content, trailing): """ Strip trailing component `trailing` from `content` if it exists. Used when generating names from view classes. @@ -19,11 +18,16 @@ def _remove_trailing_string(content, trailing): return content -def _remove_leading_indent(content): +def dedent(content): """ Remove leading indent from a block of text. Used when generating descriptions from docstrings. + + Note that python's `textwrap.dedent` doesn't quite cut it, + as it fails to dedent multiline docstrings that include + unindented text on the initial line. """ + content = force_text(content) whitespace_counts = [len(line) - len(line.lstrip(' ')) for line in content.splitlines()[1:] if line.lstrip()] @@ -31,11 +35,11 @@ def _remove_leading_indent(content): if whitespace_counts: whitespace_pattern = '^' + (' ' * min(whitespace_counts)) content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) - content = content.strip('\n') - return content + return content.strip() -def _camelcase_to_spaces(content): + +def camelcase_to_spaces(content): """ Translate 'CamelCaseNames' to 'Camel Case Names'. Used when generating names from view classes. @@ -45,30 +49,6 @@ def _camelcase_to_spaces(content): return ' '.join(content.split('_')).title() -def get_view_name(cls, suffix=None): - """ - Return a formatted name for an `APIView` class or `@api_view` function. - """ - name = cls.__name__ - name = _remove_trailing_string(name, 'View') - name = _remove_trailing_string(name, 'ViewSet') - name = _camelcase_to_spaces(name) - if suffix: - name += ' ' + suffix - return name - - -def get_view_description(cls, html=False): - """ - Return a description for an `APIView` class or `@api_view` function. - """ - description = cls.__doc__ or '' - description = _remove_leading_indent(smart_text(description)) - if html: - return markup_description(description) - return description - - def markup_description(description): """ Apply HTML markup to the given description. @@ -77,4 +57,5 @@ def markup_description(description): description = apply_markdown(description) else: description = escape(description).replace('\n', '<br />') + description = '<p>' + description + '</p>' return mark_safe(description) diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py new file mode 100644 index 00000000..d773952d --- /dev/null +++ b/rest_framework/utils/html.py @@ -0,0 +1,88 @@ +""" +Helpers for dealing with HTML input. +""" +import re +from django.utils.datastructures import MultiValueDict + + +def is_html_input(dictionary): + # MultiDict type datastructures are used to represent HTML form input, + # which may have more than one value for each key. + return hasattr(dictionary, 'getlist') + + +def parse_html_list(dictionary, prefix=''): + """ + Used to suport list values in HTML forms. + Supports lists of primitives and/or dictionaries. + + * List of primitives. + + { + '[0]': 'abc', + '[1]': 'def', + '[2]': 'hij' + } + --> + [ + 'abc', + 'def', + 'hij' + ] + + * List of dictionaries. + + { + '[0]foo': 'abc', + '[0]bar': 'def', + '[1]foo': 'hij', + '[1]bar': 'klm', + } + --> + [ + {'foo': 'abc', 'bar': 'def'}, + {'foo': 'hij', 'bar': 'klm'} + ] + """ + ret = {} + regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix)) + for field, value in dictionary.items(): + match = regex.match(field) + if not match: + continue + index, key = match.groups() + index = int(index) + if not key: + ret[index] = value + elif isinstance(ret.get(index), dict): + ret[index][key] = value + else: + ret[index] = MultiValueDict({key: [value]}) + return [ret[item] for item in sorted(ret.keys())] + + +def parse_html_dict(dictionary, prefix): + """ + Used to support dictionary values in HTML forms. + + { + 'profile.username': 'example', + 'profile.email': 'example@example.com', + } + --> + { + 'profile': { + 'username': 'example', + 'email': 'example@example.com' + } + } + """ + ret = {} + regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix)) + for field, value in dictionary.items(): + match = regex.match(field) + if not match: + continue + key = match.groups()[0] + ret[key] = value + return ret diff --git a/rest_framework/utils/humanize_datetime.py b/rest_framework/utils/humanize_datetime.py new file mode 100644 index 00000000..649f2abc --- /dev/null +++ b/rest_framework/utils/humanize_datetime.py @@ -0,0 +1,47 @@ +""" +Helper functions that convert strftime formats into more readable representations. +""" +from rest_framework import ISO_8601 + + +def datetime_formats(formats): + format = ', '.join(formats).replace( + ISO_8601, + 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' + ) + return humanize_strptime(format) + + +def date_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') + return humanize_strptime(format) + + +def time_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') + return humanize_strptime(format) + + +def humanize_strptime(format_string): + # Note that we're missing some of the locale specific mappings that + # don't really make sense. + mapping = { + "%Y": "YYYY", + "%y": "YY", + "%m": "MM", + "%b": "[Jan-Dec]", + "%B": "[January-December]", + "%d": "DD", + "%H": "hh", + "%I": "hh", # Requires '%p' to differentiate from '%H'. + "%M": "mm", + "%S": "ss", + "%f": "uuuuuu", + "%a": "[Mon-Sun]", + "%A": "[Monday-Sunday]", + "%p": "[AM|PM]", + "%z": "[+HHMM|-HHMM]" + } + for key, val in mapping.items(): + format_string = format_string.replace(key, val) + return format_string diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py index c09c2933..de2931c2 100644 --- a/rest_framework/utils/mediatypes.py +++ b/rest_framework/utils/mediatypes.py @@ -5,6 +5,7 @@ See http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7 """ from __future__ import unicode_literals from django.http.multipartparser import parse_header +from django.utils.encoding import python_2_unicode_compatible from rest_framework import HTTP_HEADER_ENCODING @@ -43,6 +44,7 @@ def order_by_precedence(media_type_lst): return [media_types for media_types in ret if media_types] +@python_2_unicode_compatible class _MediaType(object): def __init__(self, media_type_str): if media_type_str is None: @@ -57,7 +59,7 @@ class _MediaType(object): if key != 'q' and other.params.get(key, None) != self.params.get(key, None): return False - if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type: + if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type: return False if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type: @@ -74,14 +76,11 @@ class _MediaType(object): return 0 elif self.sub_type == '*': return 1 - elif not self.params or self.params.keys() == ['q']: + elif not self.params or list(self.params.keys()) == ['q']: return 2 return 3 def __str__(self): - return unicode(self).encode('utf-8') - - def __unicode__(self): ret = "%s/%s" % (self.main_type, self.sub_type) for key, val in self.params.items(): ret += "; %s=%s" % (key, val) diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py new file mode 100644 index 00000000..d92bceb9 --- /dev/null +++ b/rest_framework/utils/model_meta.py @@ -0,0 +1,169 @@ +""" +Helper function for returning the field information that is associated +with a model class. This includes returning all the forward and reverse +relationships and their associated metadata. + +Usage: `get_field_info(model)` returns a `FieldInfo` instance. +""" +from collections import namedtuple +from django.core.exceptions import ImproperlyConfigured +from django.db import models +from django.utils import six +from rest_framework.compat import OrderedDict +import inspect + + +FieldInfo = namedtuple('FieldResult', [ + 'pk', # Model field instance + 'fields', # Dict of field name -> model field instance + 'forward_relations', # Dict of field name -> RelationInfo + 'reverse_relations', # Dict of field name -> RelationInfo + 'fields_and_pk', # Shortcut for 'pk' + 'fields' + 'relations' # Shortcut for 'forward_relations' + 'reverse_relations' +]) + +RelationInfo = namedtuple('RelationInfo', [ + 'model_field', + 'related_model', + 'to_many', + 'has_through_model' +]) + + +def _resolve_model(obj): + """ + Resolve supplied `obj` to a Django model class. + + `obj` must be a Django model class itself, or a string + representation of one. Useful in situations like GH #1225 where + Django may not have resolved a string-based reference to a model in + another model's foreign key definition. + + String representations should have the format: + 'appname.ModelName' + """ + if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: + app_name, model_name = obj.split('.') + resolved_model = models.get_model(app_name, model_name) + if resolved_model is None: + msg = "Django did not return a model for {0}.{1}" + raise ImproperlyConfigured(msg.format(app_name, model_name)) + return resolved_model + elif inspect.isclass(obj) and issubclass(obj, models.Model): + return obj + raise ValueError("{0} is not a Django model".format(obj)) + + +def get_field_info(model): + """ + Given a model class, returns a `FieldInfo` instance, which is a + `namedtuple`, containing metadata about the various field types on the model + including information about their relationships. + """ + opts = model._meta.concrete_model._meta + + pk = _get_pk(opts) + fields = _get_fields(opts) + forward_relations = _get_forward_relationships(opts) + reverse_relations = _get_reverse_relationships(opts) + fields_and_pk = _merge_fields_and_pk(pk, fields) + relationships = _merge_relationships(forward_relations, reverse_relations) + + return FieldInfo(pk, fields, forward_relations, reverse_relations, + fields_and_pk, relationships) + + +def _get_pk(opts): + pk = opts.pk + while pk.rel and pk.rel.parent_link: + # If model is a child via multi-table inheritance, use parent's pk. + pk = pk.rel.to._meta.pk + + return pk + + +def _get_fields(opts): + fields = OrderedDict() + for field in [field for field in opts.fields if field.serialize and not field.rel]: + fields[field.name] = field + + return fields + + +def _get_forward_relationships(opts): + """ + Returns an `OrderedDict` of field names to `RelationInfo`. + """ + forward_relations = OrderedDict() + for field in [field for field in opts.fields if field.serialize and field.rel]: + forward_relations[field.name] = RelationInfo( + model_field=field, + related_model=_resolve_model(field.rel.to), + to_many=False, + has_through_model=False + ) + + # Deal with forward many-to-many relationships. + for field in [field for field in opts.many_to_many if field.serialize]: + forward_relations[field.name] = RelationInfo( + model_field=field, + related_model=_resolve_model(field.rel.to), + to_many=True, + has_through_model=( + not field.rel.through._meta.auto_created + ) + ) + + return forward_relations + + +def _get_reverse_relationships(opts): + """ + Returns an `OrderedDict` of field names to `RelationInfo`. + """ + # Note that we have a hack here to handle internal API differences for + # this internal API across Django 1.7 -> Django 1.8. + # See: https://code.djangoproject.com/ticket/24208 + + reverse_relations = OrderedDict() + for relation in opts.get_all_related_objects(): + accessor_name = relation.get_accessor_name() + related = getattr(relation, 'related_model', relation.model) + reverse_relations[accessor_name] = RelationInfo( + model_field=None, + related_model=related, + to_many=relation.field.rel.multiple, + has_through_model=False + ) + + # Deal with reverse many-to-many relationships. + for relation in opts.get_all_related_many_to_many_objects(): + accessor_name = relation.get_accessor_name() + related = getattr(relation, 'related_model', relation.model) + reverse_relations[accessor_name] = RelationInfo( + model_field=None, + related_model=related, + to_many=True, + has_through_model=( + (getattr(relation.field.rel, 'through', None) is not None) and + not relation.field.rel.through._meta.auto_created + ) + ) + + return reverse_relations + + +def _merge_fields_and_pk(pk, fields): + fields_and_pk = OrderedDict() + fields_and_pk['pk'] = pk + fields_and_pk[pk.name] = pk + fields_and_pk.update(fields) + + return fields_and_pk + + +def _merge_relationships(forward_relations, reverse_relations): + return OrderedDict( + list(forward_relations.items()) + + list(reverse_relations.items()) + ) diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py new file mode 100644 index 00000000..1bfc64c1 --- /dev/null +++ b/rest_framework/utils/representation.py @@ -0,0 +1,99 @@ +""" +Helper functions for creating user-friendly representations +of serializer classes and serializer fields. +""" +from __future__ import unicode_literals +from django.db import models +from django.utils.encoding import force_text +from django.utils.functional import Promise +from rest_framework.compat import unicode_repr +import re + + +def manager_repr(value): + model = value.model + opts = model._meta + for _, name, manager in opts.concrete_managers + opts.abstract_managers: + if manager == value: + return '%s.%s.all()' % (model._meta.object_name, name) + return repr(value) + + +def smart_repr(value): + if isinstance(value, models.Manager): + return manager_repr(value) + + if isinstance(value, Promise) and value._delegate_text: + value = force_text(value) + + value = unicode_repr(value) + + # Representations like u'help text' + # should simply be presented as 'help text' + if value.startswith("u'") and value.endswith("'"): + return value[1:] + + # Representations like + # <django.core.validators.RegexValidator object at 0x1047af050> + # Should be presented as + # <django.core.validators.RegexValidator object> + value = re.sub(' at 0x[0-9a-f]{4,32}>', '>', value) + + return value + + +def field_repr(field, force_many=False): + kwargs = field._kwargs + if force_many: + kwargs = kwargs.copy() + kwargs['many'] = True + kwargs.pop('child', None) + + arg_string = ', '.join([smart_repr(val) for val in field._args]) + kwarg_string = ', '.join([ + '%s=%s' % (key, smart_repr(val)) + for key, val in sorted(kwargs.items()) + ]) + if arg_string and kwarg_string: + arg_string += ', ' + + if force_many: + class_name = force_many.__class__.__name__ + else: + class_name = field.__class__.__name__ + + return "%s(%s%s)" % (class_name, arg_string, kwarg_string) + + +def serializer_repr(serializer, indent, force_many=None): + ret = field_repr(serializer, force_many) + ':' + indent_str = ' ' * indent + + if force_many: + fields = force_many.fields + else: + fields = serializer.fields + + for field_name, field in fields.items(): + ret += '\n' + indent_str + field_name + ' = ' + if hasattr(field, 'fields'): + ret += serializer_repr(field, indent + 1) + elif hasattr(field, 'child'): + ret += list_repr(field, indent + 1) + elif hasattr(field, 'child_relation'): + ret += field_repr(field.child_relation, force_many=field.child_relation) + else: + ret += field_repr(field) + + if serializer.validators: + ret += '\n' + indent_str + 'class Meta:' + ret += '\n' + indent_str + ' validators = ' + smart_repr(serializer.validators) + + return ret + + +def list_repr(serializer, indent): + child = serializer.child + if hasattr(child, 'fields'): + return serializer_repr(serializer, indent, force_many=child) + return field_repr(serializer) diff --git a/rest_framework/utils/serializer_helpers.py b/rest_framework/utils/serializer_helpers.py new file mode 100644 index 00000000..87bb3ac0 --- /dev/null +++ b/rest_framework/utils/serializer_helpers.py @@ -0,0 +1,120 @@ +from __future__ import unicode_literals +import collections +from rest_framework.compat import OrderedDict, unicode_to_repr + + +class ReturnDict(OrderedDict): + """ + Return object from `serialier.data` for the `Serializer` class. + Includes a backlink to the serializer instance for renderers + to use if they need richer field information. + """ + def __init__(self, *args, **kwargs): + self.serializer = kwargs.pop('serializer') + super(ReturnDict, self).__init__(*args, **kwargs) + + def copy(self): + return ReturnDict(self, serializer=self.serializer) + + def __repr__(self): + return dict.__repr__(self) + + def __reduce__(self): + # Pickling these objects will drop the .serializer backlink, + # but preserve the raw data. + return (dict, (dict(self),)) + + +class ReturnList(list): + """ + Return object from `serialier.data` for the `SerializerList` class. + Includes a backlink to the serializer instance for renderers + to use if they need richer field information. + """ + def __init__(self, *args, **kwargs): + self.serializer = kwargs.pop('serializer') + super(ReturnList, self).__init__(*args, **kwargs) + + def __repr__(self): + return list.__repr__(self) + + def __reduce__(self): + # Pickling these objects will drop the .serializer backlink, + # but preserve the raw data. + return (list, (list(self),)) + + +class BoundField(object): + """ + A field object that also includes `.value` and `.error` properties. + Returned when iterating over a serializer instance, + providing an API similar to Django forms and form fields. + """ + def __init__(self, field, value, errors, prefix=''): + self._field = field + self.value = value + self.errors = errors + self.name = prefix + self.field_name + + def __getattr__(self, attr_name): + return getattr(self._field, attr_name) + + @property + def _proxy_class(self): + return self._field.__class__ + + def __repr__(self): + return unicode_to_repr('<%s value=%s errors=%s>' % ( + self.__class__.__name__, self.value, self.errors + )) + + +class NestedBoundField(BoundField): + """ + This `BoundField` additionally implements __iter__ and __getitem__ + in order to support nested bound fields. This class is the type of + `BoundField` that is used for serializer fields. + """ + def __iter__(self): + for field in self.fields.values(): + yield self[field.field_name] + + def __getitem__(self, key): + field = self.fields[key] + value = self.value.get(key) if self.value else None + error = self.errors.get(key) if self.errors else None + if hasattr(field, 'fields'): + return NestedBoundField(field, value, error, prefix=self.name + '.') + return BoundField(field, value, error, prefix=self.name + '.') + + +class BindingDict(collections.MutableMapping): + """ + This dict-like object is used to store fields on a serializer. + + This ensures that whenever fields are added to the serializer we call + `field.bind()` so that the `field_name` and `parent` attributes + can be set correctly. + """ + def __init__(self, serializer): + self.serializer = serializer + self.fields = OrderedDict() + + def __setitem__(self, key, field): + self.fields[key] = field + field.bind(field_name=key, parent=self.serializer) + + def __getitem__(self, key): + return self.fields[key] + + def __delitem__(self, key): + del self.fields[key] + + def __iter__(self): + return iter(self.fields) + + def __len__(self): + return len(self.fields) + + def __repr__(self): + return dict.__repr__(self.fields) diff --git a/rest_framework/utils/urls.py b/rest_framework/utils/urls.py new file mode 100644 index 00000000..880ef9ed --- /dev/null +++ b/rest_framework/utils/urls.py @@ -0,0 +1,25 @@ +from django.utils.six.moves.urllib import parse as urlparse + + +def replace_query_param(url, key, val): + """ + Given a URL and a key/val pair, set or replace an item in the query + parameters of the URL, and return the new URL. + """ + (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url) + query_dict = urlparse.parse_qs(query) + query_dict[key] = [val] + query = urlparse.urlencode(sorted(list(query_dict.items())), doseq=True) + return urlparse.urlunsplit((scheme, netloc, path, query, fragment)) + + +def remove_query_param(url, key): + """ + Given a URL and a key/val pair, remove an item in the query + parameters of the URL, and return the new URL. + """ + (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url) + query_dict = urlparse.parse_qs(query) + query_dict.pop(key, None) + query = urlparse.urlencode(sorted(list(query_dict.items())), doseq=True) + return urlparse.urlunsplit((scheme, netloc, path, query, fragment)) |
