aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/utils
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/utils')
-rw-r--r--rest_framework/utils/breadcrumbs.py6
-rw-r--r--rest_framework/utils/encoders.py124
-rw-r--r--rest_framework/utils/field_mapping.py249
-rw-r--r--rest_framework/utils/formatting.py43
-rw-r--r--rest_framework/utils/html.py88
-rw-r--r--rest_framework/utils/humanize_datetime.py47
-rw-r--r--rest_framework/utils/mediatypes.py9
-rw-r--r--rest_framework/utils/model_meta.py169
-rw-r--r--rest_framework/utils/representation.py99
-rw-r--r--rest_framework/utils/serializer_helpers.py120
-rw-r--r--rest_framework/utils/urls.py25
11 files changed, 861 insertions, 118 deletions
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py
index d51374b0..e6690d17 100644
--- a/rest_framework/utils/breadcrumbs.py
+++ b/rest_framework/utils/breadcrumbs.py
@@ -1,6 +1,5 @@
from __future__ import unicode_literals
from django.core.urlresolvers import resolve, get_script_prefix
-from rest_framework.utils.formatting import get_view_name
def get_breadcrumbs(url):
@@ -9,8 +8,11 @@ def get_breadcrumbs(url):
tuple of (name, url).
"""
+ from rest_framework.settings import api_settings
from rest_framework.views import APIView
+ view_name_func = api_settings.VIEW_NAME_FUNCTION
+
def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):
"""
Add tuples of (name, url) to the breadcrumbs list,
@@ -30,7 +32,7 @@ def get_breadcrumbs(url):
# Probably an optional trailing slash.
if not seen or seen[-1] != view:
suffix = getattr(view, 'suffix', None)
- name = get_view_name(view.cls, suffix)
+ name = view_name_func(cls, suffix)
breadcrumbs_list.insert(0, (name, prefix + url))
seen.append(view)
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index b26a2085..2160d18b 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -2,96 +2,60 @@
Helper classes for parsers.
"""
from __future__ import unicode_literals
-from django.utils.datastructures import SortedDict
+from django.db.models.query import QuerySet
+from django.utils import six, timezone
+from django.utils.encoding import force_text
from django.utils.functional import Promise
-from rest_framework.compat import timezone, force_text
-from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
+from rest_framework.compat import total_seconds
import datetime
import decimal
-import types
import json
+import uuid
class JSONEncoder(json.JSONEncoder):
"""
JSONEncoder subclass that knows how to encode date/time/timedelta,
- decimal types, and generators.
+ decimal types, generators and other basic python objects.
"""
- def default(self, o):
+ def default(self, obj):
# For Date Time string spec, see ECMA 262
# http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15
- if isinstance(o, Promise):
- return force_text(o)
- elif isinstance(o, datetime.datetime):
- r = o.isoformat()
- if o.microsecond:
- r = r[:23] + r[26:]
- if r.endswith('+00:00'):
- r = r[:-6] + 'Z'
- return r
- elif isinstance(o, datetime.date):
- return o.isoformat()
- elif isinstance(o, datetime.time):
- if timezone and timezone.is_aware(o):
+ if isinstance(obj, Promise):
+ return force_text(obj)
+ elif isinstance(obj, datetime.datetime):
+ representation = obj.isoformat()
+ if obj.microsecond:
+ representation = representation[:23] + representation[26:]
+ if representation.endswith('+00:00'):
+ representation = representation[:-6] + 'Z'
+ return representation
+ elif isinstance(obj, datetime.date):
+ return obj.isoformat()
+ elif isinstance(obj, datetime.time):
+ if timezone and timezone.is_aware(obj):
raise ValueError("JSON can't represent timezone-aware times.")
- r = o.isoformat()
- if o.microsecond:
- r = r[:12]
- return r
- elif isinstance(o, datetime.timedelta):
- return str(o.total_seconds())
- elif isinstance(o, decimal.Decimal):
- return str(o)
- elif hasattr(o, '__iter__'):
- return [i for i in o]
- return super(JSONEncoder, self).default(o)
-
-
-try:
- import yaml
-except ImportError:
- SafeDumper = None
-else:
- # Adapted from http://pyyaml.org/attachment/ticket/161/use_ordered_dict.py
- class SafeDumper(yaml.SafeDumper):
- """
- Handles decimals as strings.
- Handles SortedDicts as usual dicts, but preserves field order, rather
- than the usual behaviour of sorting the keys.
- """
- def represent_decimal(self, data):
- return self.represent_scalar('tag:yaml.org,2002:str', str(data))
-
- def represent_mapping(self, tag, mapping, flow_style=None):
- value = []
- node = yaml.MappingNode(tag, value, flow_style=flow_style)
- if self.alias_key is not None:
- self.represented_objects[self.alias_key] = node
- best_style = True
- if hasattr(mapping, 'items'):
- mapping = list(mapping.items())
- if not isinstance(mapping, SortedDict):
- mapping.sort()
- for item_key, item_value in mapping:
- node_key = self.represent_data(item_key)
- node_value = self.represent_data(item_value)
- if not (isinstance(node_key, yaml.ScalarNode) and not node_key.style):
- best_style = False
- if not (isinstance(node_value, yaml.ScalarNode) and not node_value.style):
- best_style = False
- value.append((node_key, node_value))
- if flow_style is None:
- if self.default_flow_style is not None:
- node.flow_style = self.default_flow_style
- else:
- node.flow_style = best_style
- return node
-
- SafeDumper.add_representer(SortedDict,
- yaml.representer.SafeRepresenter.represent_dict)
- SafeDumper.add_representer(DictWithMetadata,
- yaml.representer.SafeRepresenter.represent_dict)
- SafeDumper.add_representer(SortedDictWithMetadata,
- yaml.representer.SafeRepresenter.represent_dict)
- SafeDumper.add_representer(types.GeneratorType,
- yaml.representer.SafeRepresenter.represent_list)
+ representation = obj.isoformat()
+ if obj.microsecond:
+ representation = representation[:12]
+ return representation
+ elif isinstance(obj, datetime.timedelta):
+ return six.text_type(total_seconds(obj))
+ elif isinstance(obj, decimal.Decimal):
+ # Serializers will coerce decimals to strings by default.
+ return float(obj)
+ elif isinstance(obj, uuid.UUID):
+ return six.text_type(obj)
+ elif isinstance(obj, QuerySet):
+ return tuple(obj)
+ elif hasattr(obj, 'tolist'):
+ # Numpy arrays and array scalars.
+ return obj.tolist()
+ elif hasattr(obj, '__getitem__'):
+ try:
+ return dict(obj)
+ except:
+ pass
+ elif hasattr(obj, '__iter__'):
+ return tuple(item for item in obj)
+ return super(JSONEncoder, self).default(obj)
diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py
new file mode 100644
index 00000000..c97ec5d0
--- /dev/null
+++ b/rest_framework/utils/field_mapping.py
@@ -0,0 +1,249 @@
+"""
+Helper functions for mapping model fields to a dictionary of default
+keyword arguments that should be used for their equivelent serializer fields.
+"""
+from django.core import validators
+from django.db import models
+from django.utils.text import capfirst
+from rest_framework.compat import clean_manytomany_helptext
+from rest_framework.validators import UniqueValidator
+import inspect
+
+
+NUMERIC_FIELD_TYPES = (
+ models.IntegerField, models.FloatField, models.DecimalField
+)
+
+
+class ClassLookupDict(object):
+ """
+ Takes a dictionary with classes as keys.
+ Lookups against this object will traverses the object's inheritance
+ hierarchy in method resolution order, and returns the first matching value
+ from the dictionary or raises a KeyError if nothing matches.
+ """
+ def __init__(self, mapping):
+ self.mapping = mapping
+
+ def __getitem__(self, key):
+ if hasattr(key, '_proxy_class'):
+ # Deal with proxy classes. Ie. BoundField behaves as if it
+ # is a Field instance when using ClassLookupDict.
+ base_class = key._proxy_class
+ else:
+ base_class = key.__class__
+
+ for cls in inspect.getmro(base_class):
+ if cls in self.mapping:
+ return self.mapping[cls]
+ raise KeyError('Class %s not found in lookup.', cls.__name__)
+
+ def __setitem__(self, key, value):
+ self.mapping[key] = value
+
+
+def needs_label(model_field, field_name):
+ """
+ Returns `True` if the label based on the model's verbose name
+ is not equal to the default label it would have based on it's field name.
+ """
+ default_label = field_name.replace('_', ' ').capitalize()
+ return capfirst(model_field.verbose_name) != default_label
+
+
+def get_detail_view_name(model):
+ """
+ Given a model class, return the view name to use for URL relationships
+ that refer to instances of the model.
+ """
+ return '%(model_name)s-detail' % {
+ 'app_label': model._meta.app_label,
+ 'model_name': model._meta.object_name.lower()
+ }
+
+
+def get_field_kwargs(field_name, model_field):
+ """
+ Creates a default instance of a basic non-relational field.
+ """
+ kwargs = {}
+ validator_kwarg = list(model_field.validators)
+
+ # The following will only be used by ModelField classes.
+ # Gets removed for everything else.
+ kwargs['model_field'] = model_field
+
+ if model_field.verbose_name and needs_label(model_field, field_name):
+ kwargs['label'] = capfirst(model_field.verbose_name)
+
+ if model_field.help_text:
+ kwargs['help_text'] = model_field.help_text
+
+ max_digits = getattr(model_field, 'max_digits', None)
+ if max_digits is not None:
+ kwargs['max_digits'] = max_digits
+
+ decimal_places = getattr(model_field, 'decimal_places', None)
+ if decimal_places is not None:
+ kwargs['decimal_places'] = decimal_places
+
+ if isinstance(model_field, models.TextField):
+ kwargs['style'] = {'base_template': 'textarea.html'}
+
+ if isinstance(model_field, models.AutoField) or not model_field.editable:
+ # If this field is read-only, then return early.
+ # Further keyword arguments are not valid.
+ kwargs['read_only'] = True
+ return kwargs
+
+ if model_field.has_default() or model_field.blank or model_field.null:
+ kwargs['required'] = False
+
+ if model_field.null and not isinstance(model_field, models.NullBooleanField):
+ kwargs['allow_null'] = True
+
+ if model_field.blank:
+ kwargs['allow_blank'] = True
+
+ if model_field.flatchoices:
+ # If this model field contains choices, then return early.
+ # Further keyword arguments are not valid.
+ kwargs['choices'] = model_field.flatchoices
+ return kwargs
+
+ # Ensure that max_length is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ max_length = getattr(model_field, 'max_length', None)
+ if max_length is not None and isinstance(model_field, models.CharField):
+ kwargs['max_length'] = max_length
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MaxLengthValidator)
+ ]
+
+ # Ensure that min_length is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ min_length = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MinLengthValidator)
+ ), None)
+ if min_length is not None and isinstance(model_field, models.CharField):
+ kwargs['min_length'] = min_length
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MinLengthValidator)
+ ]
+
+ # Ensure that max_value is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ max_value = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MaxValueValidator)
+ ), None)
+ if max_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):
+ kwargs['max_value'] = max_value
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MaxValueValidator)
+ ]
+
+ # Ensure that max_value is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ min_value = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MinValueValidator)
+ ), None)
+ if min_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):
+ kwargs['min_value'] = min_value
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MinValueValidator)
+ ]
+
+ # URLField does not need to include the URLValidator argument,
+ # as it is explicitly added in.
+ if isinstance(model_field, models.URLField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.URLValidator)
+ ]
+
+ # EmailField does not need to include the validate_email argument,
+ # as it is explicitly added in.
+ if isinstance(model_field, models.EmailField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if validator is not validators.validate_email
+ ]
+
+ # SlugField do not need to include the 'validate_slug' argument,
+ if isinstance(model_field, models.SlugField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if validator is not validators.validate_slug
+ ]
+
+ if getattr(model_field, 'unique', False):
+ validator = UniqueValidator(queryset=model_field.model._default_manager)
+ validator_kwarg.append(validator)
+
+ if validator_kwarg:
+ kwargs['validators'] = validator_kwarg
+
+ return kwargs
+
+
+def get_relation_kwargs(field_name, relation_info):
+ """
+ Creates a default instance of a flat relational field.
+ """
+ model_field, related_model, to_many, has_through_model = relation_info
+ kwargs = {
+ 'queryset': related_model._default_manager,
+ 'view_name': get_detail_view_name(related_model)
+ }
+
+ if to_many:
+ kwargs['many'] = True
+
+ if has_through_model:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
+
+ if model_field:
+ if model_field.verbose_name and needs_label(model_field, field_name):
+ kwargs['label'] = capfirst(model_field.verbose_name)
+ help_text = clean_manytomany_helptext(model_field.help_text)
+ if help_text:
+ kwargs['help_text'] = help_text
+ if not model_field.editable:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
+ if kwargs.get('read_only', False):
+ # If this field is read-only, then return early.
+ # No further keyword arguments are valid.
+ return kwargs
+ if model_field.has_default() or model_field.null:
+ kwargs['required'] = False
+ if model_field.null:
+ kwargs['allow_null'] = True
+ if model_field.validators:
+ kwargs['validators'] = model_field.validators
+ if getattr(model_field, 'unique', False):
+ validator = UniqueValidator(queryset=model_field.model._default_manager)
+ kwargs['validators'] = kwargs.get('validators', []) + [validator]
+
+ return kwargs
+
+
+def get_nested_relation_kwargs(relation_info):
+ kwargs = {'read_only': True}
+ if relation_info.to_many:
+ kwargs['many'] = True
+ return kwargs
+
+
+def get_url_kwargs(model_field):
+ return {
+ 'view_name': get_detail_view_name(model_field)
+ }
diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py
index 4bec8387..8b6f005e 100644
--- a/rest_framework/utils/formatting.py
+++ b/rest_framework/utils/formatting.py
@@ -2,14 +2,13 @@
Utility functions to return a formatted name and description for a given view.
"""
from __future__ import unicode_literals
-
from django.utils.html import escape
from django.utils.safestring import mark_safe
-from rest_framework.compat import apply_markdown, smart_text
+from rest_framework.compat import apply_markdown, force_text
import re
-def _remove_trailing_string(content, trailing):
+def remove_trailing_string(content, trailing):
"""
Strip trailing component `trailing` from `content` if it exists.
Used when generating names from view classes.
@@ -19,11 +18,16 @@ def _remove_trailing_string(content, trailing):
return content
-def _remove_leading_indent(content):
+def dedent(content):
"""
Remove leading indent from a block of text.
Used when generating descriptions from docstrings.
+
+ Note that python's `textwrap.dedent` doesn't quite cut it,
+ as it fails to dedent multiline docstrings that include
+ unindented text on the initial line.
"""
+ content = force_text(content)
whitespace_counts = [len(line) - len(line.lstrip(' '))
for line in content.splitlines()[1:] if line.lstrip()]
@@ -31,11 +35,11 @@ def _remove_leading_indent(content):
if whitespace_counts:
whitespace_pattern = '^' + (' ' * min(whitespace_counts))
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
- content = content.strip('\n')
- return content
+ return content.strip()
-def _camelcase_to_spaces(content):
+
+def camelcase_to_spaces(content):
"""
Translate 'CamelCaseNames' to 'Camel Case Names'.
Used when generating names from view classes.
@@ -45,30 +49,6 @@ def _camelcase_to_spaces(content):
return ' '.join(content.split('_')).title()
-def get_view_name(cls, suffix=None):
- """
- Return a formatted name for an `APIView` class or `@api_view` function.
- """
- name = cls.__name__
- name = _remove_trailing_string(name, 'View')
- name = _remove_trailing_string(name, 'ViewSet')
- name = _camelcase_to_spaces(name)
- if suffix:
- name += ' ' + suffix
- return name
-
-
-def get_view_description(cls, html=False):
- """
- Return a description for an `APIView` class or `@api_view` function.
- """
- description = cls.__doc__ or ''
- description = _remove_leading_indent(smart_text(description))
- if html:
- return markup_description(description)
- return description
-
-
def markup_description(description):
"""
Apply HTML markup to the given description.
@@ -77,4 +57,5 @@ def markup_description(description):
description = apply_markdown(description)
else:
description = escape(description).replace('\n', '<br />')
+ description = '<p>' + description + '</p>'
return mark_safe(description)
diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py
new file mode 100644
index 00000000..d773952d
--- /dev/null
+++ b/rest_framework/utils/html.py
@@ -0,0 +1,88 @@
+"""
+Helpers for dealing with HTML input.
+"""
+import re
+from django.utils.datastructures import MultiValueDict
+
+
+def is_html_input(dictionary):
+ # MultiDict type datastructures are used to represent HTML form input,
+ # which may have more than one value for each key.
+ return hasattr(dictionary, 'getlist')
+
+
+def parse_html_list(dictionary, prefix=''):
+ """
+ Used to suport list values in HTML forms.
+ Supports lists of primitives and/or dictionaries.
+
+ * List of primitives.
+
+ {
+ '[0]': 'abc',
+ '[1]': 'def',
+ '[2]': 'hij'
+ }
+ -->
+ [
+ 'abc',
+ 'def',
+ 'hij'
+ ]
+
+ * List of dictionaries.
+
+ {
+ '[0]foo': 'abc',
+ '[0]bar': 'def',
+ '[1]foo': 'hij',
+ '[1]bar': 'klm',
+ }
+ -->
+ [
+ {'foo': 'abc', 'bar': 'def'},
+ {'foo': 'hij', 'bar': 'klm'}
+ ]
+ """
+ ret = {}
+ regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix))
+ for field, value in dictionary.items():
+ match = regex.match(field)
+ if not match:
+ continue
+ index, key = match.groups()
+ index = int(index)
+ if not key:
+ ret[index] = value
+ elif isinstance(ret.get(index), dict):
+ ret[index][key] = value
+ else:
+ ret[index] = MultiValueDict({key: [value]})
+ return [ret[item] for item in sorted(ret.keys())]
+
+
+def parse_html_dict(dictionary, prefix):
+ """
+ Used to support dictionary values in HTML forms.
+
+ {
+ 'profile.username': 'example',
+ 'profile.email': 'example@example.com',
+ }
+ -->
+ {
+ 'profile': {
+ 'username': 'example',
+ 'email': 'example@example.com'
+ }
+ }
+ """
+ ret = {}
+ regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix))
+ for field, value in dictionary.items():
+ match = regex.match(field)
+ if not match:
+ continue
+ key = match.groups()[0]
+ ret[key] = value
+ return ret
diff --git a/rest_framework/utils/humanize_datetime.py b/rest_framework/utils/humanize_datetime.py
new file mode 100644
index 00000000..649f2abc
--- /dev/null
+++ b/rest_framework/utils/humanize_datetime.py
@@ -0,0 +1,47 @@
+"""
+Helper functions that convert strftime formats into more readable representations.
+"""
+from rest_framework import ISO_8601
+
+
+def datetime_formats(formats):
+ format = ', '.join(formats).replace(
+ ISO_8601,
+ 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
+ )
+ return humanize_strptime(format)
+
+
+def date_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
+ return humanize_strptime(format)
+
+
+def time_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
+ return humanize_strptime(format)
+
+
+def humanize_strptime(format_string):
+ # Note that we're missing some of the locale specific mappings that
+ # don't really make sense.
+ mapping = {
+ "%Y": "YYYY",
+ "%y": "YY",
+ "%m": "MM",
+ "%b": "[Jan-Dec]",
+ "%B": "[January-December]",
+ "%d": "DD",
+ "%H": "hh",
+ "%I": "hh", # Requires '%p' to differentiate from '%H'.
+ "%M": "mm",
+ "%S": "ss",
+ "%f": "uuuuuu",
+ "%a": "[Mon-Sun]",
+ "%A": "[Monday-Sunday]",
+ "%p": "[AM|PM]",
+ "%z": "[+HHMM|-HHMM]"
+ }
+ for key, val in mapping.items():
+ format_string = format_string.replace(key, val)
+ return format_string
diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py
index c09c2933..de2931c2 100644
--- a/rest_framework/utils/mediatypes.py
+++ b/rest_framework/utils/mediatypes.py
@@ -5,6 +5,7 @@ See http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7
"""
from __future__ import unicode_literals
from django.http.multipartparser import parse_header
+from django.utils.encoding import python_2_unicode_compatible
from rest_framework import HTTP_HEADER_ENCODING
@@ -43,6 +44,7 @@ def order_by_precedence(media_type_lst):
return [media_types for media_types in ret if media_types]
+@python_2_unicode_compatible
class _MediaType(object):
def __init__(self, media_type_str):
if media_type_str is None:
@@ -57,7 +59,7 @@ class _MediaType(object):
if key != 'q' and other.params.get(key, None) != self.params.get(key, None):
return False
- if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:
+ if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:
return False
if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type:
@@ -74,14 +76,11 @@ class _MediaType(object):
return 0
elif self.sub_type == '*':
return 1
- elif not self.params or self.params.keys() == ['q']:
+ elif not self.params or list(self.params.keys()) == ['q']:
return 2
return 3
def __str__(self):
- return unicode(self).encode('utf-8')
-
- def __unicode__(self):
ret = "%s/%s" % (self.main_type, self.sub_type)
for key, val in self.params.items():
ret += "; %s=%s" % (key, val)
diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py
new file mode 100644
index 00000000..d92bceb9
--- /dev/null
+++ b/rest_framework/utils/model_meta.py
@@ -0,0 +1,169 @@
+"""
+Helper function for returning the field information that is associated
+with a model class. This includes returning all the forward and reverse
+relationships and their associated metadata.
+
+Usage: `get_field_info(model)` returns a `FieldInfo` instance.
+"""
+from collections import namedtuple
+from django.core.exceptions import ImproperlyConfigured
+from django.db import models
+from django.utils import six
+from rest_framework.compat import OrderedDict
+import inspect
+
+
+FieldInfo = namedtuple('FieldResult', [
+ 'pk', # Model field instance
+ 'fields', # Dict of field name -> model field instance
+ 'forward_relations', # Dict of field name -> RelationInfo
+ 'reverse_relations', # Dict of field name -> RelationInfo
+ 'fields_and_pk', # Shortcut for 'pk' + 'fields'
+ 'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
+])
+
+RelationInfo = namedtuple('RelationInfo', [
+ 'model_field',
+ 'related_model',
+ 'to_many',
+ 'has_through_model'
+])
+
+
+def _resolve_model(obj):
+ """
+ Resolve supplied `obj` to a Django model class.
+
+ `obj` must be a Django model class itself, or a string
+ representation of one. Useful in situations like GH #1225 where
+ Django may not have resolved a string-based reference to a model in
+ another model's foreign key definition.
+
+ String representations should have the format:
+ 'appname.ModelName'
+ """
+ if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
+ app_name, model_name = obj.split('.')
+ resolved_model = models.get_model(app_name, model_name)
+ if resolved_model is None:
+ msg = "Django did not return a model for {0}.{1}"
+ raise ImproperlyConfigured(msg.format(app_name, model_name))
+ return resolved_model
+ elif inspect.isclass(obj) and issubclass(obj, models.Model):
+ return obj
+ raise ValueError("{0} is not a Django model".format(obj))
+
+
+def get_field_info(model):
+ """
+ Given a model class, returns a `FieldInfo` instance, which is a
+ `namedtuple`, containing metadata about the various field types on the model
+ including information about their relationships.
+ """
+ opts = model._meta.concrete_model._meta
+
+ pk = _get_pk(opts)
+ fields = _get_fields(opts)
+ forward_relations = _get_forward_relationships(opts)
+ reverse_relations = _get_reverse_relationships(opts)
+ fields_and_pk = _merge_fields_and_pk(pk, fields)
+ relationships = _merge_relationships(forward_relations, reverse_relations)
+
+ return FieldInfo(pk, fields, forward_relations, reverse_relations,
+ fields_and_pk, relationships)
+
+
+def _get_pk(opts):
+ pk = opts.pk
+ while pk.rel and pk.rel.parent_link:
+ # If model is a child via multi-table inheritance, use parent's pk.
+ pk = pk.rel.to._meta.pk
+
+ return pk
+
+
+def _get_fields(opts):
+ fields = OrderedDict()
+ for field in [field for field in opts.fields if field.serialize and not field.rel]:
+ fields[field.name] = field
+
+ return fields
+
+
+def _get_forward_relationships(opts):
+ """
+ Returns an `OrderedDict` of field names to `RelationInfo`.
+ """
+ forward_relations = OrderedDict()
+ for field in [field for field in opts.fields if field.serialize and field.rel]:
+ forward_relations[field.name] = RelationInfo(
+ model_field=field,
+ related_model=_resolve_model(field.rel.to),
+ to_many=False,
+ has_through_model=False
+ )
+
+ # Deal with forward many-to-many relationships.
+ for field in [field for field in opts.many_to_many if field.serialize]:
+ forward_relations[field.name] = RelationInfo(
+ model_field=field,
+ related_model=_resolve_model(field.rel.to),
+ to_many=True,
+ has_through_model=(
+ not field.rel.through._meta.auto_created
+ )
+ )
+
+ return forward_relations
+
+
+def _get_reverse_relationships(opts):
+ """
+ Returns an `OrderedDict` of field names to `RelationInfo`.
+ """
+ # Note that we have a hack here to handle internal API differences for
+ # this internal API across Django 1.7 -> Django 1.8.
+ # See: https://code.djangoproject.com/ticket/24208
+
+ reverse_relations = OrderedDict()
+ for relation in opts.get_all_related_objects():
+ accessor_name = relation.get_accessor_name()
+ related = getattr(relation, 'related_model', relation.model)
+ reverse_relations[accessor_name] = RelationInfo(
+ model_field=None,
+ related_model=related,
+ to_many=relation.field.rel.multiple,
+ has_through_model=False
+ )
+
+ # Deal with reverse many-to-many relationships.
+ for relation in opts.get_all_related_many_to_many_objects():
+ accessor_name = relation.get_accessor_name()
+ related = getattr(relation, 'related_model', relation.model)
+ reverse_relations[accessor_name] = RelationInfo(
+ model_field=None,
+ related_model=related,
+ to_many=True,
+ has_through_model=(
+ (getattr(relation.field.rel, 'through', None) is not None) and
+ not relation.field.rel.through._meta.auto_created
+ )
+ )
+
+ return reverse_relations
+
+
+def _merge_fields_and_pk(pk, fields):
+ fields_and_pk = OrderedDict()
+ fields_and_pk['pk'] = pk
+ fields_and_pk[pk.name] = pk
+ fields_and_pk.update(fields)
+
+ return fields_and_pk
+
+
+def _merge_relationships(forward_relations, reverse_relations):
+ return OrderedDict(
+ list(forward_relations.items()) +
+ list(reverse_relations.items())
+ )
diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py
new file mode 100644
index 00000000..1bfc64c1
--- /dev/null
+++ b/rest_framework/utils/representation.py
@@ -0,0 +1,99 @@
+"""
+Helper functions for creating user-friendly representations
+of serializer classes and serializer fields.
+"""
+from __future__ import unicode_literals
+from django.db import models
+from django.utils.encoding import force_text
+from django.utils.functional import Promise
+from rest_framework.compat import unicode_repr
+import re
+
+
+def manager_repr(value):
+ model = value.model
+ opts = model._meta
+ for _, name, manager in opts.concrete_managers + opts.abstract_managers:
+ if manager == value:
+ return '%s.%s.all()' % (model._meta.object_name, name)
+ return repr(value)
+
+
+def smart_repr(value):
+ if isinstance(value, models.Manager):
+ return manager_repr(value)
+
+ if isinstance(value, Promise) and value._delegate_text:
+ value = force_text(value)
+
+ value = unicode_repr(value)
+
+ # Representations like u'help text'
+ # should simply be presented as 'help text'
+ if value.startswith("u'") and value.endswith("'"):
+ return value[1:]
+
+ # Representations like
+ # <django.core.validators.RegexValidator object at 0x1047af050>
+ # Should be presented as
+ # <django.core.validators.RegexValidator object>
+ value = re.sub(' at 0x[0-9a-f]{4,32}>', '>', value)
+
+ return value
+
+
+def field_repr(field, force_many=False):
+ kwargs = field._kwargs
+ if force_many:
+ kwargs = kwargs.copy()
+ kwargs['many'] = True
+ kwargs.pop('child', None)
+
+ arg_string = ', '.join([smart_repr(val) for val in field._args])
+ kwarg_string = ', '.join([
+ '%s=%s' % (key, smart_repr(val))
+ for key, val in sorted(kwargs.items())
+ ])
+ if arg_string and kwarg_string:
+ arg_string += ', '
+
+ if force_many:
+ class_name = force_many.__class__.__name__
+ else:
+ class_name = field.__class__.__name__
+
+ return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
+
+
+def serializer_repr(serializer, indent, force_many=None):
+ ret = field_repr(serializer, force_many) + ':'
+ indent_str = ' ' * indent
+
+ if force_many:
+ fields = force_many.fields
+ else:
+ fields = serializer.fields
+
+ for field_name, field in fields.items():
+ ret += '\n' + indent_str + field_name + ' = '
+ if hasattr(field, 'fields'):
+ ret += serializer_repr(field, indent + 1)
+ elif hasattr(field, 'child'):
+ ret += list_repr(field, indent + 1)
+ elif hasattr(field, 'child_relation'):
+ ret += field_repr(field.child_relation, force_many=field.child_relation)
+ else:
+ ret += field_repr(field)
+
+ if serializer.validators:
+ ret += '\n' + indent_str + 'class Meta:'
+ ret += '\n' + indent_str + ' validators = ' + smart_repr(serializer.validators)
+
+ return ret
+
+
+def list_repr(serializer, indent):
+ child = serializer.child
+ if hasattr(child, 'fields'):
+ return serializer_repr(serializer, indent, force_many=child)
+ return field_repr(serializer)
diff --git a/rest_framework/utils/serializer_helpers.py b/rest_framework/utils/serializer_helpers.py
new file mode 100644
index 00000000..87bb3ac0
--- /dev/null
+++ b/rest_framework/utils/serializer_helpers.py
@@ -0,0 +1,120 @@
+from __future__ import unicode_literals
+import collections
+from rest_framework.compat import OrderedDict, unicode_to_repr
+
+
+class ReturnDict(OrderedDict):
+ """
+ Return object from `serialier.data` for the `Serializer` class.
+ Includes a backlink to the serializer instance for renderers
+ to use if they need richer field information.
+ """
+ def __init__(self, *args, **kwargs):
+ self.serializer = kwargs.pop('serializer')
+ super(ReturnDict, self).__init__(*args, **kwargs)
+
+ def copy(self):
+ return ReturnDict(self, serializer=self.serializer)
+
+ def __repr__(self):
+ return dict.__repr__(self)
+
+ def __reduce__(self):
+ # Pickling these objects will drop the .serializer backlink,
+ # but preserve the raw data.
+ return (dict, (dict(self),))
+
+
+class ReturnList(list):
+ """
+ Return object from `serialier.data` for the `SerializerList` class.
+ Includes a backlink to the serializer instance for renderers
+ to use if they need richer field information.
+ """
+ def __init__(self, *args, **kwargs):
+ self.serializer = kwargs.pop('serializer')
+ super(ReturnList, self).__init__(*args, **kwargs)
+
+ def __repr__(self):
+ return list.__repr__(self)
+
+ def __reduce__(self):
+ # Pickling these objects will drop the .serializer backlink,
+ # but preserve the raw data.
+ return (list, (list(self),))
+
+
+class BoundField(object):
+ """
+ A field object that also includes `.value` and `.error` properties.
+ Returned when iterating over a serializer instance,
+ providing an API similar to Django forms and form fields.
+ """
+ def __init__(self, field, value, errors, prefix=''):
+ self._field = field
+ self.value = value
+ self.errors = errors
+ self.name = prefix + self.field_name
+
+ def __getattr__(self, attr_name):
+ return getattr(self._field, attr_name)
+
+ @property
+ def _proxy_class(self):
+ return self._field.__class__
+
+ def __repr__(self):
+ return unicode_to_repr('<%s value=%s errors=%s>' % (
+ self.__class__.__name__, self.value, self.errors
+ ))
+
+
+class NestedBoundField(BoundField):
+ """
+ This `BoundField` additionally implements __iter__ and __getitem__
+ in order to support nested bound fields. This class is the type of
+ `BoundField` that is used for serializer fields.
+ """
+ def __iter__(self):
+ for field in self.fields.values():
+ yield self[field.field_name]
+
+ def __getitem__(self, key):
+ field = self.fields[key]
+ value = self.value.get(key) if self.value else None
+ error = self.errors.get(key) if self.errors else None
+ if hasattr(field, 'fields'):
+ return NestedBoundField(field, value, error, prefix=self.name + '.')
+ return BoundField(field, value, error, prefix=self.name + '.')
+
+
+class BindingDict(collections.MutableMapping):
+ """
+ This dict-like object is used to store fields on a serializer.
+
+ This ensures that whenever fields are added to the serializer we call
+ `field.bind()` so that the `field_name` and `parent` attributes
+ can be set correctly.
+ """
+ def __init__(self, serializer):
+ self.serializer = serializer
+ self.fields = OrderedDict()
+
+ def __setitem__(self, key, field):
+ self.fields[key] = field
+ field.bind(field_name=key, parent=self.serializer)
+
+ def __getitem__(self, key):
+ return self.fields[key]
+
+ def __delitem__(self, key):
+ del self.fields[key]
+
+ def __iter__(self):
+ return iter(self.fields)
+
+ def __len__(self):
+ return len(self.fields)
+
+ def __repr__(self):
+ return dict.__repr__(self.fields)
diff --git a/rest_framework/utils/urls.py b/rest_framework/utils/urls.py
new file mode 100644
index 00000000..880ef9ed
--- /dev/null
+++ b/rest_framework/utils/urls.py
@@ -0,0 +1,25 @@
+from django.utils.six.moves.urllib import parse as urlparse
+
+
+def replace_query_param(url, key, val):
+ """
+ Given a URL and a key/val pair, set or replace an item in the query
+ parameters of the URL, and return the new URL.
+ """
+ (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url)
+ query_dict = urlparse.parse_qs(query)
+ query_dict[key] = [val]
+ query = urlparse.urlencode(sorted(list(query_dict.items())), doseq=True)
+ return urlparse.urlunsplit((scheme, netloc, path, query, fragment))
+
+
+def remove_query_param(url, key):
+ """
+ Given a URL and a key/val pair, remove an item in the query
+ parameters of the URL, and return the new URL.
+ """
+ (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url)
+ query_dict = urlparse.parse_qs(query)
+ query_dict.pop(key, None)
+ query = urlparse.urlencode(sorted(list(query_dict.items())), doseq=True)
+ return urlparse.urlunsplit((scheme, netloc, path, query, fragment))