aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/utils
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/utils')
-rw-r--r--rest_framework/utils/encoders.py6
-rw-r--r--rest_framework/utils/field_mapping.py18
-rw-r--r--rest_framework/utils/formatting.py5
-rw-r--r--rest_framework/utils/model_meta.py73
-rw-r--r--rest_framework/utils/serializer_helpers.py19
-rw-r--r--rest_framework/utils/urls.py25
6 files changed, 118 insertions, 28 deletions
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 0bd24939..2160d18b 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -6,9 +6,11 @@ from django.db.models.query import QuerySet
from django.utils import six, timezone
from django.utils.encoding import force_text
from django.utils.functional import Promise
+from rest_framework.compat import total_seconds
import datetime
import decimal
import json
+import uuid
class JSONEncoder(json.JSONEncoder):
@@ -38,10 +40,12 @@ class JSONEncoder(json.JSONEncoder):
representation = representation[:12]
return representation
elif isinstance(obj, datetime.timedelta):
- return six.text_type(obj.total_seconds())
+ return six.text_type(total_seconds(obj))
elif isinstance(obj, decimal.Decimal):
# Serializers will coerce decimals to strings by default.
return float(obj)
+ elif isinstance(obj, uuid.UUID):
+ return six.text_type(obj)
elif isinstance(obj, QuerySet):
return tuple(obj)
elif hasattr(obj, 'tolist'):
diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py
index fca97b4b..c97ec5d0 100644
--- a/rest_framework/utils/field_mapping.py
+++ b/rest_framework/utils/field_mapping.py
@@ -10,6 +10,11 @@ from rest_framework.validators import UniqueValidator
import inspect
+NUMERIC_FIELD_TYPES = (
+ models.IntegerField, models.FloatField, models.DecimalField
+)
+
+
class ClassLookupDict(object):
"""
Takes a dictionary with classes as keys.
@@ -33,6 +38,9 @@ class ClassLookupDict(object):
return self.mapping[cls]
raise KeyError('Class %s not found in lookup.', cls.__name__)
+ def __setitem__(self, key, value):
+ self.mapping[key] = value
+
def needs_label(model_field, field_name):
"""
@@ -80,7 +88,7 @@ def get_field_kwargs(field_name, model_field):
kwargs['decimal_places'] = decimal_places
if isinstance(model_field, models.TextField):
- kwargs['style'] = {'type': 'textarea'}
+ kwargs['style'] = {'base_template': 'textarea.html'}
if isinstance(model_field, models.AutoField) or not model_field.editable:
# If this field is read-only, then return early.
@@ -106,7 +114,7 @@ def get_field_kwargs(field_name, model_field):
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
- if max_length is not None:
+ if max_length is not None and isinstance(model_field, models.CharField):
kwargs['max_length'] = max_length
validator_kwarg = [
validator for validator in validator_kwarg
@@ -119,7 +127,7 @@ def get_field_kwargs(field_name, model_field):
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinLengthValidator)
), None)
- if min_length is not None:
+ if min_length is not None and isinstance(model_field, models.CharField):
kwargs['min_length'] = min_length
validator_kwarg = [
validator for validator in validator_kwarg
@@ -132,7 +140,7 @@ def get_field_kwargs(field_name, model_field):
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
), None)
- if max_value is not None:
+ if max_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):
kwargs['max_value'] = max_value
validator_kwarg = [
validator for validator in validator_kwarg
@@ -145,7 +153,7 @@ def get_field_kwargs(field_name, model_field):
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
), None)
- if min_value is not None:
+ if min_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):
kwargs['min_value'] = min_value
validator_kwarg = [
validator for validator in validator_kwarg
diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py
index 470af51b..8b6f005e 100644
--- a/rest_framework/utils/formatting.py
+++ b/rest_framework/utils/formatting.py
@@ -2,12 +2,10 @@
Utility functions to return a formatted name and description for a given view.
"""
from __future__ import unicode_literals
-import re
-
from django.utils.html import escape
from django.utils.safestring import mark_safe
-
from rest_framework.compat import apply_markdown, force_text
+import re
def remove_trailing_string(content, trailing):
@@ -59,4 +57,5 @@ def markup_description(description):
description = apply_markdown(description)
else:
description = escape(description).replace('\n', '<br />')
+ description = '<p>' + description + '</p>'
return mark_safe(description)
diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py
index c98725c6..dd92f8b6 100644
--- a/rest_framework/utils/model_meta.py
+++ b/rest_framework/utils/model_meta.py
@@ -24,7 +24,7 @@ FieldInfo = namedtuple('FieldResult', [
RelationInfo = namedtuple('RelationInfo', [
'model_field',
- 'related',
+ 'related_model',
'to_many',
'has_through_model'
])
@@ -35,7 +35,7 @@ def _resolve_model(obj):
Resolve supplied `obj` to a Django model class.
`obj` must be a Django model class itself, or a string
- representation of one. Useful in situtations like GH #1225 where
+ representation of one. Useful in situations like GH #1225 where
Django may not have resolved a string-based reference to a model in
another model's foreign key definition.
@@ -56,28 +56,49 @@ def _resolve_model(obj):
def get_field_info(model):
"""
- Given a model class, returns a `FieldInfo` instance containing metadata
- about the various field types on the model.
+ Given a model class, returns a `FieldInfo` instance, which is a
+ `namedtuple`, containing metadata about the various field types on the model
+ including information about their relationships.
"""
opts = model._meta.concrete_model._meta
- # Deal with the primary key.
+ pk = _get_pk(opts)
+ fields = _get_fields(opts)
+ forward_relations = _get_forward_relationships(opts)
+ reverse_relations = _get_reverse_relationships(opts)
+ fields_and_pk = _merge_fields_and_pk(pk, fields)
+ relationships = _merge_relationships(forward_relations, reverse_relations)
+
+ return FieldInfo(pk, fields, forward_relations, reverse_relations,
+ fields_and_pk, relationships)
+
+
+def _get_pk(opts):
pk = opts.pk
while pk.rel and pk.rel.parent_link:
- # If model is a child via multitable inheritance, use parent's pk.
+ # If model is a child via multi-table inheritance, use parent's pk.
pk = pk.rel.to._meta.pk
- # Deal with regular fields.
+ return pk
+
+
+def _get_fields(opts):
fields = OrderedDict()
for field in [field for field in opts.fields if field.serialize and not field.rel]:
fields[field.name] = field
- # Deal with forward relationships.
+ return fields
+
+
+def _get_forward_relationships(opts):
+ """
+ Returns an `OrderedDict` of field names to `RelationInfo`.
+ """
forward_relations = OrderedDict()
for field in [field for field in opts.fields if field.serialize and field.rel]:
forward_relations[field.name] = RelationInfo(
model_field=field,
- related=_resolve_model(field.rel.to),
+ related_model=_resolve_model(field.rel.to),
to_many=False,
has_through_model=False
)
@@ -86,20 +107,31 @@ def get_field_info(model):
for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo(
model_field=field,
- related=_resolve_model(field.rel.to),
+ related_model=_resolve_model(field.rel.to),
to_many=True,
has_through_model=(
not field.rel.through._meta.auto_created
)
)
- # Deal with reverse relationships.
+ return forward_relations
+
+
+def _get_reverse_relationships(opts):
+ """
+ Returns an `OrderedDict` of field names to `RelationInfo`.
+ """
+ # Note that we have a hack here to handle internal API differences for
+ # this internal API across Django 1.7 -> Django 1.8.
+ # See: https://code.djangoproject.com/ticket/24208
+
reverse_relations = OrderedDict()
for relation in opts.get_all_related_objects():
accessor_name = relation.get_accessor_name()
+ related = getattr(relation, 'related_model', relation.model)
reverse_relations[accessor_name] = RelationInfo(
model_field=None,
- related=relation.model,
+ related_model=related,
to_many=relation.field.rel.multiple,
has_through_model=False
)
@@ -107,9 +139,10 @@ def get_field_info(model):
# Deal with reverse many-to-many relationships.
for relation in opts.get_all_related_many_to_many_objects():
accessor_name = relation.get_accessor_name()
+ related = getattr(relation, 'related_model', relation.model)
reverse_relations[accessor_name] = RelationInfo(
model_field=None,
- related=relation.model,
+ related_model=related,
to_many=True,
has_through_model=(
(getattr(relation.field.rel, 'through', None) is not None)
@@ -117,18 +150,20 @@ def get_field_info(model):
)
)
- # Shortcut that merges both regular fields and the pk,
- # for simplifying regular field lookup.
+ return reverse_relations
+
+
+def _merge_fields_and_pk(pk, fields):
fields_and_pk = OrderedDict()
fields_and_pk['pk'] = pk
fields_and_pk[pk.name] = pk
fields_and_pk.update(fields)
- # Shortcut that merges both forward and reverse relationships
+ return fields_and_pk
- relations = OrderedDict(
+
+def _merge_relationships(forward_relations, reverse_relations):
+ return OrderedDict(
list(forward_relations.items()) +
list(reverse_relations.items())
)
-
- return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations)
diff --git a/rest_framework/utils/serializer_helpers.py b/rest_framework/utils/serializer_helpers.py
index 65a04d06..87bb3ac0 100644
--- a/rest_framework/utils/serializer_helpers.py
+++ b/rest_framework/utils/serializer_helpers.py
@@ -16,6 +16,14 @@ class ReturnDict(OrderedDict):
def copy(self):
return ReturnDict(self, serializer=self.serializer)
+ def __repr__(self):
+ return dict.__repr__(self)
+
+ def __reduce__(self):
+ # Pickling these objects will drop the .serializer backlink,
+ # but preserve the raw data.
+ return (dict, (dict(self),))
+
class ReturnList(list):
"""
@@ -27,6 +35,14 @@ class ReturnList(list):
self.serializer = kwargs.pop('serializer')
super(ReturnList, self).__init__(*args, **kwargs)
+ def __repr__(self):
+ return list.__repr__(self)
+
+ def __reduce__(self):
+ # Pickling these objects will drop the .serializer backlink,
+ # but preserve the raw data.
+ return (list, (list(self),))
+
class BoundField(object):
"""
@@ -99,3 +115,6 @@ class BindingDict(collections.MutableMapping):
def __len__(self):
return len(self.fields)
+
+ def __repr__(self):
+ return dict.__repr__(self.fields)
diff --git a/rest_framework/utils/urls.py b/rest_framework/utils/urls.py
new file mode 100644
index 00000000..880ef9ed
--- /dev/null
+++ b/rest_framework/utils/urls.py
@@ -0,0 +1,25 @@
+from django.utils.six.moves.urllib import parse as urlparse
+
+
+def replace_query_param(url, key, val):
+ """
+ Given a URL and a key/val pair, set or replace an item in the query
+ parameters of the URL, and return the new URL.
+ """
+ (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url)
+ query_dict = urlparse.parse_qs(query)
+ query_dict[key] = [val]
+ query = urlparse.urlencode(sorted(list(query_dict.items())), doseq=True)
+ return urlparse.urlunsplit((scheme, netloc, path, query, fragment))
+
+
+def remove_query_param(url, key):
+ """
+ Given a URL and a key/val pair, remove an item in the query
+ parameters of the URL, and return the new URL.
+ """
+ (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url)
+ query_dict = urlparse.parse_qs(query)
+ query_dict.pop(key, None)
+ query = urlparse.urlencode(sorted(list(query_dict.items())), doseq=True)
+ return urlparse.urlunsplit((scheme, netloc, path, query, fragment))