diff options
Diffstat (limited to 'rest_framework/utils')
| -rw-r--r-- | rest_framework/utils/encoders.py | 6 | ||||
| -rw-r--r-- | rest_framework/utils/field_mapping.py | 18 | ||||
| -rw-r--r-- | rest_framework/utils/formatting.py | 5 | ||||
| -rw-r--r-- | rest_framework/utils/model_meta.py | 73 | ||||
| -rw-r--r-- | rest_framework/utils/serializer_helpers.py | 19 | ||||
| -rw-r--r-- | rest_framework/utils/urls.py | 25 | 
6 files changed, 118 insertions, 28 deletions
| diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 0bd24939..2160d18b 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -6,9 +6,11 @@ from django.db.models.query import QuerySet  from django.utils import six, timezone  from django.utils.encoding import force_text  from django.utils.functional import Promise +from rest_framework.compat import total_seconds  import datetime  import decimal  import json +import uuid  class JSONEncoder(json.JSONEncoder): @@ -38,10 +40,12 @@ class JSONEncoder(json.JSONEncoder):                  representation = representation[:12]              return representation          elif isinstance(obj, datetime.timedelta): -            return six.text_type(obj.total_seconds()) +            return six.text_type(total_seconds(obj))          elif isinstance(obj, decimal.Decimal):              # Serializers will coerce decimals to strings by default.              return float(obj) +        elif isinstance(obj, uuid.UUID): +            return six.text_type(obj)          elif isinstance(obj, QuerySet):              return tuple(obj)          elif hasattr(obj, 'tolist'): diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index fca97b4b..c97ec5d0 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -10,6 +10,11 @@ from rest_framework.validators import UniqueValidator  import inspect +NUMERIC_FIELD_TYPES = ( +    models.IntegerField, models.FloatField, models.DecimalField +) + +  class ClassLookupDict(object):      """      Takes a dictionary with classes as keys. @@ -33,6 +38,9 @@ class ClassLookupDict(object):                  return self.mapping[cls]          raise KeyError('Class %s not found in lookup.', cls.__name__) +    def __setitem__(self, key, value): +        self.mapping[key] = value +  def needs_label(model_field, field_name):      """ @@ -80,7 +88,7 @@ def get_field_kwargs(field_name, model_field):          kwargs['decimal_places'] = decimal_places      if isinstance(model_field, models.TextField): -        kwargs['style'] = {'type': 'textarea'} +        kwargs['style'] = {'base_template': 'textarea.html'}      if isinstance(model_field, models.AutoField) or not model_field.editable:          # If this field is read-only, then return early. @@ -106,7 +114,7 @@ def get_field_kwargs(field_name, model_field):      # Ensure that max_length is passed explicitly as a keyword arg,      # rather than as a validator.      max_length = getattr(model_field, 'max_length', None) -    if max_length is not None: +    if max_length is not None and isinstance(model_field, models.CharField):          kwargs['max_length'] = max_length          validator_kwarg = [              validator for validator in validator_kwarg @@ -119,7 +127,7 @@ def get_field_kwargs(field_name, model_field):          validator.limit_value for validator in validator_kwarg          if isinstance(validator, validators.MinLengthValidator)      ), None) -    if min_length is not None: +    if min_length is not None and isinstance(model_field, models.CharField):          kwargs['min_length'] = min_length          validator_kwarg = [              validator for validator in validator_kwarg @@ -132,7 +140,7 @@ def get_field_kwargs(field_name, model_field):          validator.limit_value for validator in validator_kwarg          if isinstance(validator, validators.MaxValueValidator)      ), None) -    if max_value is not None: +    if max_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):          kwargs['max_value'] = max_value          validator_kwarg = [              validator for validator in validator_kwarg @@ -145,7 +153,7 @@ def get_field_kwargs(field_name, model_field):          validator.limit_value for validator in validator_kwarg          if isinstance(validator, validators.MinValueValidator)      ), None) -    if min_value is not None: +    if min_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):          kwargs['min_value'] = min_value          validator_kwarg = [              validator for validator in validator_kwarg diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 470af51b..8b6f005e 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -2,12 +2,10 @@  Utility functions to return a formatted name and description for a given view.  """  from __future__ import unicode_literals -import re -  from django.utils.html import escape  from django.utils.safestring import mark_safe -  from rest_framework.compat import apply_markdown, force_text +import re  def remove_trailing_string(content, trailing): @@ -59,4 +57,5 @@ def markup_description(description):          description = apply_markdown(description)      else:          description = escape(description).replace('\n', '<br />') +        description = '<p>' + description + '</p>'      return mark_safe(description) diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index c98725c6..dd92f8b6 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -24,7 +24,7 @@ FieldInfo = namedtuple('FieldResult', [  RelationInfo = namedtuple('RelationInfo', [      'model_field', -    'related', +    'related_model',      'to_many',      'has_through_model'  ]) @@ -35,7 +35,7 @@ def _resolve_model(obj):      Resolve supplied `obj` to a Django model class.      `obj` must be a Django model class itself, or a string -    representation of one.  Useful in situtations like GH #1225 where +    representation of one.  Useful in situations like GH #1225 where      Django may not have resolved a string-based reference to a model in      another model's foreign key definition. @@ -56,28 +56,49 @@ def _resolve_model(obj):  def get_field_info(model):      """ -    Given a model class, returns a `FieldInfo` instance containing metadata -    about the various field types on the model. +    Given a model class, returns a `FieldInfo` instance, which is a +    `namedtuple`, containing metadata about the various field types on the model +    including information about their relationships.      """      opts = model._meta.concrete_model._meta -    # Deal with the primary key. +    pk = _get_pk(opts) +    fields = _get_fields(opts) +    forward_relations = _get_forward_relationships(opts) +    reverse_relations = _get_reverse_relationships(opts) +    fields_and_pk = _merge_fields_and_pk(pk, fields) +    relationships = _merge_relationships(forward_relations, reverse_relations) + +    return FieldInfo(pk, fields, forward_relations, reverse_relations, +                     fields_and_pk, relationships) + + +def _get_pk(opts):      pk = opts.pk      while pk.rel and pk.rel.parent_link: -        # If model is a child via multitable inheritance, use parent's pk. +        # If model is a child via multi-table inheritance, use parent's pk.          pk = pk.rel.to._meta.pk -    # Deal with regular fields. +    return pk + + +def _get_fields(opts):      fields = OrderedDict()      for field in [field for field in opts.fields if field.serialize and not field.rel]:          fields[field.name] = field -    # Deal with forward relationships. +    return fields + + +def _get_forward_relationships(opts): +    """ +    Returns an `OrderedDict` of field names to `RelationInfo`. +    """      forward_relations = OrderedDict()      for field in [field for field in opts.fields if field.serialize and field.rel]:          forward_relations[field.name] = RelationInfo(              model_field=field, -            related=_resolve_model(field.rel.to), +            related_model=_resolve_model(field.rel.to),              to_many=False,              has_through_model=False          ) @@ -86,20 +107,31 @@ def get_field_info(model):      for field in [field for field in opts.many_to_many if field.serialize]:          forward_relations[field.name] = RelationInfo(              model_field=field, -            related=_resolve_model(field.rel.to), +            related_model=_resolve_model(field.rel.to),              to_many=True,              has_through_model=(                  not field.rel.through._meta.auto_created              )          ) -    # Deal with reverse relationships. +    return forward_relations + + +def _get_reverse_relationships(opts): +    """ +    Returns an `OrderedDict` of field names to `RelationInfo`. +    """ +    # Note that we have a hack here to handle internal API differences for +    # this internal API across Django 1.7 -> Django 1.8. +    # See: https://code.djangoproject.com/ticket/24208 +      reverse_relations = OrderedDict()      for relation in opts.get_all_related_objects():          accessor_name = relation.get_accessor_name() +        related = getattr(relation, 'related_model', relation.model)          reverse_relations[accessor_name] = RelationInfo(              model_field=None, -            related=relation.model, +            related_model=related,              to_many=relation.field.rel.multiple,              has_through_model=False          ) @@ -107,9 +139,10 @@ def get_field_info(model):      # Deal with reverse many-to-many relationships.      for relation in opts.get_all_related_many_to_many_objects():          accessor_name = relation.get_accessor_name() +        related = getattr(relation, 'related_model', relation.model)          reverse_relations[accessor_name] = RelationInfo(              model_field=None, -            related=relation.model, +            related_model=related,              to_many=True,              has_through_model=(                  (getattr(relation.field.rel, 'through', None) is not None) @@ -117,18 +150,20 @@ def get_field_info(model):              )          ) -    # Shortcut that merges both regular fields and the pk, -    # for simplifying regular field lookup. +    return reverse_relations + + +def _merge_fields_and_pk(pk, fields):      fields_and_pk = OrderedDict()      fields_and_pk['pk'] = pk      fields_and_pk[pk.name] = pk      fields_and_pk.update(fields) -    # Shortcut that merges both forward and reverse relationships +    return fields_and_pk -    relations = OrderedDict( + +def _merge_relationships(forward_relations, reverse_relations): +    return OrderedDict(          list(forward_relations.items()) +          list(reverse_relations.items())      ) - -    return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations) diff --git a/rest_framework/utils/serializer_helpers.py b/rest_framework/utils/serializer_helpers.py index 65a04d06..87bb3ac0 100644 --- a/rest_framework/utils/serializer_helpers.py +++ b/rest_framework/utils/serializer_helpers.py @@ -16,6 +16,14 @@ class ReturnDict(OrderedDict):      def copy(self):          return ReturnDict(self, serializer=self.serializer) +    def __repr__(self): +        return dict.__repr__(self) + +    def __reduce__(self): +        # Pickling these objects will drop the .serializer backlink, +        # but preserve the raw data. +        return (dict, (dict(self),)) +  class ReturnList(list):      """ @@ -27,6 +35,14 @@ class ReturnList(list):          self.serializer = kwargs.pop('serializer')          super(ReturnList, self).__init__(*args, **kwargs) +    def __repr__(self): +        return list.__repr__(self) + +    def __reduce__(self): +        # Pickling these objects will drop the .serializer backlink, +        # but preserve the raw data. +        return (list, (list(self),)) +  class BoundField(object):      """ @@ -99,3 +115,6 @@ class BindingDict(collections.MutableMapping):      def __len__(self):          return len(self.fields) + +    def __repr__(self): +        return dict.__repr__(self.fields) diff --git a/rest_framework/utils/urls.py b/rest_framework/utils/urls.py new file mode 100644 index 00000000..880ef9ed --- /dev/null +++ b/rest_framework/utils/urls.py @@ -0,0 +1,25 @@ +from django.utils.six.moves.urllib import parse as urlparse + + +def replace_query_param(url, key, val): +    """ +    Given a URL and a key/val pair, set or replace an item in the query +    parameters of the URL, and return the new URL. +    """ +    (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url) +    query_dict = urlparse.parse_qs(query) +    query_dict[key] = [val] +    query = urlparse.urlencode(sorted(list(query_dict.items())), doseq=True) +    return urlparse.urlunsplit((scheme, netloc, path, query, fragment)) + + +def remove_query_param(url, key): +    """ +    Given a URL and a key/val pair, remove an item in the query +    parameters of the URL, and return the new URL. +    """ +    (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url) +    query_dict = urlparse.parse_qs(query) +    query_dict.pop(key, None) +    query = urlparse.urlencode(sorted(list(query_dict.items())), doseq=True) +    return urlparse.urlunsplit((scheme, netloc, path, query, fragment)) | 
