diff options
| author | Eleni Lixourioti | 2014-11-15 14:27:41 +0000 | 
|---|---|---|
| committer | Eleni Lixourioti | 2014-11-15 14:27:41 +0000 | 
| commit | 1aa77830955dcdf829f65a9001b6b8900dfc8755 (patch) | |
| tree | 1f6d0bea3c0fe720a298b2da177bb91e8a74a19c /rest_framework | |
| parent | afaa52a378705b7f0475d5ece04a2cf49af4b7c2 (diff) | |
| parent | 88008c0a687219e3104d548196915b1068536d74 (diff) | |
| download | django-rest-framework-1aa77830955dcdf829f65a9001b6b8900dfc8755.tar.bz2 | |
Merge branch 'version-3.1' of github.com:tomchristie/django-rest-framework into oauth_as_package
Conflicts:
	.travis.yml
Diffstat (limited to 'rest_framework')
26 files changed, 1857 insertions, 2728 deletions
| diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py index 2e5d6b47..769f6202 100644 --- a/rest_framework/authtoken/migrations/0001_initial.py +++ b/rest_framework/authtoken/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# encoding: utf8 +# -*- coding: utf-8 -*-  from __future__ import unicode_literals  from django.db import models, migrations @@ -15,12 +15,11 @@ class Migration(migrations.Migration):          migrations.CreateModel(              name='Token',              fields=[ -                ('key', models.CharField(max_length=40, serialize=False, primary_key=True)), -                ('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, to_field='id')), +                ('key', models.CharField(primary_key=True, serialize=False, max_length=40)),                  ('created', models.DateTimeField(auto_now_add=True)), +                ('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, related_name='auth_token')),              ],              options={ -                'abstract': False,              },              bases=(models.Model,),          ), diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 99e99ae3..c2c456de 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -19,11 +19,12 @@ class AuthTokenSerializer(serializers.Serializer):                  if not user.is_active:                      msg = _('User account is disabled.')                      raise serializers.ValidationError(msg) -                attrs['user'] = user -                return attrs              else: -                msg = _('Unable to login with provided credentials.') +                msg = _('Unable to log in with provided credentials.')                  raise serializers.ValidationError(msg)          else:              msg = _('Must include "username" and "password"')              raise serializers.ValidationError(msg) + +        attrs['user'] = user +        return attrs diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 7c03cb76..94e6f061 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -18,7 +18,8 @@ class ObtainAuthToken(APIView):      def post(self, request):          serializer = self.serializer_class(data=request.DATA)          if serializer.is_valid(): -            token, created = Token.objects.get_or_create(user=serializer.object['user']) +            user = serializer.validated_data['user'] +            token, created = Token.objects.get_or_create(user=user)              return Response({'token': token.key})          return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index bc5719ef..6c243462 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -39,6 +39,17 @@ except ImportError:      django_filters = None +if django.VERSION >= (1, 6): +    def clean_manytomany_helptext(text): +        return text +else: +    # Up to version 1.5 many to many fields automatically suffix +    # the `help_text` attribute with hardcoded text. +    def clean_manytomany_helptext(text): +        if text.endswith(' Hold down "Control", or "Command" on a Mac, to select more than one.'): +            text = text[:-69] +        return text +  # Django-guardian is optional. Import only if guardian is in INSTALLED_APPS  # Fixes (#1712). We keep the try/except for the test suite.  guardian = None diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 449ba0a2..d28d6e22 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -10,7 +10,6 @@ from __future__ import unicode_literals  from django.utils import six  from rest_framework.views import APIView  import types -import warnings  def api_view(http_method_names): @@ -130,37 +129,3 @@ def list_route(methods=['get'], **kwargs):          func.kwargs = kwargs          return func      return decorator - - -# These are now pending deprecation, in favor of `detail_route` and `list_route`. - -def link(**kwargs): -    """ -    Used to mark a method on a ViewSet that should be routed for detail GET requests. -    """ -    msg = 'link is pending deprecation. Use detail_route instead.' -    warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) - -    def decorator(func): -        func.bind_to_methods = ['get'] -        func.detail = True -        func.kwargs = kwargs -        return func - -    return decorator - - -def action(methods=['post'], **kwargs): -    """ -    Used to mark a method on a ViewSet that should be routed for detail POST requests. -    """ -    msg = 'action is pending deprecation. Use detail_route instead.' -    warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) - -    def decorator(func): -        func.bind_to_methods = methods -        func.detail = True -        func.kwargs = kwargs -        return func - -    return decorator diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index ad52d172..06b5e8a2 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -15,7 +15,7 @@ class APIException(Exception):      Subclasses should provide `.status_code` and `.default_detail` properties.      """      status_code = status.HTTP_500_INTERNAL_SERVER_ERROR -    default_detail = '' +    default_detail = 'A server error occured'      def __init__(self, detail=None):          self.detail = detail or self.default_detail @@ -54,7 +54,7 @@ class MethodNotAllowed(APIException):  class NotAcceptable(APIException):      status_code = status.HTTP_406_NOT_ACCEPTABLE -    default_detail = "Could not satisfy the request's Accept header" +    default_detail = "Could not satisfy the request Accept header"      def __init__(self, detail=None, available_renderers=None):          self.detail = detail or self.default_detail diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 8e15345d..0c78b3fb 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,34 +1,28 @@ -""" -Serializer fields perform validation on incoming data. - -They are very similar to Django's form fields. -""" -from __future__ import unicode_literals - -import copy -import datetime -import inspect -import re -import warnings -from decimal import Decimal, DecimalException -from django import forms +from django.conf import settings  from django.core import validators  from django.core.exceptions import ValidationError -from django.conf import settings -from django.db.models.fields import BLANK_CHOICE_DASH -from django.http import QueryDict -from django.forms import widgets -from django.utils import six, timezone +from django.utils import timezone +from django.utils.dateparse import parse_date, parse_datetime, parse_time  from django.utils.encoding import is_protected_type  from django.utils.translation import ugettext_lazy as _ -from django.utils.datastructures import SortedDict -from django.utils.dateparse import parse_date, parse_datetime, parse_time  from rest_framework import ISO_8601 -from rest_framework.compat import ( -    BytesIO, smart_text, -    force_text, is_non_str_iterable -) +from rest_framework.compat import smart_text  from rest_framework.settings import api_settings +from rest_framework.utils import html, representation, humanize_datetime +import datetime +import decimal +import inspect +import warnings + + +class empty: +    """ +    This class is used to represent no data being provided for a given input +    or output value. + +    It is required because `None` may be a valid input or output value. +    """ +    pass  def is_simple_callable(obj): @@ -47,597 +41,487 @@ def is_simple_callable(obj):      return len_args <= len_defaults -def get_component(obj, attr_name): +def get_attribute(instance, attrs):      """ -    Given an object, and an attribute name, -    return that attribute on the object. +    Similar to Python's built in `getattr(instance, attr)`, +    but takes a list of nested attributes, instead of a single attribute. + +    Also accepts either attribute lookup on objects or dictionary lookups.      """ -    if isinstance(obj, dict): -        val = obj.get(attr_name) -    else: -        val = getattr(obj, attr_name) - -    if is_simple_callable(val): -        return val() -    return val - - -def readable_datetime_formats(formats): -    format = ', '.join(formats).replace( -        ISO_8601, -        'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' -    ) -    return humanize_strptime(format) - - -def readable_date_formats(formats): -    format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') -    return humanize_strptime(format) - - -def readable_time_formats(formats): -    format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') -    return humanize_strptime(format) - - -def humanize_strptime(format_string): -    # Note that we're missing some of the locale specific mappings that -    # don't really make sense. -    mapping = { -        "%Y": "YYYY", -        "%y": "YY", -        "%m": "MM", -        "%b": "[Jan-Dec]", -        "%B": "[January-December]", -        "%d": "DD", -        "%H": "hh", -        "%I": "hh",  # Requires '%p' to differentiate from '%H'. -        "%M": "mm", -        "%S": "ss", -        "%f": "uuuuuu", -        "%a": "[Mon-Sun]", -        "%A": "[Monday-Sunday]", -        "%p": "[AM|PM]", -        "%z": "[+HHMM|-HHMM]" -    } -    for key, val in mapping.items(): -        format_string = format_string.replace(key, val) -    return format_string +    for attr in attrs: +        try: +            instance = getattr(instance, attr) +        except AttributeError as exc: +            try: +                return instance[attr] +            except (KeyError, TypeError): +                raise exc +    return instance -def strip_multiple_choice_msg(help_text): +def set_value(dictionary, keys, value):      """ -    Remove the 'Hold down "control" ...' message that is Django enforces in -    select multiple fields on ModelForms.  (Required for 1.5 and earlier) +    Similar to Python's built in `dictionary[key] = value`, +    but takes a list of nested keys instead of a single key. -    See https://code.djangoproject.com/ticket/9321 +    set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2} +    set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2} +    set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}}      """ -    multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.') -    multiple_choice_msg = force_text(multiple_choice_msg) +    if not keys: +        dictionary.update(value) +        return -    return help_text.replace(multiple_choice_msg, '') +    for key in keys[:-1]: +        if key not in dictionary: +            dictionary[key] = {} +        dictionary = dictionary[key] +    dictionary[keys[-1]] = value -class Field(object): -    read_only = True -    creation_counter = 0 -    empty = '' -    type_name = None -    partial = False -    use_files = False -    form_field_class = forms.CharField -    type_label = 'field' -    widget = None - -    def __init__(self, source=None, label=None, help_text=None): -        self.parent = None - -        self.creation_counter = Field.creation_counter -        Field.creation_counter += 1 -        self.source = source +class SkipField(Exception): +    pass -        if label is not None: -            self.label = smart_text(label) -        else: -            self.label = None -        if help_text is not None: -            self.help_text = strip_multiple_choice_msg(smart_text(help_text)) -        else: -            self.help_text = None +NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' +NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' +NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' +NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' +MISSING_ERROR_MESSAGE = ( +    'ValidationError raised by `{class_name}`, but error key `{key}` does ' +    'not exist in the `error_messages` dictionary.' +) -        self._errors = [] -        self._value = None -        self._name = None -    @property -    def errors(self): -        return self._errors +class Field(object): +    _creation_counter = 0 -    def widget_html(self): -        if not self.widget: -            return '' +    default_error_messages = { +        'required': _('This field is required.') +    } +    default_validators = [] -        attrs = {} -        if 'id' not in self.widget.attrs: -            attrs['id'] = self._name +    def __init__(self, read_only=False, write_only=False, +                 required=None, default=empty, initial=None, source=None, +                 label=None, help_text=None, style=None, +                 error_messages=None, validators=[]): +        self._creation_counter = Field._creation_counter +        Field._creation_counter += 1 -        return self.widget.render(self._name, self._value, attrs=attrs) +        # If `required` is unset, then use `True` unless a default is provided. +        if required is None: +            required = default is empty and not read_only -    def label_tag(self): -        return '<label for="%s">%s:</label>' % (self._name, self.label) +        # Some combinations of keyword arguments do not make sense. +        assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY +        assert not (read_only and required), NOT_READ_ONLY_REQUIRED +        assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT +        assert not (required and default is not empty), NOT_REQUIRED_DEFAULT -    def initialize(self, parent, field_name): -        """ -        Called to set up a field prior to field_to_native or field_from_native. +        self.read_only = read_only +        self.write_only = write_only +        self.required = required +        self.default = default +        self.source = source +        self.initial = initial +        self.label = label +        self.help_text = help_text +        self.style = {} if style is None else style +        self.validators = validators or self.default_validators[:] -        parent - The parent serializer. -        field_name - The name of the field being initialized. -        """ -        self.parent = parent -        self.root = parent.root or parent -        self.context = self.root.context -        self.partial = self.root.partial -        if self.partial: -            self.required = False +        # Collect default error message from self and parent classes +        messages = {} +        for cls in reversed(self.__class__.__mro__): +            messages.update(getattr(cls, 'default_error_messages', {})) +        messages.update(error_messages or {}) +        self.error_messages = messages -    def field_from_native(self, data, files, field_name, into): +    def __new__(cls, *args, **kwargs):          """ -        Given a dictionary and a field name, updates the dictionary `into`, -        with the field and it's deserialized value. +        When a field is instantiated, we store the arguments that were used, +        so that we can present a helpful representation of the object.          """ -        return +        instance = super(Field, cls).__new__(cls) +        instance._args = args +        instance._kwargs = kwargs +        return instance -    def field_to_native(self, obj, field_name): +    def bind(self, field_name, parent, root):          """ -        Given an object and a field name, returns the value that should be -        serialized for that field. +        Setup the context for the field instance.          """ -        if obj is None: -            return self.empty - -        if self.source == '*': -            return self.to_native(obj) +        self.field_name = field_name +        self.parent = parent +        self.root = root +        self.context = parent.context -        source = self.source or field_name -        value = obj +        # `self.label` should deafult to being based on the field name. +        if self.label is None: +            self.label = field_name.replace('_', ' ').capitalize() -        for component in source.split('.'): -            value = get_component(value, component) -            if value is None: -                break +        # self.source should default to being the same as the field name. +        if self.source is None: +            self.source = field_name -        return self.to_native(value) +        # self.source_attrs is a list of attributes that need to be looked up +        # when serializing the instance, or populating the validated data. +        if self.source == '*': +            self.source_attrs = [] +        else: +            self.source_attrs = self.source.split('.') -    def to_native(self, value): +    def get_initial(self):          """ -        Converts the field's value into it's simple representation. +        Return a value to use when the field is being returned as a primative +        value, without any object instance.          """ -        if is_simple_callable(value): -            value = value() - -        if is_protected_type(value): -            return value -        elif (is_non_str_iterable(value) and -              not isinstance(value, (dict, six.string_types))): -            return [self.to_native(item) for item in value] -        elif isinstance(value, dict): -            # Make sure we preserve field ordering, if it exists -            ret = SortedDict() -            for key, val in value.items(): -                ret[key] = self.to_native(val) -            return ret -        return force_text(value) +        return self.initial -    def attributes(self): +    def get_value(self, dictionary):          """ -        Returns a dictionary of attributes to be used when serializing to xml. +        Given the *incoming* primative data, return the value for this field +        that should be validated and transformed to a native value.          """ -        if self.type_name: -            return {'type': self.type_name} -        return {} - -    def metadata(self): -        metadata = SortedDict() -        metadata['type'] = self.type_label -        metadata['required'] = getattr(self, 'required', False) -        optional_attrs = ['read_only', 'label', 'help_text', -                          'min_length', 'max_length'] -        for attr in optional_attrs: -            value = getattr(self, attr, None) -            if value is not None and value != '': -                metadata[attr] = force_text(value, strings_only=True) -        return metadata - - -class WritableField(Field): -    """ -    Base for read/write fields. -    """ -    write_only = False -    default_validators = [] -    default_error_messages = { -        'required': _('This field is required.'), -        'invalid': _('Invalid value.'), -    } -    widget = widgets.TextInput -    default = None - -    def __init__(self, source=None, label=None, help_text=None, -                 read_only=False, write_only=False, required=None, -                 validators=[], error_messages=None, widget=None, -                 default=None, blank=None): - -        super(WritableField, self).__init__(source=source, label=label, help_text=help_text) - -        self.read_only = read_only -        self.write_only = write_only - -        assert not (read_only and write_only), "Cannot set read_only=True and write_only=True" - -        if required is None: -            self.required = not(read_only) -        else: -            assert not (read_only and required), "Cannot set required=True and read_only=True" -            self.required = required +        return dictionary.get(self.field_name, empty) -        messages = {} -        for c in reversed(self.__class__.__mro__): -            messages.update(getattr(c, 'default_error_messages', {})) -        messages.update(error_messages or {}) -        self.error_messages = messages +    def get_attribute(self, instance): +        """ +        Given the *outgoing* object instance, return the value for this field +        that should be returned as a primative value. +        """ +        return get_attribute(instance, self.source_attrs) -        self.validators = self.default_validators + validators -        self.default = default if default is not None else self.default +    def get_default(self): +        """ +        Return the default value to use when validating data if no input +        is provided for this field. -        # Widgets are only used for HTML forms. -        widget = widget or self.widget -        if isinstance(widget, type): -            widget = widget() -        self.widget = widget +        If a default has not been set for this field then this will simply +        return `empty`, indicating that no value should be set in the +        validated data for this field. +        """ +        if self.default is empty: +            raise SkipField() +        return self.default -    def __deepcopy__(self, memo): -        result = copy.copy(self) -        memo[id(self)] = result -        result.validators = self.validators[:] -        return result +    def run_validation(self, data=empty): +        """ +        Validate a simple representation and return the internal value. -    def get_default_value(self): -        if is_simple_callable(self.default): -            return self.default() -        return self.default +        The provided data may be `empty` if no representation was included. +        May return `empty` if the field should not be included in the +        validated data. +        """ +        if data is empty: +            if self.required: +                self.fail('required') +            return self.get_default() -    def validate(self, value): -        if value in validators.EMPTY_VALUES and self.required: -            raise ValidationError(self.error_messages['required']) +        value = self.to_internal_value(data) +        self.run_validators(value) +        return value      def run_validators(self, value): -        if value in validators.EMPTY_VALUES: +        if value in (None, '', [], (), {}):              return +          errors = [] -        for v in self.validators: +        for validator in self.validators:              try: -                v(value) -            except ValidationError as e: -                if hasattr(e, 'code') and e.code in self.error_messages: -                    message = self.error_messages[e.code] -                    if e.params: -                        message = message % e.params -                    errors.append(message) -                else: -                    errors.extend(e.messages) +                validator(value) +            except ValidationError as exc: +                errors.extend(exc.messages)          if errors:              raise ValidationError(errors) -    def field_to_native(self, obj, field_name): -        if self.write_only: -            return None -        return super(WritableField, self).field_to_native(obj, field_name) - -    def field_from_native(self, data, files, field_name, into): +    def to_internal_value(self, data):          """ -        Given a dictionary and a field name, updates the dictionary `into`, -        with the field and it's deserialized value. +        Transform the *incoming* primative data into a native value.          """ -        if self.read_only: -            return - -        try: -            data = data or {} -            if self.use_files: -                files = files or {} -                try: -                    native = files[field_name] -                except KeyError: -                    native = data[field_name] -            else: -                native = data[field_name] -        except KeyError: -            if self.default is not None and not self.partial: -                # Note: partial updates shouldn't set defaults -                native = self.get_default_value() -            else: -                if self.required: -                    raise ValidationError(self.error_messages['required']) -                return - -        value = self.from_native(native) -        if self.source == '*': -            if value: -                into.update(value) -        else: -            self.validate(value) -            self.run_validators(value) -            into[self.source or field_name] = value +        raise NotImplementedError('to_internal_value() must be implemented.') -    def from_native(self, value): +    def to_representation(self, value):          """ -        Reverts a simple representation back to the field's value. +        Transform the *outgoing* native value into primative data.          """ -        return value - +        raise NotImplementedError('to_representation() must be implemented.') -class ModelField(WritableField): -    """ -    A generic field that can be used against an arbitrary model field. -    """ -    def __init__(self, *args, **kwargs): +    def fail(self, key, **kwargs): +        """ +        A helper method that simply raises a validation error. +        """          try: -            self.model_field = kwargs.pop('model_field') +            msg = self.error_messages[key]          except KeyError: -            raise ValueError("ModelField requires 'model_field' kwarg") - -        self.min_length = kwargs.pop('min_length', -                                     getattr(self.model_field, 'min_length', None)) -        self.max_length = kwargs.pop('max_length', -                                     getattr(self.model_field, 'max_length', None)) -        self.min_value = kwargs.pop('min_value', -                                    getattr(self.model_field, 'min_value', None)) -        self.max_value = kwargs.pop('max_value', -                                    getattr(self.model_field, 'max_value', None)) - -        super(ModelField, self).__init__(*args, **kwargs) - -        if self.min_length is not None: -            self.validators.append(validators.MinLengthValidator(self.min_length)) -        if self.max_length is not None: -            self.validators.append(validators.MaxLengthValidator(self.max_length)) -        if self.min_value is not None: -            self.validators.append(validators.MinValueValidator(self.min_value)) -        if self.max_value is not None: -            self.validators.append(validators.MaxValueValidator(self.max_value)) +            class_name = self.__class__.__name__ +            msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) +            raise AssertionError(msg) +        raise ValidationError(msg.format(**kwargs)) -    def from_native(self, value): -        rel = getattr(self.model_field, "rel", None) -        if rel is not None: -            return rel.to._meta.get_field(rel.field_name).to_python(value) -        else: -            return self.model_field.to_python(value) - -    def field_to_native(self, obj, field_name): -        value = self.model_field._get_val_from_obj(obj) -        if is_protected_type(value): -            return value -        return self.model_field.value_to_string(obj) +    def __repr__(self): +        return representation.field_repr(self) -    def attributes(self): -        return { -            "type": self.model_field.get_internal_type() -        } +# Boolean types... -# Typed Fields - -class BooleanField(WritableField): -    type_name = 'BooleanField' -    type_label = 'boolean' -    form_field_class = forms.BooleanField -    widget = widgets.CheckboxInput +class BooleanField(Field):      default_error_messages = { -        'invalid': _("'%s' value must be either True or False."), +        'invalid': _('`{input}` is not a valid boolean.')      } -    empty = False - -    def field_from_native(self, data, files, field_name, into): -        # HTML checkboxes do not explicitly represent unchecked as `False` -        # we deal with that here... -        if isinstance(data, QueryDict) and self.default is None: -            self.default = False - -        return super(BooleanField, self).field_from_native( -            data, files, field_name, into -        ) +    TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) +    FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) + +    def get_value(self, dictionary): +        if html.is_html_input(dictionary): +            # HTML forms do not send a `False` value on an empty checkbox, +            # so we override the default empty value to be False. +            return dictionary.get(self.field_name, False) +        return dictionary.get(self.field_name, empty) + +    def to_internal_value(self, data): +        if data in self.TRUE_VALUES: +            return True +        elif data in self.FALSE_VALUES: +            return False +        self.fail('invalid', input=data) -    def from_native(self, value): -        if value in ('true', 't', 'True', '1'): +    def to_representation(self, value): +        if value is None: +            return None +        if value in self.TRUE_VALUES:              return True -        if value in ('false', 'f', 'False', '0'): +        elif value in self.FALSE_VALUES:              return False          return bool(value) -class CharField(WritableField): -    type_name = 'CharField' -    type_label = 'string' -    form_field_class = forms.CharField +# String types... -    def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs): -        self.max_length, self.min_length = max_length, min_length -        self.allow_none = allow_none -        super(CharField, self).__init__(*args, **kwargs) -        if min_length is not None: -            self.validators.append(validators.MinLengthValidator(min_length)) -        if max_length is not None: -            self.validators.append(validators.MaxLengthValidator(max_length)) +class CharField(Field): +    default_error_messages = { +        'blank': _('This field may not be blank.') +    } -    def from_native(self, value): -        if isinstance(value, six.string_types): -            return value +    def __init__(self, **kwargs): +        self.allow_blank = kwargs.pop('allow_blank', False) +        self.max_length = kwargs.pop('max_length', None) +        self.min_length = kwargs.pop('min_length', None) +        super(CharField, self).__init__(**kwargs) + +    def to_internal_value(self, data): +        if data == '' and not self.allow_blank: +            self.fail('blank') +        if data is None: +            return None +        return str(data) +    def to_representation(self, value):          if value is None: -            if not self.allow_none: -                return '' -            else: -                # Return None explicitly because smart_text(None) == 'None'. See #1834 for details -                return None +            return None +        return str(value) + -        return smart_text(value) +class EmailField(CharField): +    default_error_messages = { +        'invalid': _('Enter a valid email address.') +    } +    default_validators = [validators.validate_email] +    def to_internal_value(self, data): +        if data == '' and not self.allow_blank: +            self.fail('blank') +        if data is None: +            return None +        return str(data).strip() -class URLField(CharField): -    type_name = 'URLField' -    type_label = 'url' +    def to_representation(self, value): +        if value is None: +            return None +        return str(value).strip() -    def __init__(self, **kwargs): -        if 'validators' not in kwargs: -            kwargs['validators'] = [validators.URLValidator()] -        super(URLField, self).__init__(**kwargs) +class RegexField(CharField): +    def __init__(self, regex, **kwargs): +        kwargs['validators'] = ( +            [validators.RegexValidator(regex)] + +            kwargs.get('validators', []) +        ) +        super(RegexField, self).__init__(**kwargs) -class SlugField(CharField): -    type_name = 'SlugField' -    type_label = 'slug' -    form_field_class = forms.SlugField +class SlugField(CharField):      default_error_messages = { -        'invalid': _("Enter a valid 'slug' consisting of letters, numbers," -                     " underscores or hyphens."), +        'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")      }      default_validators = [validators.validate_slug] -    def __init__(self, *args, **kwargs): -        super(SlugField, self).__init__(*args, **kwargs) - -class ChoiceField(WritableField): -    type_name = 'ChoiceField' -    type_label = 'choice' -    form_field_class = forms.ChoiceField -    widget = widgets.Select +class URLField(CharField):      default_error_messages = { -        'invalid_choice': _('Select a valid choice. %(value)s is not one of ' -                            'the available choices.'), +        'invalid': _("Enter a valid URL.")      } +    default_validators = [validators.URLValidator()] -    def __init__(self, choices=(), blank_display_value=None, *args, **kwargs): -        self.empty = kwargs.pop('empty', '') -        super(ChoiceField, self).__init__(*args, **kwargs) -        self.choices = choices -        if not self.required: -            if blank_display_value is None: -                blank_choice = BLANK_CHOICE_DASH -            else: -                blank_choice = [('', blank_display_value)] -            self.choices = blank_choice + self.choices -    def _get_choices(self): -        return self._choices +# Number types... -    def _set_choices(self, value): -        # Setting choices also sets the choices on the widget. -        # choices can be any iterable, but we call list() on it because -        # it will be consumed more than once. -        self._choices = self.widget.choices = list(value) +class IntegerField(Field): +    default_error_messages = { +        'invalid': _('A valid integer is required.') +    } -    choices = property(_get_choices, _set_choices) +    def __init__(self, **kwargs): +        max_value = kwargs.pop('max_value', None) +        min_value = kwargs.pop('min_value', None) +        super(IntegerField, self).__init__(**kwargs) +        if max_value is not None: +            self.validators.append(validators.MaxValueValidator(max_value)) +        if min_value is not None: +            self.validators.append(validators.MinValueValidator(min_value)) -    def metadata(self): -        data = super(ChoiceField, self).metadata() -        data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices] +    def to_internal_value(self, data): +        try: +            data = int(str(data)) +        except (ValueError, TypeError): +            self.fail('invalid')          return data -    def validate(self, value): -        """ -        Validates that the input is in self.choices. -        """ -        super(ChoiceField, self).validate(value) -        if value and not self.valid_value(value): -            raise ValidationError(self.error_messages['invalid_choice'] % {'value': value}) +    def to_representation(self, value): +        if value is None: +            return None +        return int(value) -    def valid_value(self, value): -        """ -        Check to see if the provided value is a valid choice. -        """ -        for k, v in self.choices: -            if isinstance(v, (list, tuple)): -                # This is an optgroup, so look inside the group for options -                for k2, v2 in v: -                    if value == smart_text(k2): -                        return True -            else: -                if value == smart_text(k) or value == k: -                    return True -        return False -    def from_native(self, value): -        value = super(ChoiceField, self).from_native(value) -        if value == self.empty or value in validators.EMPTY_VALUES: -            return self.empty -        return value +class FloatField(Field): +    default_error_messages = { +        'invalid': _("'%s' value must be a float."), +    } + +    def __init__(self, **kwargs): +        max_value = kwargs.pop('max_value', None) +        min_value = kwargs.pop('min_value', None) +        super(FloatField, self).__init__(**kwargs) +        if max_value is not None: +            self.validators.append(validators.MaxValueValidator(max_value)) +        if min_value is not None: +            self.validators.append(validators.MinValueValidator(min_value)) +    def to_internal_value(self, value): +        if value is None: +            return None +        return float(value) -class EmailField(CharField): -    type_name = 'EmailField' -    type_label = 'email' -    form_field_class = forms.EmailField +    def to_representation(self, value): +        if value is None: +            return None +        try: +            return float(value) +        except (TypeError, ValueError): +            self.fail('invalid', value=value) + +class DecimalField(Field):      default_error_messages = { -        'invalid': _('Enter a valid email address.'), +        'invalid': _('Enter a number.'), +        'max_value': _('Ensure this value is less than or equal to {max_value}.'), +        'min_value': _('Ensure this value is greater than or equal to {min_value}.'), +        'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'), +        'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'), +        'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')      } -    default_validators = [validators.validate_email] -    def from_native(self, value): -        ret = super(EmailField, self).from_native(value) -        if ret is None: +    coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING + +    def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, **kwargs): +        self.max_digits = max_digits +        self.decimal_places = decimal_places +        self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string +        super(DecimalField, self).__init__(**kwargs) +        if max_value is not None: +            self.validators.append(validators.MaxValueValidator(max_value)) +        if min_value is not None: +            self.validators.append(validators.MinValueValidator(min_value)) + +    def to_internal_value(self, value): +        """ +        Validates that the input is a decimal number. Returns a Decimal +        instance. Returns None for empty values. Ensures that there are no more +        than max_digits in the number, and no more than decimal_places digits +        after the decimal point. +        """ +        if value in (None, ''):              return None -        return ret.strip() +        value = smart_text(value).strip() +        try: +            value = decimal.Decimal(value) +        except decimal.DecimalException: +            self.fail('invalid') -class RegexField(CharField): -    type_name = 'RegexField' -    type_label = 'regex' -    form_field_class = forms.RegexField +        # Check for NaN. It is the only value that isn't equal to itself, +        # so we can use this to identify NaN values. +        if value != value: +            self.fail('invalid') + +        # Check for infinity and negative infinity. +        if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')): +            self.fail('invalid') -    def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs): -        super(RegexField, self).__init__(max_length, min_length, *args, **kwargs) -        self.regex = regex +        sign, digittuple, exponent = value.as_tuple() +        decimals = abs(exponent) +        # digittuple doesn't include any leading zeros. +        digits = len(digittuple) +        if decimals > digits: +            # We have leading zeros up to or past the decimal point.  Count +            # everything past the decimal point as a digit.  We do not count +            # 0 before the decimal point as a digit since that would mean +            # we would not allow max_digits = decimal_places. +            digits = decimals +        whole_digits = digits - decimals -    def _get_regex(self): -        return self._regex +        if self.max_digits is not None and digits > self.max_digits: +            self.fail('max_digits', max_digits=self.max_digits) +        if self.decimal_places is not None and decimals > self.decimal_places: +            self.fail('max_decimal_places', max_decimal_places=self.decimal_places) +        if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): +            self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places) -    def _set_regex(self, regex): -        if isinstance(regex, six.string_types): -            regex = re.compile(regex) -        self._regex = regex -        if hasattr(self, '_regex_validator') and self._regex_validator in self.validators: -            self.validators.remove(self._regex_validator) -        self._regex_validator = validators.RegexValidator(regex=regex) -        self.validators.append(self._regex_validator) +        return value -    regex = property(_get_regex, _set_regex) +    def to_representation(self, value): +        if isinstance(value, decimal.Decimal): +            context = decimal.getcontext().copy() +            context.prec = self.max_digits +            quantized = value.quantize( +                decimal.Decimal('.1') ** self.decimal_places, +                context=context +            ) +            if not self.coerce_to_string: +                return quantized +            return '{0:f}'.format(quantized) + +        if not self.coerce_to_string: +            return value +        return '%.*f' % (self.max_decimal_places, value) -class DateField(WritableField): -    type_name = 'DateField' -    type_label = 'date' -    widget = widgets.DateInput -    form_field_class = forms.DateField +# Date & time fields... +class DateField(Field):      default_error_messages = { -        'invalid': _("Date has wrong format. Use one of these formats instead: %s"), +        'invalid': _('Date has wrong format. Use one of these formats instead: {format}'),      } -    empty = None -    input_formats = api_settings.DATE_INPUT_FORMATS      format = api_settings.DATE_FORMAT +    input_formats = api_settings.DATE_INPUT_FORMATS -    def __init__(self, input_formats=None, format=None, *args, **kwargs): -        self.input_formats = input_formats if input_formats is not None else self.input_formats +    def __init__(self, format=None, input_formats=None, *args, **kwargs):          self.format = format if format is not None else self.format +        self.input_formats = input_formats if input_formats is not None else self.input_formats          super(DateField, self).__init__(*args, **kwargs) -    def from_native(self, value): -        if value in validators.EMPTY_VALUES: +    def to_internal_value(self, value): +        if value in (None, ''):              return None          if isinstance(value, datetime.datetime): @@ -647,6 +531,7 @@ class DateField(WritableField):                  default_timezone = timezone.get_default_timezone()                  value = timezone.make_naive(value, default_timezone)              return value.date() +          if isinstance(value, datetime.date):              return value @@ -667,10 +552,10 @@ class DateField(WritableField):                  else:                      return parsed.date() -        msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats) -        raise ValidationError(msg) +        humanized_format = humanize_datetime.date_formats(self.input_formats) +        self.fail('invalid', format=humanized_format) -    def to_native(self, value): +    def to_representation(self, value):          if value is None or self.format is None:              return value @@ -682,30 +567,25 @@ class DateField(WritableField):          return value.strftime(self.format) -class DateTimeField(WritableField): -    type_name = 'DateTimeField' -    type_label = 'datetime' -    widget = widgets.DateTimeInput -    form_field_class = forms.DateTimeField - +class DateTimeField(Field):      default_error_messages = { -        'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), +        'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'),      } -    empty = None -    input_formats = api_settings.DATETIME_INPUT_FORMATS      format = api_settings.DATETIME_FORMAT +    input_formats = api_settings.DATETIME_INPUT_FORMATS -    def __init__(self, input_formats=None, format=None, *args, **kwargs): -        self.input_formats = input_formats if input_formats is not None else self.input_formats +    def __init__(self, format=None, input_formats=None, *args, **kwargs):          self.format = format if format is not None else self.format +        self.input_formats = input_formats if input_formats is not None else self.input_formats          super(DateTimeField, self).__init__(*args, **kwargs) -    def from_native(self, value): -        if value in validators.EMPTY_VALUES: +    def to_internal_value(self, value): +        if value in (None, ''):              return None          if isinstance(value, datetime.datetime):              return value +          if isinstance(value, datetime.date):              value = datetime.datetime(value.year, value.month, value.day)              if settings.USE_TZ: @@ -737,10 +617,10 @@ class DateTimeField(WritableField):                  else:                      return parsed -        msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats) -        raise ValidationError(msg) +        humanized_format = humanize_datetime.datetime_formats(self.input_formats) +        self.fail('invalid', format=humanized_format) -    def to_native(self, value): +    def to_representation(self, value):          if value is None or self.format is None:              return value @@ -752,26 +632,20 @@ class DateTimeField(WritableField):          return value.strftime(self.format) -class TimeField(WritableField): -    type_name = 'TimeField' -    type_label = 'time' -    widget = widgets.TimeInput -    form_field_class = forms.TimeField - +class TimeField(Field):      default_error_messages = { -        'invalid': _("Time has wrong format. Use one of these formats instead: %s"), +        'invalid': _('Time has wrong format. Use one of these formats instead: {format}'),      } -    empty = None -    input_formats = api_settings.TIME_INPUT_FORMATS      format = api_settings.TIME_FORMAT +    input_formats = api_settings.TIME_INPUT_FORMATS -    def __init__(self, input_formats=None, format=None, *args, **kwargs): -        self.input_formats = input_formats if input_formats is not None else self.input_formats +    def __init__(self, format=None, input_formats=None, *args, **kwargs):          self.format = format if format is not None else self.format +        self.input_formats = input_formats if input_formats is not None else self.input_formats          super(TimeField, self).__init__(*args, **kwargs)      def from_native(self, value): -        if value in validators.EMPTY_VALUES: +        if value in (None, ''):              return None          if isinstance(value, datetime.time): @@ -794,10 +668,10 @@ class TimeField(WritableField):                  else:                      return parsed.time() -        msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats) -        raise ValidationError(msg) +        humanized_format = humanize_datetime.time_formats(self.input_formats) +        self.fail('invalid', format=humanized_format) -    def to_native(self, value): +    def to_representation(self, value):          if value is None or self.format is None:              return value @@ -809,234 +683,147 @@ class TimeField(WritableField):          return value.strftime(self.format) -class IntegerField(WritableField): -    type_name = 'IntegerField' -    type_label = 'integer' -    form_field_class = forms.IntegerField -    empty = 0 +# Choice types... +class ChoiceField(Field):      default_error_messages = { -        'invalid': _('Enter a whole number.'), -        'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), -        'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), +        'invalid_choice': _('`{input}` is not a valid choice.')      } -    def __init__(self, max_value=None, min_value=None, *args, **kwargs): -        self.max_value, self.min_value = max_value, min_value -        super(IntegerField, self).__init__(*args, **kwargs) +    def __init__(self, choices, **kwargs): +        # Allow either single or paired choices style: +        # choices = [1, 2, 3] +        # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] +        pairs = [ +            isinstance(item, (list, tuple)) and len(item) == 2 +            for item in choices +        ] +        if all(pairs): +            self.choices = dict([(key, display_value) for key, display_value in choices]) +        else: +            self.choices = dict([(item, item) for item in choices]) -        if max_value is not None: -            self.validators.append(validators.MaxValueValidator(max_value)) -        if min_value is not None: -            self.validators.append(validators.MinValueValidator(min_value)) +        # Map the string representation of choices to the underlying value. +        # Allows us to deal with eg. integer choices while supporting either +        # integer or string input, but still get the correct datatype out. +        self.choice_strings_to_values = dict([ +            (str(key), key) for key in self.choices.keys() +        ]) -    def from_native(self, value): -        if value in validators.EMPTY_VALUES: -            return None +        super(ChoiceField, self).__init__(**kwargs) +    def to_internal_value(self, data):          try: -            value = int(str(value)) -        except (ValueError, TypeError): -            raise ValidationError(self.error_messages['invalid']) -        return value +            return self.choice_strings_to_values[str(data)] +        except KeyError: +            self.fail('invalid_choice', input=data) +    def to_representation(self, value): +        return value -class FloatField(WritableField): -    type_name = 'FloatField' -    type_label = 'float' -    form_field_class = forms.FloatField -    empty = 0 +class MultipleChoiceField(ChoiceField):      default_error_messages = { -        'invalid': _("'%s' value must be a float."), +        'invalid_choice': _('`{input}` is not a valid choice.'), +        'not_a_list': _('Expected a list of items but got type `{input_type}`')      } -    def from_native(self, value): -        if value in validators.EMPTY_VALUES: -            return None - -        try: -            return float(value) -        except (TypeError, ValueError): -            msg = self.error_messages['invalid'] % value -            raise ValidationError(msg) - - -class DecimalField(WritableField): -    type_name = 'DecimalField' -    type_label = 'decimal' -    form_field_class = forms.DecimalField -    empty = Decimal('0') - -    default_error_messages = { -        'invalid': _('Enter a number.'), -        'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), -        'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), -        'max_digits': _('Ensure that there are no more than %s digits in total.'), -        'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), -        'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') -    } +    def to_internal_value(self, data): +        if not hasattr(data, '__iter__'): +            self.fail('not_a_list', input_type=type(data).__name__) +        return set([ +            super(MultipleChoiceField, self).to_internal_value(item) +            for item in data +        ]) -    def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): -        self.max_value, self.min_value = max_value, min_value -        self.max_digits, self.decimal_places = max_digits, decimal_places -        super(DecimalField, self).__init__(*args, **kwargs) +    def to_representation(self, value): +        return value -        if max_value is not None: -            self.validators.append(validators.MaxValueValidator(max_value)) -        if min_value is not None: -            self.validators.append(validators.MinValueValidator(min_value)) -    def from_native(self, value): -        """ -        Validates that the input is a decimal number. Returns a Decimal -        instance. Returns None for empty values. Ensures that there are no more -        than max_digits in the number, and no more than decimal_places digits -        after the decimal point. -        """ -        if value in validators.EMPTY_VALUES: -            return None -        value = smart_text(value).strip() -        try: -            value = Decimal(value) -        except DecimalException: -            raise ValidationError(self.error_messages['invalid']) -        return value +# File types... -    def validate(self, value): -        super(DecimalField, self).validate(value) -        if value in validators.EMPTY_VALUES: -            return -        # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, -        # since it is never equal to itself. However, NaN is the only value that -        # isn't equal to itself, so we can use this to identify NaN -        if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): -            raise ValidationError(self.error_messages['invalid']) -        sign, digittuple, exponent = value.as_tuple() -        decimals = abs(exponent) -        # digittuple doesn't include any leading zeros. -        digits = len(digittuple) -        if decimals > digits: -            # We have leading zeros up to or past the decimal point.  Count -            # everything past the decimal point as a digit.  We do not count -            # 0 before the decimal point as a digit since that would mean -            # we would not allow max_digits = decimal_places. -            digits = decimals -        whole_digits = digits - decimals +class FileField(Field): +    pass  # TODO -        if self.max_digits is not None and digits > self.max_digits: -            raise ValidationError(self.error_messages['max_digits'] % self.max_digits) -        if self.decimal_places is not None and decimals > self.decimal_places: -            raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) -        if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): -            raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) -        return value +class ImageField(Field): +    pass  # TODO -class FileField(WritableField): -    use_files = True -    type_name = 'FileField' -    type_label = 'file upload' -    form_field_class = forms.FileField -    widget = widgets.FileInput -    default_error_messages = { -        'invalid': _("No file was submitted. Check the encoding type on the form."), -        'missing': _("No file was submitted."), -        'empty': _("The submitted file is empty."), -        'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'), -        'contradiction': _('Please either submit a file or check the clear checkbox, not both.') -    } +# Advanced field types... -    def __init__(self, *args, **kwargs): -        self.max_length = kwargs.pop('max_length', None) -        self.allow_empty_file = kwargs.pop('allow_empty_file', False) -        super(FileField, self).__init__(*args, **kwargs) +class ReadOnlyField(Field): +    """ +    A read-only field that simply returns the field value. -    def from_native(self, data): -        if data in validators.EMPTY_VALUES: -            return None +    If the field is a method with no parameters, the method will be called +    and it's return value used as the representation. -        # UploadedFile objects should have name and size attributes. -        try: -            file_name = data.name -            file_size = data.size -        except AttributeError: -            raise ValidationError(self.error_messages['invalid']) - -        if self.max_length is not None and len(file_name) > self.max_length: -            error_values = {'max': self.max_length, 'length': len(file_name)} -            raise ValidationError(self.error_messages['max_length'] % error_values) -        if not file_name: -            raise ValidationError(self.error_messages['invalid']) -        if not self.allow_empty_file and not file_size: -            raise ValidationError(self.error_messages['empty']) +    For example, the following would call `get_expiry_date()` on the object: -        return data +    class ExampleSerializer(self): +        expiry_date = ReadOnlyField(source='get_expiry_date') +    """ -    def to_native(self, value): -        return value.name +    def __init__(self, **kwargs): +        kwargs['read_only'] = True +        super(ReadOnlyField, self).__init__(**kwargs) +    def to_representation(self, value): +        if is_simple_callable(value): +            return value() +        return value -class ImageField(FileField): -    use_files = True -    type_name = 'ImageField' -    type_label = 'image upload' -    form_field_class = forms.ImageField -    default_error_messages = { -        'invalid_image': _("Upload a valid image. The file you uploaded was " -                           "either not an image or a corrupted image."), -    } +class SerializerMethodField(Field): +    """ +    A read-only field that get its representation from calling a method on the +    parent serializer class. The method called will be of the form +    "get_{field_name}", and should take a single argument, which is the +    object being serialized. -    def from_native(self, data): -        """ -        Checks that the file-upload field data contains a valid image (GIF, JPG, -        PNG, possibly others -- whatever the Python Imaging Library supports). -        """ -        f = super(ImageField, self).from_native(data) -        if f is None: -            return None +    For example: -        from rest_framework.compat import Image -        assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.' +    class ExampleSerializer(self): +        extra_info = SerializerMethodField() -        # We need to get a file object for PIL. We might have a path or we might -        # have to read the data into memory. -        if hasattr(data, 'temporary_file_path'): -            file = data.temporary_file_path() -        else: -            if hasattr(data, 'read'): -                file = BytesIO(data.read()) -            else: -                file = BytesIO(data['content']) +        def get_extra_info(self, obj): +            return ...  # Calculate some data to return. +    """ +    def __init__(self, method_attr=None, **kwargs): +        self.method_attr = method_attr +        kwargs['source'] = '*' +        kwargs['read_only'] = True +        super(SerializerMethodField, self).__init__(**kwargs) -        try: -            # load() could spot a truncated JPEG, but it loads the entire -            # image in memory, which is a DoS vector. See #3848 and #18520. -            # verify() must be called immediately after the constructor. -            Image.open(file).verify() -        except ImportError: -            # Under PyPy, it is possible to import PIL. However, the underlying -            # _imaging C module isn't available, so an ImportError will be -            # raised. Catch and re-raise. -            raise -        except Exception:  # Python Imaging Library doesn't recognize it as an image -            raise ValidationError(self.error_messages['invalid_image']) -        if hasattr(f, 'seek') and callable(f.seek): -            f.seek(0) -        return f +    def to_representation(self, value): +        method_attr = self.method_attr +        if method_attr is None: +            method_attr = 'get_{field_name}'.format(field_name=self.field_name) +        method = getattr(self.parent, method_attr) +        return method(value) -class SerializerMethodField(Field): +class ModelField(Field):      """ -    A field that gets its value by calling a method on the serializer it's attached to. +    A generic field that can be used against an arbitrary model field. + +    This is used by `ModelSerializer` when dealing with custom model fields, +    that do not have a serializer field to be mapped to.      """ +    def __init__(self, model_field, **kwargs): +        self.model_field = model_field +        kwargs['source'] = '*' +        super(ModelField, self).__init__(**kwargs) -    def __init__(self, method_name, *args, **kwargs): -        self.method_name = method_name -        super(SerializerMethodField, self).__init__(*args, **kwargs) +    def to_internal_value(self, data): +        rel = getattr(self.model_field, 'rel', None) +        if rel is not None: +            return rel.to._meta.get_field(rel.field_name).to_python(data) +        return self.model_field.to_python(data) -    def field_to_native(self, obj, field_name): -        value = getattr(self.parent, self.method_name)(obj) -        return self.to_native(value) +    def to_representation(self, obj): +        value = self.model_field._get_val_from_obj(obj) +        if is_protected_type(value): +            return value +        return self.model_field.value_to_string(obj) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index a6f68657..eb6b64ef 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -3,7 +3,8 @@ Generic views that provide commonly needed behaviour.  """  from __future__ import unicode_literals -from django.core.exceptions import ImproperlyConfigured, PermissionDenied +from django.db.models.query import QuerySet +from django.core.exceptions import PermissionDenied  from django.core.paginator import Paginator, InvalidPage  from django.http import Http404  from django.shortcuts import get_object_or_404 as _get_object_or_404 @@ -11,7 +12,6 @@ from django.utils.translation import ugettext as _  from rest_framework import views, mixins, exceptions  from rest_framework.request import clone_request  from rest_framework.settings import api_settings -import warnings  def strict_positive_int(integer_string, cutoff=None): @@ -28,7 +28,7 @@ def strict_positive_int(integer_string, cutoff=None):  def get_object_or_404(queryset, *filter_args, **filter_kwargs):      """ -    Same as Django's standard shortcut, but make sure to raise 404 +    Same as Django's standard shortcut, but make sure to also raise 404      if the filter_kwargs don't match the required types.      """      try: @@ -51,11 +51,6 @@ class GenericAPIView(views.APIView):      queryset = None      serializer_class = None -    # This shortcut may be used instead of setting either or both -    # of the `queryset`/`serializer_class` attributes, although using -    # the explicit style is generally preferred. -    model = None -      # If you want to use object lookups other than pk, set this attribute.      # For more complex lookup requirements override `get_object()`.      lookup_field = 'pk' @@ -71,20 +66,10 @@ class GenericAPIView(views.APIView):      # The filter backend classes to use for queryset filtering      filter_backends = api_settings.DEFAULT_FILTER_BACKENDS -    # The following attributes may be subject to change, +    # The following attribute may be subject to change,      # and should be considered private API. -    model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS      paginator_class = Paginator -    ###################################### -    # These are pending deprecation... - -    pk_url_kwarg = 'pk' -    slug_url_kwarg = 'slug' -    slug_field = 'slug' -    allow_empty = True -    filter_backend = api_settings.FILTER_BACKEND -      def get_serializer_context(self):          """          Extra context provided to the serializer class. @@ -95,18 +80,16 @@ class GenericAPIView(views.APIView):              'view': self          } -    def get_serializer(self, instance=None, data=None, files=None, many=False, -                       partial=False, allow_add_remove=False): +    def get_serializer(self, instance=None, data=None, many=False, partial=False):          """          Return the serializer instance that should be used for validating and          deserializing input, and for serializing output.          """          serializer_class = self.get_serializer_class()          context = self.get_serializer_context() -        return serializer_class(instance, data=data, files=files, -                                many=many, partial=partial, -                                allow_add_remove=allow_add_remove, -                                context=context) +        return serializer_class( +            instance, data=data, many=many, partial=partial, context=context +        )      def get_pagination_serializer(self, page):          """ @@ -120,37 +103,16 @@ class GenericAPIView(views.APIView):          context = self.get_serializer_context()          return pagination_serializer_class(instance=page, context=context) -    def paginate_queryset(self, queryset, page_size=None): +    def paginate_queryset(self, queryset):          """          Paginate a queryset if required, either returning a page object,          or `None` if pagination is not configured for this view.          """ -        deprecated_style = False -        if page_size is not None: -            warnings.warn('The `page_size` parameter to `paginate_queryset()` ' -                          'is deprecated. ' -                          'Note that the return style of this method is also ' -                          'changed, and will simply return a page object ' -                          'when called without a `page_size` argument.', -                          DeprecationWarning, stacklevel=2) -            deprecated_style = True -        else: -            # Determine the required page size. -            # If pagination is not configured, simply return None. -            page_size = self.get_paginate_by() -            if not page_size: -                return None - -        if not self.allow_empty: -            warnings.warn( -                'The `allow_empty` parameter is deprecated. ' -                'To use `allow_empty=False` style behavior, You should override ' -                '`get_queryset()` and explicitly raise a 404 on empty querysets.', -                DeprecationWarning, stacklevel=2 -            ) - -        paginator = self.paginator_class(queryset, page_size, -                                         allow_empty_first_page=self.allow_empty) +        page_size = self.get_paginate_by() +        if not page_size: +            return None + +        paginator = self.paginator_class(queryset, page_size)          page_kwarg = self.kwargs.get(self.page_kwarg)          page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)          page = page_kwarg or page_query_param or 1 @@ -170,8 +132,6 @@ class GenericAPIView(views.APIView):                  'message': str(exc)              }) -        if deprecated_style: -            return (paginator, page, page.object_list, page.has_other_pages())          return page      def filter_queryset(self, queryset): @@ -191,29 +151,12 @@ class GenericAPIView(views.APIView):          """          Returns the list of filter backends that this view requires.          """ -        if self.filter_backends is None: -            filter_backends = [] -        else: -            # Note that we are returning a *copy* of the class attribute, -            # so that it is safe for the view to mutate it if needed. -            filter_backends = list(self.filter_backends) - -        if not filter_backends and self.filter_backend: -            warnings.warn( -                'The `filter_backend` attribute and `FILTER_BACKEND` setting ' -                'are deprecated in favor of a `filter_backends` ' -                'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take ' -                'a *list* of filter backend classes.', -                DeprecationWarning, stacklevel=2 -            ) -            filter_backends = [self.filter_backend] - -        return filter_backends +        return list(self.filter_backends)      # The following methods provide default implementations      # that you may want to override for more complex cases. -    def get_paginate_by(self, queryset=None): +    def get_paginate_by(self):          """          Return the size of pages to use with pagination. @@ -222,11 +165,6 @@ class GenericAPIView(views.APIView):          Otherwise defaults to using `self.paginate_by`.          """ -        if queryset is not None: -            warnings.warn('The `queryset` parameter to `get_paginate_by()` ' -                          'is deprecated.', -                          DeprecationWarning, stacklevel=2) -          if self.paginate_by_param:              try:                  return strict_positive_int( @@ -248,26 +186,13 @@ class GenericAPIView(views.APIView):          (Eg. admins get full serialization, others get basic serialization)          """ -        serializer_class = self.serializer_class -        if serializer_class is not None: -            return serializer_class - -        warnings.warn( -            'The `.model` attribute on view classes is now deprecated in favor ' -            'of the more explicit `serializer_class` and `queryset` attributes.', -            DeprecationWarning, stacklevel=2 -        ) - -        assert self.model is not None, \ -            "'%s' should either include a 'serializer_class' attribute, " \ -            "or use the 'model' attribute as a shortcut for " \ -            "automatically generating a serializer class." \ +        assert self.serializer_class is not None, ( +            "'%s' should either include a `serializer_class` attribute, " +            "or override the `get_serializer_class()` method."              % self.__class__.__name__ +        ) -        class DefaultSerializer(self.model_serializer_class): -            class Meta: -                model = self.model -        return DefaultSerializer +        return self.serializer_class      def get_queryset(self):          """ @@ -284,21 +209,19 @@ class GenericAPIView(views.APIView):          (Eg. return a list of items that is specific to the user)          """ -        if self.queryset is not None: -            return self.queryset._clone() - -        if self.model is not None: -            warnings.warn( -                'The `.model` attribute on view classes is now deprecated in favor ' -                'of the more explicit `serializer_class` and `queryset` attributes.', -                DeprecationWarning, stacklevel=2 -            ) -            return self.model._default_manager.all() +        assert self.queryset is not None, ( +            "'%s' should either include a `queryset` attribute, " +            "or override the `get_queryset()` method." +            % self.__class__.__name__ +        ) -        error_format = "'%s' must define 'queryset' or 'model'" -        raise ImproperlyConfigured(error_format % self.__class__.__name__) +        queryset = self.queryset +        if isinstance(queryset, QuerySet): +            # Ensure queryset is re-evaluated on each request. +            queryset = queryset.all() +        return queryset -    def get_object(self, queryset=None): +    def get_object(self):          """          Returns the object the view is displaying. @@ -306,43 +229,19 @@ class GenericAPIView(views.APIView):          queryset lookups.  Eg if objects are referenced using multiple          keyword arguments in the url conf.          """ -        # Determine the base queryset to use. -        if queryset is None: -            queryset = self.filter_queryset(self.get_queryset()) -        else: -            pass  # Deprecation warning +        queryset = self.filter_queryset(self.get_queryset())          # Perform the lookup filtering. -        # Note that `pk` and `slug` are deprecated styles of lookup filtering.          lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field -        lookup = self.kwargs.get(lookup_url_kwarg, None) -        pk = self.kwargs.get(self.pk_url_kwarg, None) -        slug = self.kwargs.get(self.slug_url_kwarg, None) - -        if lookup is not None: -            filter_kwargs = {self.lookup_field: lookup} -        elif pk is not None and self.lookup_field == 'pk': -            warnings.warn( -                'The `pk_url_kwarg` attribute is deprecated. ' -                'Use the `lookup_field` attribute instead', -                DeprecationWarning -            ) -            filter_kwargs = {'pk': pk} -        elif slug is not None and self.lookup_field == 'pk': -            warnings.warn( -                'The `slug_url_kwarg` attribute is deprecated. ' -                'Use the `lookup_field` attribute instead', -                DeprecationWarning -            ) -            filter_kwargs = {self.slug_field: slug} -        else: -            raise ImproperlyConfigured( -                'Expected view %s to be called with a URL keyword argument ' -                'named "%s". Fix your URL conf, or set the `.lookup_field` ' -                'attribute on the view correctly.' % -                (self.__class__.__name__, self.lookup_field) -            ) +        assert lookup_url_kwarg in self.kwargs, ( +            'Expected view %s to be called with a URL keyword argument ' +            'named "%s". Fix your URL conf, or set the `.lookup_field` ' +            'attribute on the view correctly.' % +            (self.__class__.__name__, lookup_url_kwarg) +        ) + +        filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}          obj = get_object_or_404(queryset, **filter_kwargs)          # May raise a permission denied @@ -355,34 +254,6 @@ class GenericAPIView(views.APIView):      #      # The are not called by GenericAPIView directly,      # but are used by the mixin methods. - -    def pre_save(self, obj): -        """ -        Placeholder method for calling before saving an object. - -        May be used to set attributes on the object that are implicit -        in either the request, or the url. -        """ -        pass - -    def post_save(self, obj, created=False): -        """ -        Placeholder method for calling after saving an object. -        """ -        pass - -    def pre_delete(self, obj): -        """ -        Placeholder method for calling before deleting an object. -        """ -        pass - -    def post_delete(self, obj): -        """ -        Placeholder method for calling after deleting an object. -        """ -        pass -      def metadata(self, request):          """          Return a dictionary of metadata about the view. @@ -540,25 +411,3 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,      def delete(self, request, *args, **kwargs):          return self.destroy(request, *args, **kwargs) - - -# Deprecated classes - -class MultipleObjectAPIView(GenericAPIView): -    def __init__(self, *args, **kwargs): -        warnings.warn( -            'Subclassing `MultipleObjectAPIView` is deprecated. ' -            'You should simply subclass `GenericAPIView` instead.', -            DeprecationWarning, stacklevel=2 -        ) -        super(MultipleObjectAPIView, self).__init__(*args, **kwargs) - - -class SingleObjectAPIView(GenericAPIView): -    def __init__(self, *args, **kwargs): -        warnings.warn( -            'Subclassing `SingleObjectAPIView` is deprecated. ' -            'You should simply subclass `GenericAPIView` instead.', -            DeprecationWarning, stacklevel=2 -        ) -        super(SingleObjectAPIView, self).__init__(*args, **kwargs) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 2cc87eef..14a6b44b 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -6,40 +6,11 @@ which allows mixin classes to be composed in interesting ways.  """  from __future__ import unicode_literals -from django.core.exceptions import ValidationError  from django.http import Http404  from rest_framework import status  from rest_framework.response import Response  from rest_framework.request import clone_request  from rest_framework.settings import api_settings -import warnings - - -def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None): -    """ -    Given a model instance, and an optional pk and slug field, -    return the full list of all other field names on that model. - -    For use when performing full_clean on a model instance, -    so we only clean the required fields. -    """ -    include = [] - -    if pk: -        # Deprecated -        pk_field = obj._meta.pk -        while pk_field.rel: -            pk_field = pk_field.rel.to._meta.pk -        include.append(pk_field.name) - -    if slug_field: -        # Deprecated -        include.append(slug_field) - -    if lookup_field and lookup_field != 'pk': -        include.append(lookup_field) - -    return [field.name for field in obj._meta.fields if field.name not in include]  class CreateModelMixin(object): @@ -47,17 +18,11 @@ class CreateModelMixin(object):      Create a model instance.      """      def create(self, request, *args, **kwargs): -        serializer = self.get_serializer(data=request.DATA, files=request.FILES) - -        if serializer.is_valid(): -            self.pre_save(serializer.object) -            self.object = serializer.save(force_insert=True) -            self.post_save(self.object, created=True) -            headers = self.get_success_headers(serializer.data) -            return Response(serializer.data, status=status.HTTP_201_CREATED, -                            headers=headers) - -        return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) +        serializer = self.get_serializer(data=request.DATA) +        serializer.is_valid(raise_exception=True) +        serializer.save() +        headers = self.get_success_headers(serializer.data) +        return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)      def get_success_headers(self, data):          try: @@ -70,31 +35,13 @@ class ListModelMixin(object):      """      List a queryset.      """ -    empty_error = "Empty list and '%(class_name)s.allow_empty' is False." -      def list(self, request, *args, **kwargs): -        self.object_list = self.filter_queryset(self.get_queryset()) - -        # Default is to allow empty querysets.  This can be altered by setting -        # `.allow_empty = False`, to raise 404 errors on empty querysets. -        if not self.allow_empty and not self.object_list: -            warnings.warn( -                'The `allow_empty` parameter is deprecated. ' -                'To use `allow_empty=False` style behavior, You should override ' -                '`get_queryset()` and explicitly raise a 404 on empty querysets.', -                DeprecationWarning -            ) -            class_name = self.__class__.__name__ -            error_msg = self.empty_error % {'class_name': class_name} -            raise Http404(error_msg) - -        # Switch between paginated or standard style responses -        page = self.paginate_queryset(self.object_list) +        instance = self.filter_queryset(self.get_queryset()) +        page = self.paginate_queryset(instance)          if page is not None:              serializer = self.get_pagination_serializer(page)          else: -            serializer = self.get_serializer(self.object_list, many=True) - +            serializer = self.get_serializer(instance, many=True)          return Response(serializer.data) @@ -103,8 +50,8 @@ class RetrieveModelMixin(object):      Retrieve a model instance.      """      def retrieve(self, request, *args, **kwargs): -        self.object = self.get_object() -        serializer = self.get_serializer(self.object) +        instance = self.get_object() +        serializer = self.get_serializer(instance)          return Response(serializer.data) @@ -114,29 +61,52 @@ class UpdateModelMixin(object):      """      def update(self, request, *args, **kwargs):          partial = kwargs.pop('partial', False) -        self.object = self.get_object_or_none() +        instance = self.get_object() +        serializer = self.get_serializer(instance, data=request.DATA, partial=partial) +        serializer.is_valid(raise_exception=True) +        serializer.save() +        return Response(serializer.data) -        serializer = self.get_serializer(self.object, data=request.DATA, -                                         files=request.FILES, partial=partial) +    def partial_update(self, request, *args, **kwargs): +        kwargs['partial'] = True +        return self.update(request, *args, **kwargs) -        if not serializer.is_valid(): -            return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) -        try: -            self.pre_save(serializer.object) -        except ValidationError as err: -            # full_clean on model instance may be called in pre_save, -            # so we have to handle eventual errors. -            return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST) - -        if self.object is None: -            self.object = serializer.save(force_insert=True) -            self.post_save(self.object, created=True) +class DestroyModelMixin(object): +    """ +    Destroy a model instance. +    """ +    def destroy(self, request, *args, **kwargs): +        instance = self.get_object() +        instance.delete() +        return Response(status=status.HTTP_204_NO_CONTENT) + + +# The AllowPUTAsCreateMixin was previously the default behaviour +# for PUT requests. This has now been removed and must be *explictly* +# included if it is the behavior that you want. +# For more info see: ... + +class AllowPUTAsCreateMixin(object): +    """ +    The following mixin class may be used in order to support PUT-as-create +    behavior for incoming requests. +    """ +    def update(self, request, *args, **kwargs): +        partial = kwargs.pop('partial', False) +        instance = self.get_object_or_none() +        serializer = self.get_serializer(instance, data=request.DATA, partial=partial) +        serializer.is_valid(raise_exception=True) + +        if instance is None: +            lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field +            lookup_value = self.kwargs[lookup_url_kwarg] +            extras = {self.lookup_field: lookup_value} +            serializer.save(extras=extras)              return Response(serializer.data, status=status.HTTP_201_CREATED) -        self.object = serializer.save(force_update=True) -        self.post_save(self.object, created=False) -        return Response(serializer.data, status=status.HTTP_200_OK) +        serializer.save() +        return Response(serializer.data)      def partial_update(self, request, *args, **kwargs):          kwargs['partial'] = True @@ -156,41 +126,3 @@ class UpdateModelMixin(object):                  # PATCH requests where the object does not exist should still                  # return a 404 response.                  raise - -    def pre_save(self, obj): -        """ -        Set any attributes on the object that are implicit in the request. -        """ -        # pk and/or slug attributes are implicit in the URL. -        lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field -        lookup = self.kwargs.get(lookup_url_kwarg, None) -        pk = self.kwargs.get(self.pk_url_kwarg, None) -        slug = self.kwargs.get(self.slug_url_kwarg, None) -        slug_field = slug and self.slug_field or None - -        if lookup: -            setattr(obj, self.lookup_field, lookup) - -        if pk: -            setattr(obj, 'pk', pk) - -        if slug: -            setattr(obj, slug_field, slug) - -        # Ensure we clean the attributes so that we don't eg return integer -        # pk using a string representation, as provided by the url conf kwarg. -        if hasattr(obj, 'full_clean'): -            exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field) -            obj.full_clean(exclude) - - -class DestroyModelMixin(object): -    """ -    Destroy a model instance. -    """ -    def destroy(self, request, *args, **kwargs): -        obj = self.get_object() -        self.pre_delete(obj) -        obj.delete() -        self.post_delete(obj) -        return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 1f5749f1..c5a9270a 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -13,7 +13,7 @@ class NextPageField(serializers.Field):      """      page_field = 'page' -    def to_native(self, value): +    def to_representation(self, value):          if not value.has_next():              return None          page = value.next_page_number() @@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field):      """      page_field = 'page' -    def to_native(self, value): +    def to_representation(self, value):          if not value.has_previous():              return None          page = value.previous_page_number() @@ -37,7 +37,7 @@ class PreviousPageField(serializers.Field):          return replace_query_param(url, self.page_field, page) -class DefaultObjectSerializer(serializers.Field): +class DefaultObjectSerializer(serializers.ReadOnlyField):      """      If no object serializer is specified, then this serializer will be applied      as the default. @@ -49,25 +49,11 @@ class DefaultObjectSerializer(serializers.Field):          super(DefaultObjectSerializer, self).__init__(source=source) -class PaginationSerializerOptions(serializers.SerializerOptions): -    """ -    An object that stores the options that may be provided to a -    pagination serializer by using the inner `Meta` class. - -    Accessible on the instance as `serializer.opts`. -    """ -    def __init__(self, meta): -        super(PaginationSerializerOptions, self).__init__(meta) -        self.object_serializer_class = getattr(meta, 'object_serializer_class', -                                               DefaultObjectSerializer) - -  class BasePaginationSerializer(serializers.Serializer):      """      A base class for pagination serializers to inherit from,      to make implementing custom serializers more easy.      """ -    _options_class = PaginationSerializerOptions      results_field = 'results'      def __init__(self, *args, **kwargs): @@ -76,22 +62,23 @@ class BasePaginationSerializer(serializers.Serializer):          """          super(BasePaginationSerializer, self).__init__(*args, **kwargs)          results_field = self.results_field -        object_serializer = self.opts.object_serializer_class -        if 'context' in kwargs: -            context_kwarg = {'context': kwargs['context']} -        else: -            context_kwarg = {} +        try: +            object_serializer = self.Meta.object_serializer_class +        except AttributeError: +            object_serializer = DefaultObjectSerializer -        self.fields[results_field] = object_serializer(source='object_list', -                                                       many=True, -                                                       **context_kwarg) +        self.fields[results_field] = serializers.ListSerializer( +            child=object_serializer(), +            source='object_list' +        ) +        self.fields[results_field].bind(results_field, self, self)  class PaginationSerializer(BasePaginationSerializer):      """      A default implementation of a pagination serializer.      """ -    count = serializers.Field(source='paginator.count') +    count = serializers.ReadOnlyField(source='paginator.count')      next = NextPageField(source='*')      previous = PreviousPageField(source='*') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index aa4fd3f1..fa02ecf1 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -11,7 +11,7 @@ from django.http import QueryDict  from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser  from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter  from django.utils import six -from rest_framework.compat import etree, yaml, force_text +from rest_framework.compat import etree, yaml, force_text, urlparse  from rest_framework.exceptions import ParseError  from rest_framework import renderers  import json @@ -48,7 +48,7 @@ class JSONParser(BaseParser):      """      media_type = 'application/json' -    renderer_class = renderers.UnicodeJSONRenderer +    renderer_class = renderers.JSONRenderer      def parse(self, stream, media_type=None, parser_context=None):          """ @@ -290,6 +290,22 @@ class FileUploadParser(BaseParser):          try:              meta = parser_context['request'].META              disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8')) -            return force_text(disposition[1]['filename']) +            filename_parm = disposition[1] +            if 'filename*' in filename_parm: +                return self.get_encoded_filename(filename_parm) +            return force_text(filename_parm['filename'])          except (AttributeError, KeyError):              pass + +    def get_encoded_filename(self, filename_parm): +        """ +        Handle encoded filenames per RFC6266. See also: +        http://tools.ietf.org/html/rfc2231#section-4 +        """ +        encoded_filename = force_text(filename_parm['filename*']) +        try: +            charset, lang, filename = encoded_filename.split('\'', 2) +            filename = urlparse.unquote(filename) +        except (ValueError, LookupError): +            filename = force_text(filename_parm['filename']) +        return filename diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 1acbdce2..5aa1f8bd 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,356 +1,112 @@ -""" -Serializer fields that deal with relationships. - -These fields allow you to specify the style that should be used to represent -model relationships, including hyperlinks, primary keys, or slugs. -""" -from __future__ import unicode_literals -from django.core.exceptions import ObjectDoesNotExist, ValidationError -from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch -from django import forms -from django.db.models.fields import BLANK_CHOICE_DASH -from django.forms import widgets -from django.forms.models import ModelChoiceIterator -from django.utils.translation import ugettext_lazy as _ -from rest_framework.fields import Field, WritableField, get_component, is_simple_callable +from rest_framework.compat import smart_text, urlparse +from rest_framework.fields import Field  from rest_framework.reverse import reverse -from rest_framework.compat import urlparse -from rest_framework.compat import smart_text -import warnings - - -# Relational fields - -# Not actually Writable, but subclasses may need to be. -class RelatedField(WritableField): -    """ -    Base class for related model fields. - -    This represents a relationship using the unicode representation of the target. -    """ -    widget = widgets.Select -    many_widget = widgets.SelectMultiple -    form_field_class = forms.ChoiceField -    many_form_field_class = forms.MultipleChoiceField -    null_values = (None, '', 'None') - -    cache_choices = False -    empty_label = None -    read_only = True -    many = False - -    def __init__(self, *args, **kwargs): -        queryset = kwargs.pop('queryset', None) -        self.many = kwargs.pop('many', self.many) -        if self.many: -            self.widget = self.many_widget -            self.form_field_class = self.many_form_field_class - -        kwargs['read_only'] = kwargs.pop('read_only', self.read_only) -        super(RelatedField, self).__init__(*args, **kwargs) - -        if not self.required: -            # Accessed in ModelChoiceIterator django/forms/models.py:1034 -            # If set adds empty choice. -            self.empty_label = BLANK_CHOICE_DASH[0][1] - -        self.queryset = queryset - -    def initialize(self, parent, field_name): -        super(RelatedField, self).initialize(parent, field_name) -        if self.queryset is None and not self.read_only: -            manager = getattr(self.parent.opts.model, self.source or field_name) -            if hasattr(manager, 'related'):  # Forward -                self.queryset = manager.related.model._default_manager.all() -            else:  # Reverse -                self.queryset = manager.field.rel.to._default_manager.all() - -    # We need this stuff to make form choices work... - -    def prepare_value(self, obj): -        return self.to_native(obj) - -    def label_from_instance(self, obj): -        """ -        Return a readable representation for use with eg. select widgets. -        """ -        desc = smart_text(obj) -        ident = smart_text(self.to_native(obj)) -        if desc == ident: -            return desc -        return "%s - %s" % (desc, ident) - -    def _get_queryset(self): -        return self._queryset - -    def _set_queryset(self, queryset): -        self._queryset = queryset -        self.widget.choices = self.choices - -    queryset = property(_get_queryset, _set_queryset) - -    def _get_choices(self): -        # If self._choices is set, then somebody must have manually set -        # the property self.choices. In this case, just return self._choices. -        if hasattr(self, '_choices'): -            return self._choices - -        # Otherwise, execute the QuerySet in self.queryset to determine the -        # choices dynamically. Return a fresh ModelChoiceIterator that has not been -        # consumed. Note that we're instantiating a new ModelChoiceIterator *each* -        # time _get_choices() is called (and, thus, each time self.choices is -        # accessed) so that we can ensure the QuerySet has not been consumed. This -        # construct might look complicated but it allows for lazy evaluation of -        # the queryset. -        return ModelChoiceIterator(self) - -    def _set_choices(self, value): -        # Setting choices also sets the choices on the widget. -        # choices can be any iterable, but we call list() on it because -        # it will be consumed more than once. -        self._choices = self.widget.choices = list(value) - -    choices = property(_get_choices, _set_choices) - -    # Default value handling - -    def get_default_value(self): -        default = super(RelatedField, self).get_default_value() -        if self.many and default is None: -            return [] -        return default - -    # Regular serializer stuff... - -    def field_to_native(self, obj, field_name): -        try: -            if self.source == '*': -                return self.to_native(obj) - -            source = self.source or field_name -            value = obj - -            for component in source.split('.'): -                if value is None: -                    break -                value = get_component(value, component) -        except ObjectDoesNotExist: -            return None +from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured +from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch, Resolver404 +from django.db.models.query import QuerySet +from django.utils.translation import ugettext_lazy as _ -        if value is None: -            return None -        if self.many: -            if is_simple_callable(getattr(value, 'all', None)): -                return [self.to_native(item) for item in value.all()] -            else: -                # Also support non-queryset iterables. -                # This allows us to also support plain lists of related items. -                return [self.to_native(item) for item in value] -        return self.to_native(value) +class RelatedField(Field): +    def __init__(self, **kwargs): +        self.queryset = kwargs.pop('queryset', None) +        assert self.queryset is not None or kwargs.get('read_only', None), ( +            'Relational field must provide a `queryset` argument, ' +            'or set read_only=`True`.' +        ) +        assert not (self.queryset is not None and kwargs.get('read_only', None)), ( +            'Relational fields should not provide a `queryset` argument, ' +            'when setting read_only=`True`.' +        ) +        super(RelatedField, self).__init__(**kwargs) + +    def __new__(cls, *args, **kwargs): +        # We override this method in order to automagically create +        # `ManyRelation` classes instead when `many=True` is set. +        if kwargs.pop('many', False): +            return ManyRelation( +                child_relation=cls(*args, **kwargs), +                read_only=kwargs.get('read_only', False) +            ) +        return super(RelatedField, cls).__new__(cls, *args, **kwargs) -    def field_from_native(self, data, files, field_name, into): -        if self.read_only: -            return +    def get_queryset(self): +        queryset = self.queryset +        if isinstance(queryset, QuerySet): +            # Ensure queryset is re-evaluated whenever used. +            queryset = queryset.all() +        return queryset -        try: -            if self.many: -                try: -                    # Form data -                    value = data.getlist(field_name) -                    if value == [''] or value == []: -                        raise KeyError -                except AttributeError: -                    # Non-form data -                    value = data[field_name] -            else: -                value = data[field_name] -        except KeyError: -            if self.partial: -                return -            value = self.get_default_value() - -        if value in self.null_values: -            if self.required: -                raise ValidationError(self.error_messages['required']) -            into[(self.source or field_name)] = None -        elif self.many: -            into[(self.source or field_name)] = [self.from_native(item) for item in value] -        else: -            into[(self.source or field_name)] = self.from_native(value) - - -# PrimaryKey relationships -class PrimaryKeyRelatedField(RelatedField): +class StringRelatedField(Field):      """ -    Represents a relationship as a pk value. +    A read only field that represents its targets using their +    plain string representation.      """ -    read_only = False - -    default_error_messages = { -        'does_not_exist': _("Invalid pk '%s' - object does not exist."), -        'incorrect_type': _('Incorrect type.  Expected pk value, received %s.'), -    } - -    # TODO: Remove these field hacks... -    def prepare_value(self, obj): -        return self.to_native(obj.pk) - -    def label_from_instance(self, obj): -        """ -        Return a readable representation for use with eg. select widgets. -        """ -        desc = smart_text(obj) -        ident = smart_text(self.to_native(obj.pk)) -        if desc == ident: -            return desc -        return "%s - %s" % (desc, ident) - -    # TODO: Possibly change this to just take `obj`, through prob less performant -    def to_native(self, pk): -        return pk -    def from_native(self, data): -        if self.queryset is None: -            raise Exception('Writable related fields must include a `queryset` argument') +    def __init__(self, **kwargs): +        kwargs['read_only'] = True +        super(StringRelatedField, self).__init__(**kwargs) -        try: -            return self.queryset.get(pk=data) -        except ObjectDoesNotExist: -            msg = self.error_messages['does_not_exist'] % smart_text(data) -            raise ValidationError(msg) -        except (TypeError, ValueError): -            received = type(data).__name__ -            msg = self.error_messages['incorrect_type'] % received -            raise ValidationError(msg) - -    def field_to_native(self, obj, field_name): -        if self.many: -            # To-many relationship - -            queryset = None -            if not self.source: -                # Prefer obj.serializable_value for performance reasons -                try: -                    queryset = obj.serializable_value(field_name) -                except AttributeError: -                    pass -            if queryset is None: -                # RelatedManager (reverse relationship) -                source = self.source or field_name -                queryset = obj -                for component in source.split('.'): -                    if queryset is None: -                        return [] -                    queryset = get_component(queryset, component) - -            # Forward relationship -            if is_simple_callable(getattr(queryset, 'all', None)): -                return [self.to_native(item.pk) for item in queryset.all()] -            else: -                # Also support non-queryset iterables. -                # This allows us to also support plain lists of related items. -                return [self.to_native(item.pk) for item in queryset] - -        # To-one relationship -        try: -            # Prefer obj.serializable_value for performance reasons -            pk = obj.serializable_value(self.source or field_name) -        except AttributeError: -            # RelatedObject (reverse relationship) -            try: -                pk = getattr(obj, self.source or field_name).pk -            except (ObjectDoesNotExist, AttributeError): -                return None - -        # Forward relationship -        return self.to_native(pk) +    def to_representation(self, value): +        return str(value) -# Slug relationships - -class SlugRelatedField(RelatedField): -    """ -    Represents a relationship using a unique field on the target. -    """ -    read_only = False - +class PrimaryKeyRelatedField(RelatedField):      default_error_messages = { -        'does_not_exist': _("Object with %s=%s does not exist."), -        'invalid': _('Invalid value.'), +        'required': 'This field is required.', +        'does_not_exist': "Invalid pk '{pk_value}' - object does not exist.", +        'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.',      } -    def __init__(self, *args, **kwargs): -        self.slug_field = kwargs.pop('slug_field', None) -        assert self.slug_field, 'slug_field is required' -        super(SlugRelatedField, self).__init__(*args, **kwargs) - -    def to_native(self, obj): -        return getattr(obj, self.slug_field) - -    def from_native(self, data): -        if self.queryset is None: -            raise Exception('Writable related fields must include a `queryset` argument') - +    def to_internal_value(self, data):          try: -            return self.queryset.get(**{self.slug_field: data}) +            return self.get_queryset().get(pk=data)          except ObjectDoesNotExist: -            raise ValidationError(self.error_messages['does_not_exist'] % -                                  (self.slug_field, smart_text(data))) +            self.fail('does_not_exist', pk_value=data)          except (TypeError, ValueError): -            msg = self.error_messages['invalid'] -            raise ValidationError(msg) +            self.fail('incorrect_type', data_type=type(data).__name__) +    def to_representation(self, value): +        return value.pk -# Hyperlinked relationships  class HyperlinkedRelatedField(RelatedField): -    """ -    Represents a relationship using hyperlinking. -    """ -    read_only = False      lookup_field = 'pk'      default_error_messages = { -        'no_match': _('Invalid hyperlink - No URL match'), -        'incorrect_match': _('Invalid hyperlink - Incorrect URL match'), -        'configuration_error': _('Invalid hyperlink due to configuration error'), -        'does_not_exist': _("Invalid hyperlink - object does not exist."), -        'incorrect_type': _('Incorrect type.  Expected url string, received %s.'), +        'required': 'This field is required.', +        'no_match': 'Invalid hyperlink - No URL match', +        'incorrect_match': 'Invalid hyperlink - Incorrect URL match.', +        'does_not_exist': 'Invalid hyperlink - Object does not exist.', +        'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',      } -    # These are all deprecated -    pk_url_kwarg = 'pk' -    slug_field = 'slug' -    slug_url_kwarg = None  # Defaults to same as `slug_field` unless overridden - -    def __init__(self, *args, **kwargs): -        try: -            self.view_name = kwargs.pop('view_name') -        except KeyError: -            raise ValueError("Hyperlinked field requires 'view_name' kwarg") - +    def __init__(self, view_name=None, **kwargs): +        assert view_name is not None, 'The `view_name` argument is required.' +        self.view_name = view_name          self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) +        self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)          self.format = kwargs.pop('format', None) -        # These are deprecated -        if 'pk_url_kwarg' in kwargs: -            msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.' -            warnings.warn(msg, DeprecationWarning, stacklevel=2) -        if 'slug_url_kwarg' in kwargs: -            msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.' -            warnings.warn(msg, DeprecationWarning, stacklevel=2) -        if 'slug_field' in kwargs: -            msg = 'slug_field is deprecated. Use lookup_field instead.' -            warnings.warn(msg, DeprecationWarning, stacklevel=2) +        # We include these simply for dependancy injection in tests. +        # We can't add them as class attributes or they would expect an +        # implict `self` argument to be passed. +        self.reverse = reverse +        self.resolve = resolve -        self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) -        self.slug_field = kwargs.pop('slug_field', self.slug_field) -        default_slug_kwarg = self.slug_url_kwarg or self.slug_field -        self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) +        super(HyperlinkedRelatedField, self).__init__(**kwargs) -        super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) +    def get_object(self, view_name, view_args, view_kwargs): +        """ +        Return the object corresponding to a matched URL. + +        Takes the matched URL conf arguments, and should return an +        object instance, or raise an `ObjectDoesNotExist` exception. +        """ +        lookup_value = view_kwargs[self.lookup_url_kwarg] +        lookup_kwargs = {self.lookup_field: lookup_value} +        return self.get_queryset().get(**lookup_kwargs)      def get_url(self, obj, view_name, request, format):          """ @@ -359,176 +115,48 @@ class HyperlinkedRelatedField(RelatedField):          May raise a `NoReverseMatch` if the `view_name` and `lookup_field`          attributes are not configured to correctly match the URL conf.          """ -        lookup_field = getattr(obj, self.lookup_field) -        kwargs = {self.lookup_field: lookup_field} -        try: -            return reverse(view_name, kwargs=kwargs, request=request, format=format) -        except NoReverseMatch: -            pass - -        if self.pk_url_kwarg != 'pk': -            # Only try pk if it has been explicitly set. -            # Otherwise, the default `lookup_field = 'pk'` has us covered. -            pk = obj.pk -            kwargs = {self.pk_url_kwarg: pk} -            try: -                return reverse(view_name, kwargs=kwargs, request=request, format=format) -            except NoReverseMatch: -                pass - -        slug = getattr(obj, self.slug_field, None) -        if slug is not None: -            # Only try slug if it corresponds to an attribute on the object. -            kwargs = {self.slug_url_kwarg: slug} -            try: -                ret = reverse(view_name, kwargs=kwargs, request=request, format=format) -                if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug': -                    # If the lookup succeeds using the default slug params, -                    # then `slug_field` is being used implicitly, and we -                    # we need to warn about the pending deprecation. -                    msg = 'Implicit slug field hyperlinked fields are deprecated.' \ -                          'You should set `lookup_field=slug` on the HyperlinkedRelatedField.' -                    warnings.warn(msg, DeprecationWarning, stacklevel=2) -                return ret -            except NoReverseMatch: -                pass - -        raise NoReverseMatch() - -    def get_object(self, queryset, view_name, view_args, view_kwargs): -        """ -        Return the object corresponding to a matched URL. - -        Takes the matched URL conf arguments, and the queryset, and should -        return an object instance, or raise an `ObjectDoesNotExist` exception. -        """ -        lookup = view_kwargs.get(self.lookup_field, None) -        pk = view_kwargs.get(self.pk_url_kwarg, None) -        slug = view_kwargs.get(self.slug_url_kwarg, None) - -        if lookup is not None: -            filter_kwargs = {self.lookup_field: lookup} -        elif pk is not None: -            filter_kwargs = {'pk': pk} -        elif slug is not None: -            filter_kwargs = {self.slug_field: slug} -        else: -            raise ObjectDoesNotExist() - -        return queryset.get(**filter_kwargs) - -    def to_native(self, obj): -        view_name = self.view_name -        request = self.context.get('request', None) -        format = self.format or self.context.get('format', None) - -        assert request is not None, ( -            "`HyperlinkedRelatedField` requires the request in the serializer " -            "context. Add `context={'request': request}` when instantiating " -            "the serializer." -        ) - -        # If the object has not yet been saved then we cannot hyperlink to it. -        if getattr(obj, 'pk', None) is None: -            return - -        # Return the hyperlink, or error if incorrectly configured. -        try: -            return self.get_url(obj, view_name, request, format) -        except NoReverseMatch: -            msg = ( -                'Could not resolve URL for hyperlinked relationship using ' -                'view name "%s". You may have failed to include the related ' -                'model in your API, or incorrectly configured the ' -                '`lookup_field` attribute on this field.' -            ) -            raise Exception(msg % view_name) +        # Unsaved objects will not yet have a valid URL. +        if obj.pk is None: +            return None -    def from_native(self, value): -        # Convert URL -> model instance pk -        # TODO: Use values_list -        queryset = self.queryset -        if queryset is None: -            raise Exception('Writable related fields must include a `queryset` argument') +        lookup_value = getattr(obj, self.lookup_field) +        kwargs = {self.lookup_url_kwarg: lookup_value} +        return self.reverse(view_name, kwargs=kwargs, request=request, format=format) +    def to_internal_value(self, data):          try: -            http_prefix = value.startswith(('http:', 'https:')) +            http_prefix = data.startswith(('http:', 'https:'))          except AttributeError: -            msg = self.error_messages['incorrect_type'] -            raise ValidationError(msg % type(value).__name__) +            self.fail('incorrect_type', data_type=type(data).__name__)          if http_prefix:              # If needed convert absolute URLs to relative path -            value = urlparse.urlparse(value).path +            data = urlparse.urlparse(data).path              prefix = get_script_prefix() -            if value.startswith(prefix): -                value = '/' + value[len(prefix):] +            if data.startswith(prefix): +                data = '/' + data[len(prefix):]          try: -            match = resolve(value) -        except Exception: -            raise ValidationError(self.error_messages['no_match']) +            match = self.resolve(data) +        except Resolver404: +            self.fail('no_match')          if match.view_name != self.view_name: -            raise ValidationError(self.error_messages['incorrect_match']) +            self.fail('incorrect_match')          try: -            return self.get_object(queryset, match.view_name, -                                   match.args, match.kwargs) +            return self.get_object(match.view_name, match.args, match.kwargs)          except (ObjectDoesNotExist, TypeError, ValueError): -            raise ValidationError(self.error_messages['does_not_exist']) +            self.fail('does_not_exist') - -class HyperlinkedIdentityField(Field): -    """ -    Represents the instance, or a property on the instance, using hyperlinking. -    """ -    lookup_field = 'pk' -    read_only = True - -    # These are all deprecated -    pk_url_kwarg = 'pk' -    slug_field = 'slug' -    slug_url_kwarg = None  # Defaults to same as `slug_field` unless overridden - -    def __init__(self, *args, **kwargs): -        try: -            self.view_name = kwargs.pop('view_name') -        except KeyError: -            msg = "HyperlinkedIdentityField requires 'view_name' argument" -            raise ValueError(msg) - -        self.format = kwargs.pop('format', None) -        lookup_field = kwargs.pop('lookup_field', None) -        self.lookup_field = lookup_field or self.lookup_field - -        # These are deprecated -        if 'pk_url_kwarg' in kwargs: -            msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.' -            warnings.warn(msg, DeprecationWarning, stacklevel=2) -        if 'slug_url_kwarg' in kwargs: -            msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.' -            warnings.warn(msg, DeprecationWarning, stacklevel=2) -        if 'slug_field' in kwargs: -            msg = 'slug_field is deprecated. Use lookup_field instead.' -            warnings.warn(msg, DeprecationWarning, stacklevel=2) - -        self.slug_field = kwargs.pop('slug_field', self.slug_field) -        default_slug_kwarg = self.slug_url_kwarg or self.slug_field -        self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) -        self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) - -        super(HyperlinkedIdentityField, self).__init__(*args, **kwargs) - -    def field_to_native(self, obj, field_name): +    def to_representation(self, value):          request = self.context.get('request', None)          format = self.context.get('format', None) -        view_name = self.view_name          assert request is not None, ( -            "`HyperlinkedIdentityField` requires the request in the serializer" +            "`%s` requires the request in the serializer"              " context. Add `context={'request': request}` when instantiating " -            "the serializer." +            "the serializer." % self.__class__.__name__          )          # By default use whatever format is given for the current context @@ -545,7 +173,7 @@ class HyperlinkedIdentityField(Field):          # Return the hyperlink, or error if incorrectly configured.          try: -            return self.get_url(obj, view_name, request, format) +            return self.get_url(value, self.view_name, request, format)          except NoReverseMatch:              msg = (                  'Could not resolve URL for hyperlinked relationship using ' @@ -553,43 +181,81 @@ class HyperlinkedIdentityField(Field):                  'model in your API, or incorrectly configured the '                  '`lookup_field` attribute on this field.'              ) -            raise Exception(msg % view_name) +            raise ImproperlyConfigured(msg % self.view_name) -    def get_url(self, obj, view_name, request, format): -        """ -        Given an object, return the URL that hyperlinks to the object. -        May raise a `NoReverseMatch` if the `view_name` and `lookup_field` -        attributes are not configured to correctly match the URL conf. -        """ -        lookup_field = getattr(obj, self.lookup_field, None) -        kwargs = {self.lookup_field: lookup_field} +class HyperlinkedIdentityField(HyperlinkedRelatedField): +    """ +    A read-only field that represents the identity URL for an object, itself. -        # Handle unsaved object case -        if lookup_field is None: -            return None +    This is in contrast to `HyperlinkedRelatedField` which represents the +    URL of relationships to other objects. +    """ + +    def __init__(self, view_name=None, **kwargs): +        assert view_name is not None, 'The `view_name` argument is required.' +        kwargs['read_only'] = True +        kwargs['source'] = '*' +        super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs) + + +class SlugRelatedField(RelatedField): +    """ +    A read-write field the represents the target of the relationship +    by a unique 'slug' attribute. +    """ + +    default_error_messages = { +        'does_not_exist': _("Object with {slug_name}={value} does not exist."), +        'invalid': _('Invalid value.'), +    } +    def __init__(self, slug_field=None, **kwargs): +        assert slug_field is not None, 'The `slug_field` argument is required.' +        self.slug_field = slug_field +        super(SlugRelatedField, self).__init__(**kwargs) + +    def to_internal_value(self, data):          try: -            return reverse(view_name, kwargs=kwargs, request=request, format=format) -        except NoReverseMatch: -            pass - -        if self.pk_url_kwarg != 'pk': -            # Only try pk lookup if it has been explicitly set. -            # Otherwise, the default `lookup_field = 'pk'` has us covered. -            kwargs = {self.pk_url_kwarg: obj.pk} -            try: -                return reverse(view_name, kwargs=kwargs, request=request, format=format) -            except NoReverseMatch: -                pass - -        slug = getattr(obj, self.slug_field, None) -        if slug: -            # Only use slug lookup if a slug field exists on the model -            kwargs = {self.slug_url_kwarg: slug} -            try: -                return reverse(view_name, kwargs=kwargs, request=request, format=format) -            except NoReverseMatch: -                pass - -        raise NoReverseMatch() +            return self.get_queryset().get(**{self.slug_field: data}) +        except ObjectDoesNotExist: +            self.fail('does_not_exist', slug_name=self.slug_field, value=smart_text(data)) +        except (TypeError, ValueError): +            self.fail('invalid') + +    def to_representation(self, obj): +        return getattr(obj, self.slug_field) + + +class ManyRelation(Field): +    """ +    Relationships with `many=True` transparently get coerced into instead being +    a ManyRelation with a child relationship. + +    The `ManyRelation` class is responsible for handling iterating through +    the values and passing each one to the child relationship. + +    You shouldn't need to be using this class directly yourself. +    """ + +    def __init__(self, child_relation=None, *args, **kwargs): +        self.child_relation = child_relation +        assert child_relation is not None, '`child_relation` is a required argument.' +        super(ManyRelation, self).__init__(*args, **kwargs) + +    def bind(self, field_name, parent, root): +        # ManyRelation needs to provide the current context to the child relation. +        super(ManyRelation, self).bind(field_name, parent, root) +        self.child_relation.bind(field_name, parent, root) + +    def to_internal_value(self, data): +        return [ +            self.child_relation.to_internal_value(item) +            for item in data +        ] + +    def to_representation(self, obj): +        return [ +            self.child_relation.to_representation(value) +            for value in obj.all() +        ] diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 748ebac9..3bf03e62 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -26,6 +26,10 @@ from rest_framework.utils.breadcrumbs import get_breadcrumbs  from rest_framework import exceptions, status, VERSION +def zero_as_none(value): +    return None if value == 0 else value + +  class BaseRenderer(object):      """      All renderers should extend this class, setting the `media_type` @@ -44,13 +48,13 @@ class BaseRenderer(object):  class JSONRenderer(BaseRenderer):      """      Renderer which serializes to JSON. -    Applies JSON's backslash-u character escaping for non-ascii characters.      """      media_type = 'application/json'      format = 'json'      encoder_class = encoders.JSONEncoder -    ensure_ascii = True +    ensure_ascii = not api_settings.UNICODE_JSON +    compact = api_settings.COMPACT_JSON      # We don't set a charset because JSON is a binary encoding,      # that can be encoded as utf-8, utf-16 or utf-32. @@ -62,9 +66,10 @@ class JSONRenderer(BaseRenderer):          if accepted_media_type:              # If the media type looks like 'application/json; indent=4',              # then pretty print the result. +            # Note that we coerce `indent=0` into `indent=None`.              base_media_type, params = parse_header(accepted_media_type.encode('ascii'))              try: -                return max(min(int(params['indent']), 8), 0) +                return zero_as_none(max(min(int(params['indent']), 8), 0))              except (KeyError, ValueError, TypeError):                  pass @@ -81,10 +86,12 @@ class JSONRenderer(BaseRenderer):          renderer_context = renderer_context or {}          indent = self.get_indent(accepted_media_type, renderer_context) +        separators = (',', ':') if (indent is None and self.compact) else (', ', ': ')          ret = json.dumps(              data, cls=self.encoder_class, -            indent=indent, ensure_ascii=self.ensure_ascii +            indent=indent, ensure_ascii=self.ensure_ascii, +            separators=separators          )          # On python 2.x json.dumps() returns bytestrings if ensure_ascii=True, @@ -96,14 +103,6 @@ class JSONRenderer(BaseRenderer):          return ret -class UnicodeJSONRenderer(JSONRenderer): -    ensure_ascii = False -    """ -    Renderer which serializes to JSON. -    Does *not* apply JSON's character escaping for non-ascii characters. -    """ - -  class JSONPRenderer(JSONRenderer):      """      Renderer which serializes to json, @@ -196,7 +195,7 @@ class YAMLRenderer(BaseRenderer):      format = 'yaml'      encoder = encoders.SafeDumper      charset = 'utf-8' -    ensure_ascii = True +    ensure_ascii = False      def render(self, data, accepted_media_type=None, renderer_context=None):          """ @@ -210,14 +209,6 @@ class YAMLRenderer(BaseRenderer):          return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii) -class UnicodeYAMLRenderer(YAMLRenderer): -    """ -    Renderer which serializes to YAML. -    Does *not* apply character escaping for non-ascii characters. -    """ -    ensure_ascii = False - -  class TemplateHTMLRenderer(BaseRenderer):      """      An HTML renderer for use with templates. @@ -436,13 +427,13 @@ class BrowsableAPIRenderer(BaseRenderer):          if request.method == method:              try:                  data = request.DATA -                files = request.FILES +                # files = request.FILES              except ParseError:                  data = None -                files = None +                # files = None          else:              data = None -            files = None +            # files = None          with override_method(view, request, method) as request:              obj = getattr(view, 'object', None) @@ -458,7 +449,7 @@ class BrowsableAPIRenderer(BaseRenderer):              ):                  return -            serializer = view.get_serializer(instance=obj, data=data, files=files) +            serializer = view.get_serializer(instance=obj, data=data)              serializer.is_valid()              data = serializer.data @@ -579,10 +570,10 @@ class BrowsableAPIRenderer(BaseRenderer):              'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],              'response_headers': response_headers, -            'put_form': self.get_rendered_html_form(view, 'PUT', request), -            'post_form': self.get_rendered_html_form(view, 'POST', request), -            'delete_form': self.get_rendered_html_form(view, 'DELETE', request), -            'options_form': self.get_rendered_html_form(view, 'OPTIONS', request), +            # 'put_form': self.get_rendered_html_form(view, 'PUT', request), +            # 'post_form': self.get_rendered_html_form(view, 'POST', request), +            # 'delete_form': self.get_rendered_html_form(view, 'DELETE', request), +            # 'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),              'raw_data_put_form': raw_data_put_form,              'raw_data_post_form': raw_data_post_form, diff --git a/rest_framework/routers.py b/rest_framework/routers.py index ae56673d..f2d06211 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -19,6 +19,7 @@ import itertools  from collections import namedtuple  from django.conf.urls import patterns, url  from django.core.exceptions import ImproperlyConfigured +from django.core.urlresolvers import NoReverseMatch  from rest_framework import views  from rest_framework.response import Response  from rest_framework.reverse import reverse @@ -284,10 +285,19 @@ class DefaultRouter(SimpleRouter):          class APIRoot(views.APIView):              _ignore_model_permissions = True -            def get(self, request, format=None): +            def get(self, request, *args, **kwargs):                  ret = {}                  for key, url_name in api_root_dict.items(): -                    ret[key] = reverse(url_name, request=request, format=format) +                    try: +                        ret[key] = reverse( +                            url_name, +                            request=request, +                            format=kwargs.get('format', None) +                        ) +                    except NoReverseMatch: +                        # Don't bail out if eg. no list routes exist, only detail routes. +                        continue +                  return Response(ret)          return APIRoot.as_view() diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index be8ad3f2..d2740fc2 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,21 +10,20 @@ python primitives.  2. The process of marshalling between python primitives and request and  response content is handled by parsers and renderers.  """ -from __future__ import unicode_literals -import copy -import datetime -import inspect -import types -from decimal import Decimal -from django.contrib.contenttypes.generic import GenericForeignKey -from django.core.paginator import Page +from django.core.exceptions import ImproperlyConfigured, ValidationError  from django.db import models -from django.forms import widgets  from django.utils import six  from django.utils.datastructures import SortedDict -from django.core.exceptions import ObjectDoesNotExist +from collections import namedtuple +from rest_framework.fields import empty, set_value, Field, SkipField  from rest_framework.settings import api_settings - +from rest_framework.utils import html, model_meta, representation +from rest_framework.utils.field_mapping import ( +    get_url_kwargs, get_field_kwargs, +    get_relation_kwargs, get_nested_relation_kwargs, +    lookup_class +) +import copy  # Note: We do the following so that users of the framework can use this style:  # @@ -37,1107 +36,453 @@ from rest_framework.relations import *  # NOQA  from rest_framework.fields import *  # NOQA -def _resolve_model(obj): -    """ -    Resolve supplied `obj` to a Django model class. +FieldResult = namedtuple('FieldResult', ['field', 'value', 'error']) -    `obj` must be a Django model class itself, or a string -    representation of one.  Useful in situtations like GH #1225 where -    Django may not have resolved a string-based reference to a model in -    another model's foreign key definition. -    String representations should have the format: -        'appname.ModelName' +class BaseSerializer(Field): +    """ +    The BaseSerializer class provides a minimal class which may be used +    for writing custom serializer implementations.      """ -    if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: -        app_name, model_name = obj.split('.') -        return models.get_model(app_name, model_name) -    elif inspect.isclass(obj) and issubclass(obj, models.Model): -        return obj -    else: -        raise ValueError("{0} is not a Django model".format(obj)) - - -def pretty_name(name): -    """Converts 'first_name' to 'First name'""" -    if not name: -        return '' -    return name.replace('_', ' ').capitalize() +    def __init__(self, instance=None, data=None, **kwargs): +        super(BaseSerializer, self).__init__(**kwargs) +        self.instance = instance +        self._initial_data = data -class RelationsList(list): -    _deleted = [] +    def to_internal_value(self, data): +        raise NotImplementedError('`to_internal_value()` must be implemented.') +    def to_representation(self, instance): +        raise NotImplementedError('`to_representation()` must be implemented.') -class NestedValidationError(ValidationError): -    """ -    The default ValidationError behavior is to stringify each item in the list -    if the messages are a list of error messages. +    def update(self, instance, attrs): +        raise NotImplementedError('`update()` must be implemented.') -    In the case of nested serializers, where the parent has many children, -    then the child's `serializer.errors` will be a list of dicts.  In the case -    of a single child, the `serializer.errors` will be a dict. +    def create(self, attrs): +        raise NotImplementedError('`create()` must be implemented.') -    We need to override the default behavior to get properly nested error dicts. -    """ +    def save(self, extras=None): +        attrs = self.validated_data +        if extras is not None: +            attrs = dict(list(attrs.items()) + list(extras.items())) -    def __init__(self, message): -        if isinstance(message, dict): -            self._messages = [message] +        if self.instance is not None: +            self.update(self.instance, attrs)          else: -            self._messages = message +            self.instance = self.create(attrs) -    @property -    def messages(self): -        return self._messages +        return self.instance +    def is_valid(self, raise_exception=False): +        if not hasattr(self, '_validated_data'): +            try: +                self._validated_data = self.to_internal_value(self._initial_data) +            except ValidationError as exc: +                self._validated_data = {} +                self._errors = exc.message_dict +            else: +                self._errors = {} -class DictWithMetadata(dict): -    """ -    A dict-like object, that can have additional properties attached. -    """ -    def __getstate__(self): -        """ -        Used by pickle (e.g., caching). -        Overridden to remove the metadata from the dict, since it shouldn't be -        pickled and may in some instances be unpickleable. -        """ -        return dict(self) +        if self._errors and raise_exception: +            raise ValidationError(self._errors) +        return not bool(self._errors) -class SortedDictWithMetadata(SortedDict): -    """ -    A sorted dict-like object, that can have additional properties attached. -    """ -    def __getstate__(self): -        """ -        Used by pickle (e.g., caching). -        Overriden to remove the metadata from the dict, since it shouldn't be -        pickle and may in some instances be unpickleable. -        """ -        return SortedDict(self).__dict__ +    @property +    def data(self): +        if not hasattr(self, '_data'): +            if self.instance is not None: +                self._data = self.to_representation(self.instance) +            elif self._initial_data is not None: +                self._data = dict([ +                    (field_name, field.get_value(self._initial_data)) +                    for field_name, field in self.fields.items() +                ]) +            else: +                self._data = self.get_initial() +        return self._data +    @property +    def errors(self): +        if not hasattr(self, '_errors'): +            msg = 'You must call `.is_valid()` before accessing `.errors`.' +            raise AssertionError(msg) +        return self._errors -def _is_protected_type(obj): -    """ -    True if the object is a native datatype that does not need to -    be serialized further. -    """ -    return isinstance(obj, ( -        types.NoneType, -        int, long, -        datetime.datetime, datetime.date, datetime.time, -        float, Decimal, -        basestring) -    ) +    @property +    def validated_data(self): +        if not hasattr(self, '_validated_data'): +            msg = 'You must call `.is_valid()` before accessing `.validated_data`.' +            raise AssertionError(msg) +        return self._validated_data -def _get_declared_fields(bases, attrs): +class SerializerMetaclass(type):      """ -    Create a list of serializer field instances from the passed in 'attrs', -    plus any fields on the base classes (in 'bases'). +    This metaclass sets a dictionary named `base_fields` on the class. -    Note that all fields from the base classes are used. +    Any instances of `Field` included as attributes on either the class +    or on any of its superclasses will be include in the +    `base_fields` dictionary.      """ -    fields = [(field_name, attrs.pop(field_name)) -              for field_name, obj in list(six.iteritems(attrs)) -              if isinstance(obj, Field)] -    fields.sort(key=lambda x: x[1].creation_counter) -    # If this class is subclassing another Serializer, add that Serializer's -    # fields.  Note that we loop over the bases in *reverse*. This is necessary -    # in order to maintain the correct order of fields. -    for base in bases[::-1]: -        if hasattr(base, 'base_fields'): -            fields = list(base.base_fields.items()) + fields +    @classmethod +    def _get_declared_fields(cls, bases, attrs): +        fields = [(field_name, attrs.pop(field_name)) +                  for field_name, obj in list(attrs.items()) +                  if isinstance(obj, Field)] +        fields.sort(key=lambda x: x[1]._creation_counter) -    return SortedDict(fields) +        # If this class is subclassing another Serializer, add that Serializer's +        # fields.  Note that we loop over the bases in *reverse*. This is necessary +        # in order to maintain the correct order of fields. +        for base in bases[::-1]: +            if hasattr(base, '_declared_fields'): +                fields = list(base._declared_fields.items()) + fields +        return SortedDict(fields) -class SerializerMetaclass(type):      def __new__(cls, name, bases, attrs): -        attrs['base_fields'] = _get_declared_fields(bases, attrs) +        attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)          return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) -class SerializerOptions(object): -    """ -    Meta class options for Serializer -    """ -    def __init__(self, meta): -        self.depth = getattr(meta, 'depth', 0) -        self.fields = getattr(meta, 'fields', ()) -        self.exclude = getattr(meta, 'exclude', ()) - - -class BaseSerializer(WritableField): -    """ -    This is the Serializer implementation. -    We need to implement it as `BaseSerializer` due to metaclass magicks. -    """ -    class Meta(object): -        pass - -    _options_class = SerializerOptions -    _dict_class = SortedDictWithMetadata - -    def __init__(self, instance=None, data=None, files=None, -                 context=None, partial=False, many=False, -                 allow_add_remove=False, **kwargs): -        super(BaseSerializer, self).__init__(**kwargs) -        self.opts = self._options_class(self.Meta) -        self.parent = None -        self.root = None -        self.partial = partial -        self.many = many -        self.allow_add_remove = allow_add_remove - -        self.context = context or {} - -        self.init_data = data -        self.init_files = files -        self.object = instance -        self.fields = self.get_fields() - -        self._data = None -        self._files = None -        self._errors = None - -        if many and instance is not None and not hasattr(instance, '__iter__'): -            raise ValueError('instance should be a queryset or other iterable with many=True') - -        if allow_add_remove and not many: -            raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') - -    ##### -    # Methods to determine which fields to use when (de)serializing objects. - -    def get_default_fields(self): -        """ -        Return the complete set of default fields for the object, as a dict. -        """ -        return {} - -    def get_fields(self): -        """ -        Returns the complete set of fields for the object as a dict. - -        This will be the set of any explicitly declared fields, -        plus the set of fields returned by get_default_fields(). -        """ -        ret = SortedDict() - -        # Get the explicitly declared fields -        base_fields = copy.deepcopy(self.base_fields) -        for key, field in base_fields.items(): -            ret[key] = field - -        # Add in the default fields -        default_fields = self.get_default_fields() -        for key, val in default_fields.items(): -            if key not in ret: -                ret[key] = val - -        # If 'fields' is specified, use those fields, in that order. -        if self.opts.fields: -            assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple' -            new = SortedDict() -            for key in self.opts.fields: -                new[key] = ret[key] -            ret = new - -        # Remove anything in 'exclude' -        if self.opts.exclude: -            assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple' -            for key in self.opts.exclude: -                ret.pop(key, None) - -        for key, field in ret.items(): -            field.initialize(parent=self, field_name=key) - -        return ret - -    ##### -    # Methods to convert or revert from objects <--> primitive representations. - -    def get_field_key(self, field_name): -        """ -        Return the key that should be used for a given field. -        """ -        return field_name +@six.add_metaclass(SerializerMetaclass) +class Serializer(BaseSerializer): +    def __init__(self, *args, **kwargs): +        self.context = kwargs.pop('context', {}) +        kwargs.pop('partial', None) +        kwargs.pop('many', None) -    def restore_fields(self, data, files): -        """ -        Core of deserialization, together with `restore_object`. -        Converts a dictionary of data into a dictionary of deserialized fields. -        """ -        reverted_data = {} +        super(Serializer, self).__init__(*args, **kwargs) -        if data is not None and not isinstance(data, dict): -            self._errors['non_field_errors'] = ['Invalid data'] -            return None +        # Every new serializer is created with a clone of the field instances. +        # This allows users to dynamically modify the fields on a serializer +        # instance without affecting every other serializer class. +        self.fields = self._get_base_fields() +        # Setup all the child fields, to provide them with the current context.          for field_name, field in self.fields.items(): -            field.initialize(parent=self, field_name=field_name) -            try: -                field.field_from_native(data, files, field_name, reverted_data) -            except ValidationError as err: -                self._errors[field_name] = list(err.messages) - -        return reverted_data - -    def perform_validation(self, attrs): -        """ -        Run `validate_<fieldname>()` and `validate()` methods on the serializer -        """ +            field.bind(field_name, self, self) + +    def __new__(cls, *args, **kwargs): +        # We override this method in order to automagically create +        # `ListSerializer` classes instead when `many=True` is set. +        if kwargs.pop('many', False): +            kwargs['child'] = cls() +            return ListSerializer(*args, **kwargs) +        return super(Serializer, cls).__new__(cls, *args, **kwargs) + +    def _get_base_fields(self): +        return copy.deepcopy(self._declared_fields) + +    def bind(self, field_name, parent, root): +        # If the serializer is used as a field then when it becomes bound +        # it also needs to bind all its child fields. +        super(Serializer, self).bind(field_name, parent, root)          for field_name, field in self.fields.items(): -            if field_name in self._errors: -                continue +            field.bind(field_name, self, root) -            source = field.source or field_name -            if self.partial and source not in attrs: -                continue -            try: -                validate_method = getattr(self, 'validate_%s' % field_name, None) -                if validate_method: -                    attrs = validate_method(attrs, source) -            except ValidationError as err: -                self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) - -        # If there are already errors, we don't run .validate() because -        # field-validation failed and thus `attrs` may not be complete. -        # which in turn can cause inconsistent validation errors. -        if not self._errors: -            try: -                attrs = self.validate(attrs) -            except ValidationError as err: -                if hasattr(err, 'message_dict'): -                    for field_name, error_messages in err.message_dict.items(): -                        self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) -                elif hasattr(err, 'messages'): -                    self._errors['non_field_errors'] = err.messages +    def get_initial(self): +        return dict([ +            (field.field_name, field.get_initial()) +            for field in self.fields.values() +        ]) -        return attrs +    def get_value(self, dictionary): +        # We override the default field access in order to support +        # nested HTML forms. +        if html.is_html_input(dictionary): +            return html.parse_html_dict(dictionary, prefix=self.field_name) +        return dictionary.get(self.field_name, empty) -    def validate(self, attrs): +    def to_internal_value(self, data):          """ -        Stub method, to be overridden in Serializer subclasses +        Dict of native values <- Dict of primitive datatypes.          """ -        return attrs +        if not isinstance(data, dict): +            raise ValidationError({ +                api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data'] +            }) -    def restore_object(self, attrs, instance=None): -        """ -        Deserialize a dictionary of attributes into an object instance. -        You should override this method to control how deserialized objects -        are instantiated. -        """ -        if instance is not None: -            instance.update(attrs) -            return instance -        return attrs - -    def to_native(self, obj): -        """ -        Serialize objects -> primitives. -        """ -        ret = self._dict_class() -        ret.fields = self._dict_class() +        ret = {} +        errors = {} +        fields = [field for field in self.fields.values() if not field.read_only] -        for field_name, field in self.fields.items(): -            if field.read_only and obj is None: -                continue -            field.initialize(parent=self, field_name=field_name) -            key = self.get_field_key(field_name) -            value = field.field_to_native(obj, field_name) -            method = getattr(self, 'transform_%s' % field_name, None) -            if callable(method): -                value = method(obj, value) -            if not getattr(field, 'write_only', False): -                ret[key] = value -            ret.fields[key] = self.augment_field(field, field_name, key, value) - -        return ret - -    def from_native(self, data, files=None): -        """ -        Deserialize primitives -> objects. -        """ -        self._errors = {} - -        if data is not None or files is not None: -            attrs = self.restore_fields(data, files) -            if attrs is not None: -                attrs = self.perform_validation(attrs) -        else: -            self._errors['non_field_errors'] = ['No input provided'] - -        if not self._errors: -            return self.restore_object(attrs, instance=getattr(self, 'object', None)) - -    def augment_field(self, field, field_name, key, value): -        # This horrible stuff is to manage serializers rendering to HTML -        field._errors = self._errors.get(key) if self._errors else None -        field._name = field_name -        field._value = self.init_data.get(key) if self._errors and self.init_data else value -        if not field.label: -            field.label = pretty_name(key) -        return field - -    def field_to_native(self, obj, field_name): -        """ -        Override default so that the serializer can be used as a nested field -        across relationships. -        """ -        if self.write_only: -            return None +        for field in fields: +            validate_method = getattr(self, 'validate_' + field.field_name, None) +            primitive_value = field.get_value(data) +            try: +                validated_value = field.run_validation(primitive_value) +                if validate_method is not None: +                    validated_value = validate_method(validated_value) +            except ValidationError as exc: +                errors[field.field_name] = exc.messages +            except SkipField: +                pass +            else: +                set_value(ret, field.source_attrs, validated_value) -        if self.source == '*': -            return self.to_native(obj) +        if errors: +            raise ValidationError(errors) -        # Get the raw field value          try: -            source = self.source or field_name -            value = obj - -            for component in source.split('.'): -                if value is None: -                    break -                value = get_component(value, component) -        except ObjectDoesNotExist: -            return None +            return self.validate(ret) +        except ValidationError as exc: +            raise ValidationError({ +                api_settings.NON_FIELD_ERRORS_KEY: exc.messages +            }) -        if is_simple_callable(getattr(value, 'all', None)): -            return [self.to_native(item) for item in value.all()] - -        if value is None: -            return None - -        if self.many: -            return [self.to_native(item) for item in value] -        return self.to_native(value) - -    def field_from_native(self, data, files, field_name, into): -        """ -        Override default so that the serializer can be used as a writable -        nested field across relationships. +    def to_representation(self, instance):          """ -        if self.read_only: -            return - -        try: -            value = data[field_name] -        except KeyError: -            if self.default is not None and not self.partial: -                # Note: partial updates shouldn't set defaults -                value = copy.deepcopy(self.default) -            else: -                if self.required: -                    raise ValidationError(self.error_messages['required']) -                return - -        if self.source == '*': -            if value: -                reverted_data = self.restore_fields(value, {}) -                if not self._errors: -                    into.update(reverted_data) -        else: -            if value in (None, ''): -                into[(self.source or field_name)] = None -            else: -                # Set the serializer object if it exists -                obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None - -                # If we have a model manager or similar object then we need -                # to iterate through each instance. -                if ( -                    self.many and -                    not hasattr(obj, '__iter__') and -                    is_simple_callable(getattr(obj, 'all', None)) -                ): -                    obj = obj.all() - -                kwargs = { -                    'instance': obj, -                    'data': value, -                    'context': self.context, -                    'partial': self.partial, -                    'many': self.many, -                    'allow_add_remove': self.allow_add_remove -                } -                serializer = self.__class__(**kwargs) - -                if serializer.is_valid(): -                    into[self.source or field_name] = serializer.object -                else: -                    # Propagate errors up to our parent -                    raise NestedValidationError(serializer.errors) - -    def get_identity(self, data): +        Object instance -> Dict of primitive datatypes.          """ -        This hook is required for bulk update. -        It is used to determine the canonical identity of a given object. +        ret = SortedDict() +        fields = [field for field in self.fields.values() if not field.write_only] -        Note that the data has not been validated at this point, so we need -        to make sure that we catch any cases of incorrect datatypes being -        passed to this method. -        """ -        try: -            return data.get('id', None) -        except AttributeError: -            return None +        for field in fields: +            native_value = field.get_attribute(instance) +            ret[field.field_name] = field.to_representation(native_value) -    @property -    def errors(self): -        """ -        Run deserialization and return error data, -        setting self.object if no errors occurred. -        """ -        if self._errors is None: -            data, files = self.init_data, self.init_files +        return ret -            if self.many is not None: -                many = self.many -            else: -                many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) -                if many: -                    warnings.warn('Implicit list/queryset serialization is deprecated. ' -                                  'Use the `many=True` flag when instantiating the serializer.', -                                  DeprecationWarning, stacklevel=3) - -            if many: -                ret = RelationsList() -                errors = [] -                update = self.object is not None - -                if update: -                    # If this is a bulk update we need to map all the objects -                    # to a canonical identity so we can determine which -                    # individual object is being updated for each item in the -                    # incoming data -                    objects = self.object -                    identities = [self.get_identity(self.to_native(obj)) for obj in objects] -                    identity_to_objects = dict(zip(identities, objects)) - -                if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): -                    for item in data: -                        if update: -                            # Determine which object we're updating -                            identity = self.get_identity(item) -                            self.object = identity_to_objects.pop(identity, None) -                            if self.object is None and not self.allow_add_remove: -                                ret.append(None) -                                errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) -                                continue - -                        ret.append(self.from_native(item, None)) -                        errors.append(self._errors) - -                    if update and self.allow_add_remove: -                        ret._deleted = identity_to_objects.values() - -                    self._errors = any(errors) and errors or [] -                else: -                    self._errors = {'non_field_errors': ['Expected a list of items.']} -            else: -                ret = self.from_native(data, files) +    def validate(self, attrs): +        return attrs -            if not self._errors: -                self.object = ret +    def __iter__(self): +        errors = self.errors if hasattr(self, '_errors') else {} +        for field in self.fields.values(): +            value = self.data.get(field.field_name) if self.data else None +            error = errors.get(field.field_name) +            yield FieldResult(field, value, error) -        return self._errors +    def __repr__(self): +        return representation.serializer_repr(self, indent=1) -    def is_valid(self): -        return not self.errors -    @property -    def data(self): -        """ -        Returns the serialized data on the serializer. -        """ -        if self._data is None: -            obj = self.object +class ListSerializer(BaseSerializer): +    child = None +    initial = [] -            if self.many is not None: -                many = self.many -            else: -                many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) -                if many: -                    warnings.warn('Implicit list/queryset serialization is deprecated. ' -                                  'Use the `many=True` flag when instantiating the serializer.', -                                  DeprecationWarning, stacklevel=2) - -            if many: -                self._data = [self.to_native(item) for item in obj] -            else: -                self._data = self.to_native(obj) +    def __init__(self, *args, **kwargs): +        self.child = kwargs.pop('child', copy.deepcopy(self.child)) +        assert self.child is not None, '`child` is a required argument.' +        self.context = kwargs.pop('context', {}) +        kwargs.pop('partial', None) -        return self._data +        super(ListSerializer, self).__init__(*args, **kwargs) +        self.child.bind('', self, self) -    def save_object(self, obj, **kwargs): -        obj.save(**kwargs) +    def bind(self, field_name, parent, root): +        # If the list is used as a field then it needs to provide +        # the current context to the child serializer. +        super(ListSerializer, self).bind(field_name, parent, root) +        self.child.bind(field_name, self, root) -    def delete_object(self, obj): -        obj.delete() +    def get_value(self, dictionary): +        # We override the default field access in order to support +        # lists in HTML forms. +        if is_html_input(dictionary): +            return html.parse_html_list(dictionary, prefix=self.field_name) +        return dictionary.get(self.field_name, empty) -    def save(self, **kwargs): +    def to_internal_value(self, data):          """ -        Save the deserialized object and return it. +        List of dicts of native values <- List of dicts of primitive datatypes.          """ -        # Clear cached _data, which may be invalidated by `save()` -        self._data = None - -        if isinstance(self.object, list): -            [self.save_object(item, **kwargs) for item in self.object] - -            if self.object._deleted: -                [self.delete_object(item) for item in self.object._deleted] -        else: -            self.save_object(self.object, **kwargs) +        if html.is_html_input(data): +            data = html.parse_html_list(data) -        return self.object +        return [self.child.run_validation(item) for item in data] -    def metadata(self): +    def to_representation(self, data):          """ -        Return a dictionary of metadata about the fields on the serializer. -        Useful for things like responding to OPTIONS requests, or generating -        API schemas for auto-documentation. +        List of object instances -> List of dicts of primitive datatypes.          """ -        return SortedDict( -            [ -                (field_name, field.metadata()) -                for field_name, field in six.iteritems(self.fields) -            ] -        ) +        return [self.child.to_representation(item) for item in data] +    def create(self, attrs_list): +        return [self.child.create(attrs) for attrs in attrs_list] -class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)): -    pass +    def save(self): +        if self.instance is not None: +            self.update(self.instance, self.validated_data) +        self.instance = self.create(self.validated_data) +        return self.instance - -class ModelSerializerOptions(SerializerOptions): -    """ -    Meta class options for ModelSerializer -    """ -    def __init__(self, meta): -        super(ModelSerializerOptions, self).__init__(meta) -        self.model = getattr(meta, 'model', None) -        self.read_only_fields = getattr(meta, 'read_only_fields', ()) -        self.write_only_fields = getattr(meta, 'write_only_fields', ()) +    def __repr__(self): +        return representation.list_repr(self, indent=1)  class ModelSerializer(Serializer): -    """ -    A serializer that deals with model instances and querysets. -    """ -    _options_class = ModelSerializerOptions - -    field_mapping = { +    _field_mapping = {          models.AutoField: IntegerField, +        models.BigIntegerField: IntegerField, +        models.BooleanField: BooleanField, +        models.CharField: CharField, +        models.CommaSeparatedIntegerField: CharField, +        models.DateField: DateField, +        models.DateTimeField: DateTimeField, +        models.DecimalField: DecimalField, +        models.EmailField: EmailField, +        models.Field: ModelField, +        models.FileField: FileField,          models.FloatField: FloatField, +        models.ImageField: ImageField,          models.IntegerField: IntegerField, +        models.NullBooleanField: BooleanField,          models.PositiveIntegerField: IntegerField, -        models.SmallIntegerField: IntegerField,          models.PositiveSmallIntegerField: IntegerField, -        models.DateTimeField: DateTimeField, -        models.DateField: DateField, -        models.TimeField: TimeField, -        models.DecimalField: DecimalField, -        models.EmailField: EmailField, -        models.CharField: CharField, -        models.URLField: URLField,          models.SlugField: SlugField, +        models.SmallIntegerField: IntegerField,          models.TextField: CharField, -        models.CommaSeparatedIntegerField: CharField, -        models.BooleanField: BooleanField, -        models.NullBooleanField: BooleanField, -        models.FileField: FileField, -        models.ImageField: ImageField, +        models.TimeField: TimeField, +        models.URLField: URLField,      } +    _related_class = PrimaryKeyRelatedField -    def get_default_fields(self): -        """ -        Return all the fields that should be serialized for the model. -        """ +    def create(self, attrs): +        ModelClass = self.Meta.model -        cls = self.opts.model -        assert cls is not None, ( -            "Serializer class '%s' is missing 'model' Meta option" % -            self.__class__.__name__ -        ) -        opts = cls._meta.concrete_model._meta -        ret = SortedDict() -        nested = bool(self.opts.depth) - -        # Deal with adding the primary key field -        pk_field = opts.pk -        while pk_field.rel and pk_field.rel.parent_link: -            # If model is a child via multitable inheritance, use parent's pk -            pk_field = pk_field.rel.to._meta.pk - -        serializer_pk_field = self.get_pk_field(pk_field) -        if serializer_pk_field: -            ret[pk_field.name] = serializer_pk_field - -        # Deal with forward relationships -        forward_rels = [field for field in opts.fields if field.serialize] -        forward_rels += [field for field in opts.many_to_many if field.serialize] - -        for model_field in forward_rels: -            has_through_model = False - -            if model_field.rel: -                to_many = isinstance(model_field, -                                     models.fields.related.ManyToManyField) -                related_model = _resolve_model(model_field.rel.to) - -                if to_many and not model_field.rel.through._meta.auto_created: -                    has_through_model = True - -            if model_field.rel and nested: -                if len(inspect.getargspec(self.get_nested_field).args) == 2: -                    warnings.warn( -                        'The `get_nested_field(model_field)` call signature ' -                        'is deprecated. ' -                        'Use `get_nested_field(model_field, related_model, ' -                        'to_many) instead', -                        DeprecationWarning -                    ) -                    field = self.get_nested_field(model_field) -                else: -                    field = self.get_nested_field(model_field, related_model, to_many) -            elif model_field.rel: -                if len(inspect.getargspec(self.get_nested_field).args) == 3: -                    warnings.warn( -                        'The `get_related_field(model_field, to_many)` call ' -                        'signature is deprecated. ' -                        'Use `get_related_field(model_field, related_model, ' -                        'to_many) instead', -                        DeprecationWarning -                    ) -                    field = self.get_related_field(model_field, to_many=to_many) -                else: -                    field = self.get_related_field(model_field, related_model, to_many) -            else: -                field = self.get_field(model_field) +        # Remove many-to-many relationships from attrs. +        # They are not valid arguments to the default `.create()` method, +        # as they require that the instance has already been saved. +        info = model_meta.get_field_info(ModelClass) +        many_to_many = {} +        for field_name, relation_info in info.relations.items(): +            if relation_info.to_many and (field_name in attrs): +                many_to_many[field_name] = attrs.pop(field_name) -            if field: -                if has_through_model: -                    field.read_only = True +        instance = ModelClass.objects.create(**attrs) -                ret[model_field.name] = field +        # Save many-to-many relationships after the instance is created. +        if many_to_many: +            for field_name, value in many_to_many.items(): +                setattr(instance, field_name, value) -        # Deal with reverse relationships -        if not self.opts.fields: -            reverse_rels = [] -        else: -            # Reverse relationships are only included if they are explicitly -            # present in the `fields` option on the serializer -            reverse_rels = opts.get_all_related_objects() -            reverse_rels += opts.get_all_related_many_to_many_objects() - -        for relation in reverse_rels: -            accessor_name = relation.get_accessor_name() -            if not self.opts.fields or accessor_name not in self.opts.fields: -                continue -            related_model = relation.model -            to_many = relation.field.rel.multiple -            has_through_model = False -            is_m2m = isinstance(relation.field, -                                models.fields.related.ManyToManyField) - -            if ( -                is_m2m and -                hasattr(relation.field.rel, 'through') and -                not relation.field.rel.through._meta.auto_created -            ): -                has_through_model = True - -            if nested: -                field = self.get_nested_field(None, related_model, to_many) -            else: -                field = self.get_related_field(None, related_model, to_many) - -            if field: -                if has_through_model: -                    field.read_only = True - -                ret[accessor_name] = field - -        # Ensure that 'read_only_fields' is an iterable -        assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple' - -        # Add the `read_only` flag to any fields that have been specified -        # in the `read_only_fields` option -        for field_name in self.opts.read_only_fields: -            assert field_name not in self.base_fields.keys(), ( -                "field '%s' on serializer '%s' specified in " -                "`read_only_fields`, but also added " -                "as an explicit field.  Remove it from `read_only_fields`." % -                (field_name, self.__class__.__name__)) -            assert field_name in ret, ( -                "Non-existant field '%s' specified in `read_only_fields` " -                "on serializer '%s'." % -                (field_name, self.__class__.__name__)) -            ret[field_name].read_only = True - -        # Ensure that 'write_only_fields' is an iterable -        assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple' - -        for field_name in self.opts.write_only_fields: -            assert field_name not in self.base_fields.keys(), ( -                "field '%s' on serializer '%s' specified in " -                "`write_only_fields`, but also added " -                "as an explicit field.  Remove it from `write_only_fields`." % -                (field_name, self.__class__.__name__)) -            assert field_name in ret, ( -                "Non-existant field '%s' specified in `write_only_fields` " -                "on serializer '%s'." % -                (field_name, self.__class__.__name__)) -            ret[field_name].write_only = True - -        return ret - -    def get_pk_field(self, model_field): -        """ -        Returns a default instance of the pk field. -        """ -        return self.get_field(model_field) - -    def get_nested_field(self, model_field, related_model, to_many): -        """ -        Creates a default instance of a nested relational field. - -        Note that model_field will be `None` for reverse relationships. -        """ -        class NestedModelSerializer(ModelSerializer): -            class Meta: -                model = related_model -                depth = self.opts.depth - 1 - -        return NestedModelSerializer(many=to_many) - -    def get_related_field(self, model_field, related_model, to_many): -        """ -        Creates a default instance of a flat relational field. - -        Note that model_field will be `None` for reverse relationships. -        """ -        # TODO: filter queryset using: -        # .using(db).complex_filter(self.rel.limit_choices_to) - -        kwargs = { -            'queryset': related_model._default_manager, -            'many': to_many -        } - -        if model_field: -            kwargs['required'] = not(model_field.null or model_field.blank) -            if model_field.help_text is not None: -                kwargs['help_text'] = model_field.help_text -            if model_field.verbose_name is not None: -                kwargs['label'] = model_field.verbose_name - -            if not model_field.editable: -                kwargs['read_only'] = True - -            if model_field.verbose_name is not None: -                kwargs['label'] = model_field.verbose_name - -            if model_field.help_text is not None: -                kwargs['help_text'] = model_field.help_text - -        return PrimaryKeyRelatedField(**kwargs) - -    def get_field(self, model_field): -        """ -        Creates a default instance of a basic non-relational field. -        """ -        kwargs = {} - -        if model_field.null or model_field.blank: -            kwargs['required'] = False - -        if isinstance(model_field, models.AutoField) or not model_field.editable: -            kwargs['read_only'] = True - -        if model_field.has_default(): -            kwargs['default'] = model_field.get_default() - -        if issubclass(model_field.__class__, models.TextField): -            kwargs['widget'] = widgets.Textarea - -        if model_field.verbose_name is not None: -            kwargs['label'] = model_field.verbose_name - -        if model_field.help_text is not None: -            kwargs['help_text'] = model_field.help_text - -        # TODO: TypedChoiceField? -        if model_field.flatchoices:  # This ModelField contains choices -            kwargs['choices'] = model_field.flatchoices -            if model_field.null: -                kwargs['empty'] = None -            return ChoiceField(**kwargs) - -        # put this below the ChoiceField because min_value isn't a valid initializer -        if issubclass(model_field.__class__, models.PositiveIntegerField) or\ -                issubclass(model_field.__class__, models.PositiveSmallIntegerField): -            kwargs['min_value'] = 0 - -        if model_field.null and \ -                issubclass(model_field.__class__, (models.CharField, models.TextField)): -            kwargs['allow_none'] = True - -        attribute_dict = { -            models.CharField: ['max_length'], -            models.CommaSeparatedIntegerField: ['max_length'], -            models.DecimalField: ['max_digits', 'decimal_places'], -            models.EmailField: ['max_length'], -            models.FileField: ['max_length'], -            models.ImageField: ['max_length'], -            models.SlugField: ['max_length'], -            models.URLField: ['max_length'], -        } - -        if model_field.__class__ in attribute_dict: -            attributes = attribute_dict[model_field.__class__] -            for attribute in attributes: -                kwargs.update({attribute: getattr(model_field, attribute)}) - -        try: -            return self.field_mapping[model_field.__class__](**kwargs) -        except KeyError: -            return ModelField(model_field=model_field, **kwargs) - -    def get_validation_exclusions(self, instance=None): -        """ -        Return a list of field names to exclude from model validation. -        """ -        cls = self.opts.model -        opts = cls._meta.concrete_model._meta -        exclusions = [field.name for field in opts.fields + opts.many_to_many] - -        for field_name, field in self.fields.items(): -            field_name = field.source or field_name -            if ( -                field_name in exclusions -                and not field.read_only -                and (field.required or hasattr(instance, field_name)) -                and not isinstance(field, Serializer) -            ): -                exclusions.remove(field_name) -        return exclusions - -    def full_clean(self, instance): -        """ -        Perform Django's full_clean, and populate the `errors` dictionary -        if any validation errors occur. - -        Note that we don't perform this inside the `.restore_object()` method, -        so that subclasses can override `.restore_object()`, and still get -        the full_clean validation checking. -        """ -        try: -            instance.full_clean(exclude=self.get_validation_exclusions(instance)) -        except ValidationError as err: -            self._errors = err.message_dict -            return None          return instance -    def restore_object(self, attrs, instance=None): -        """ -        Restore the model instance. -        """ -        m2m_data = {} -        related_data = {} -        nested_forward_relations = {} -        meta = self.opts.model._meta - -        # Reverse fk or one-to-one relations -        for (obj, model) in meta.get_all_related_objects_with_model(): -            field_name = obj.get_accessor_name() -            if field_name in attrs: -                related_data[field_name] = attrs.pop(field_name) - -        # Reverse m2m relations -        for (obj, model) in meta.get_all_related_m2m_objects_with_model(): -            field_name = obj.get_accessor_name() -            if field_name in attrs: -                m2m_data[field_name] = attrs.pop(field_name) - -        # Forward m2m relations -        for field in meta.many_to_many + meta.virtual_fields: -            if isinstance(field, GenericForeignKey): -                continue -            if field.name in attrs: -                m2m_data[field.name] = attrs.pop(field.name) +    def update(self, obj, attrs): +        for attr, value in attrs.items(): +            setattr(obj, attr, value) +        obj.save() -        # Nested forward relations - These need to be marked so we can save -        # them before saving the parent model instance. -        for field_name in attrs.keys(): -            if isinstance(self.fields.get(field_name, None), Serializer): -                nested_forward_relations[field_name] = attrs[field_name] +    def _get_base_fields(self): +        declared_fields = copy.deepcopy(self._declared_fields) -        # Create an empty instance of the model -        if instance is None: -            instance = self.opts.model() +        ret = SortedDict() +        model = getattr(self.Meta, 'model') +        fields = getattr(self.Meta, 'fields', None) +        depth = getattr(self.Meta, 'depth', 0) +        extra_kwargs = getattr(self.Meta, 'extra_kwargs', {}) + +        # Retrieve metadata about fields & relationships on the model class. +        info = model_meta.get_field_info(model) + +        # Use the default set of fields if none is supplied explicitly. +        if fields is None: +            fields = self._get_default_field_names(declared_fields, info) + +        for field_name in fields: +            if field_name in declared_fields: +                # Field is explicitly declared on the class, use that. +                ret[field_name] = declared_fields[field_name] +                continue -        for key, val in attrs.items(): -            try: -                setattr(instance, key, val) -            except ValueError: -                self._errors[key] = [self.error_messages['required']] - -        # Any relations that cannot be set until we've -        # saved the model get hidden away on these -        # private attributes, so we can deal with them -        # at the point of save. -        instance._related_data = related_data -        instance._m2m_data = m2m_data -        instance._nested_forward_relations = nested_forward_relations +            elif field_name == api_settings.URL_FIELD_NAME: +                # Create the URL field. +                field_cls = HyperlinkedIdentityField +                kwargs = get_url_kwargs(model) + +            elif field_name in info.fields_and_pk: +                # Create regular model fields. +                model_field = info.fields_and_pk[field_name] +                field_cls = lookup_class(self._field_mapping, model_field) +                kwargs = get_field_kwargs(field_name, model_field) +                if 'choices' in kwargs: +                    # Fields with choices get coerced into `ChoiceField` +                    # instead of using their regular typed field. +                    field_cls = ChoiceField +                if not issubclass(field_cls, ModelField): +                    # `model_field` is only valid for the fallback case of +                    # `ModelField`, which is used when no other typed field +                    # matched to the model field. +                    kwargs.pop('model_field', None) + +            elif field_name in info.relations: +                # Create forward and reverse relationships. +                relation_info = info.relations[field_name] +                if depth: +                    field_cls = self._get_nested_class(depth, relation_info) +                    kwargs = get_nested_relation_kwargs(relation_info) +                else: +                    field_cls = self._related_class +                    kwargs = get_relation_kwargs(field_name, relation_info) +                    # `view_name` is only valid for hyperlinked relationships. +                    if not issubclass(field_cls, HyperlinkedRelatedField): +                        kwargs.pop('view_name', None) -        return instance +            elif hasattr(model, field_name): +                # Create a read only field for model methods and properties. +                field_cls = ReadOnlyField +                kwargs = {} -    def from_native(self, data, files): -        """ -        Override the default method to also include model field validation. -        """ -        instance = super(ModelSerializer, self).from_native(data, files) -        if not self._errors: -            return self.full_clean(instance) +            else: +                raise ImproperlyConfigured( +                    'Field name `%s` is not valid for model `%s`.' % +                    (field_name, model.__class__.__name__) +                ) + +            # Check that any fields declared on the class are +            # also explicity included in `Meta.fields`. +            missing_fields = set(declared_fields.keys()) - set(fields) +            if missing_fields: +                missing_field = list(missing_fields)[0] +                raise ImproperlyConfigured( +                    'Field `%s` has been declared on serializer `%s`, but ' +                    'is missing from `Meta.fields`.' % +                    (missing_field, self.__class__.__name__) +                ) + +            # Populate any kwargs defined in `Meta.extra_kwargs` +            kwargs.update(extra_kwargs.get(field_name, {})) + +            # Create the serializer field. +            ret[field_name] = field_cls(**kwargs) -    def save_object(self, obj, **kwargs): -        """ -        Save the deserialized object. -        """ -        if getattr(obj, '_nested_forward_relations', None): -            # Nested relationships need to be saved before we can save the -            # parent instance. -            for field_name, sub_object in obj._nested_forward_relations.items(): -                if sub_object: -                    self.save_object(sub_object) -                setattr(obj, field_name, sub_object) - -        obj.save(**kwargs) - -        if getattr(obj, '_m2m_data', None): -            for accessor_name, object_list in obj._m2m_data.items(): -                setattr(obj, accessor_name, object_list) -            del(obj._m2m_data) - -        if getattr(obj, '_related_data', None): -            related_fields = dict([ -                (field.get_accessor_name(), field) -                for field, model -                in obj._meta.get_all_related_objects_with_model() -            ]) -            for accessor_name, related in obj._related_data.items(): -                if isinstance(related, RelationsList): -                    # Nested reverse fk relationship -                    for related_item in related: -                        fk_field = related_fields[accessor_name].field.name -                        setattr(related_item, fk_field, obj) -                        self.save_object(related_item) - -                    # Delete any removed objects -                    if related._deleted: -                        [self.delete_object(item) for item in related._deleted] - -                elif isinstance(related, models.Model): -                    # Nested reverse one-one relationship -                    fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name -                    setattr(related, fk_field, obj) -                    self.save_object(related) -                else: -                    # Reverse FK or reverse one-one -                    setattr(obj, accessor_name, related) -            del(obj._related_data) +        return ret +    def _get_default_field_names(self, declared_fields, model_info): +        return ( +            [model_info.pk.name] + +            list(declared_fields.keys()) + +            list(model_info.fields.keys()) + +            list(model_info.forward_relations.keys()) +        ) -class HyperlinkedModelSerializerOptions(ModelSerializerOptions): -    """ -    Options for HyperlinkedModelSerializer -    """ -    def __init__(self, meta): -        super(HyperlinkedModelSerializerOptions, self).__init__(meta) -        self.view_name = getattr(meta, 'view_name', None) -        self.lookup_field = getattr(meta, 'lookup_field', None) -        self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME) +    def _get_nested_class(self, nested_depth, relation_info): +        class NestedSerializer(ModelSerializer): +            class Meta: +                model = relation_info.related +                depth = nested_depth +        return NestedSerializer  class HyperlinkedModelSerializer(ModelSerializer): -    """ -    A subclass of ModelSerializer that uses hyperlinked relationships, -    instead of primary key relationships. -    """ -    _options_class = HyperlinkedModelSerializerOptions -    _default_view_name = '%(model_name)s-detail' -    _hyperlink_field_class = HyperlinkedRelatedField -    _hyperlink_identify_field_class = HyperlinkedIdentityField - -    def get_default_fields(self): -        fields = super(HyperlinkedModelSerializer, self).get_default_fields() - -        if self.opts.view_name is None: -            self.opts.view_name = self._get_default_view_name(self.opts.model) - -        if self.opts.url_field_name not in fields: -            url_field = self._hyperlink_identify_field_class( -                view_name=self.opts.view_name, -                lookup_field=self.opts.lookup_field -            ) -            ret = self._dict_class() -            ret[self.opts.url_field_name] = url_field -            ret.update(fields) -            fields = ret - -        return fields - -    def get_pk_field(self, model_field): -        if self.opts.fields and model_field.name in self.opts.fields: -            return self.get_field(model_field) - -    def get_related_field(self, model_field, related_model, to_many): -        """ -        Creates a default instance of a flat relational field. -        """ -        # TODO: filter queryset using: -        # .using(db).complex_filter(self.rel.limit_choices_to) -        kwargs = { -            'queryset': related_model._default_manager, -            'view_name': self._get_default_view_name(related_model), -            'many': to_many -        } - -        if model_field: -            kwargs['required'] = not(model_field.null or model_field.blank) -            if model_field.help_text is not None: -                kwargs['help_text'] = model_field.help_text -            if model_field.verbose_name is not None: -                kwargs['label'] = model_field.verbose_name - -        if self.opts.lookup_field: -            kwargs['lookup_field'] = self.opts.lookup_field - -        return self._hyperlink_field_class(**kwargs) - -    def get_identity(self, data): -        """ -        This hook is required for bulk update. -        We need to override the default, to use the url as the identity. -        """ -        try: -            return data.get(self.opts.url_field_name, None) -        except AttributeError: -            return None +    _related_class = HyperlinkedRelatedField + +    def _get_default_field_names(self, declared_fields, model_info): +        return ( +            [api_settings.URL_FIELD_NAME] + +            list(declared_fields.keys()) + +            list(model_info.fields.keys()) + +            list(model_info.forward_relations.keys()) +        ) -    def _get_default_view_name(self, model): -        """ -        Return the view name to use if 'view_name' is not specified in 'Meta' -        """ -        model_meta = model._meta -        format_kwargs = { -            'app_label': model_meta.app_label, -            'model_name': model_meta.object_name.lower() -        } -        return self._default_view_name % format_kwargs +    def _get_nested_class(self, nested_depth, relation_info): +        class NestedSerializer(HyperlinkedModelSerializer): +            class Meta: +                model = relation_info.related +                depth = nested_depth +        return NestedSerializer diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 644751f8..421e146c 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -77,6 +77,7 @@ DEFAULTS = {      # Exception handling      'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler', +    'NON_FIELD_ERRORS_KEY': 'non_field_errors',      # Testing      'TEST_REQUEST_RENDERER_CLASSES': ( @@ -96,24 +97,19 @@ DEFAULTS = {      'URL_FIELD_NAME': 'url',      # Input and output formats -    'DATE_INPUT_FORMATS': ( -        ISO_8601, -    ), -    'DATE_FORMAT': None, +    'DATE_FORMAT': ISO_8601, +    'DATE_INPUT_FORMATS': (ISO_8601,), -    'DATETIME_INPUT_FORMATS': ( -        ISO_8601, -    ), -    'DATETIME_FORMAT': None, +    'DATETIME_FORMAT': ISO_8601, +    'DATETIME_INPUT_FORMATS': (ISO_8601,), -    'TIME_INPUT_FORMATS': ( -        ISO_8601, -    ), -    'TIME_FORMAT': None, - -    # Pending deprecation -    'FILTER_BACKEND': None, +    'TIME_FORMAT': ISO_8601, +    'TIME_INPUT_FORMATS': (ISO_8601,), +    # Encoding +    'UNICODE_JSON': True, +    'COMPACT_JSON': True, +    'COERCE_DECIMAL_TO_STRING': True  } @@ -129,7 +125,6 @@ IMPORT_STRINGS = (      'DEFAULT_PAGINATION_SERIALIZER_CLASS',      'DEFAULT_FILTER_BACKENDS',      'EXCEPTION_HANDLER', -    'FILTER_BACKEND',      'TEST_REQUEST_RENDERER_CLASSES',      'UNAUTHENTICATED_USER',      'UNAUTHENTICATED_TOKEN', @@ -196,15 +191,9 @@ class APISettings(object):          if val and attr in self.import_strings:              val = perform_import(val, attr) -        self.validate_setting(attr, val) -          # Cache the result          setattr(self, attr, val)          return val -    def validate_setting(self, attr, val): -        if attr == 'FILTER_BACKEND' and val is not None: -            # Make sure we can initialize the class -            val()  api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) diff --git a/rest_framework/test.py b/rest_framework/test.py index f89a6dcd..9b40353a 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -36,7 +36,7 @@ class APIRequestFactory(DjangoRequestFactory):          Encode the data returning a two tuple of (bytes, content_type)          """ -        if not data: +        if data is None:              return ('', content_type)          assert format is None or content_type is None, ( diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 00ffdfba..174b08b8 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -7,7 +7,6 @@ from django.db.models.query import QuerySet  from django.utils.datastructures import SortedDict  from django.utils.functional import Promise  from rest_framework.compat import force_text -from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata  import datetime  import decimal  import types @@ -17,45 +16,47 @@ import json  class JSONEncoder(json.JSONEncoder):      """      JSONEncoder subclass that knows how to encode date/time/timedelta, -    decimal types, and generators. +    decimal types, generators and other basic python objects.      """ -    def default(self, o): +    def default(self, obj):          # For Date Time string spec, see ECMA 262          # http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15 -        if isinstance(o, Promise): -            return force_text(o) -        elif isinstance(o, datetime.datetime): -            r = o.isoformat() -            if o.microsecond: -                r = r[:23] + r[26:] -            if r.endswith('+00:00'): -                r = r[:-6] + 'Z' -            return r -        elif isinstance(o, datetime.date): -            return o.isoformat() -        elif isinstance(o, datetime.time): -            if timezone and timezone.is_aware(o): +        if isinstance(obj, Promise): +            return force_text(obj) +        elif isinstance(obj, datetime.datetime): +            representation = obj.isoformat() +            if obj.microsecond: +                representation = representation[:23] + representation[26:] +            if representation.endswith('+00:00'): +                representation = representation[:-6] + 'Z' +            return representation +        elif isinstance(obj, datetime.date): +            return obj.isoformat() +        elif isinstance(obj, datetime.time): +            if timezone and timezone.is_aware(obj):                  raise ValueError("JSON can't represent timezone-aware times.") -            r = o.isoformat() -            if o.microsecond: -                r = r[:12] -            return r -        elif isinstance(o, datetime.timedelta): -            return str(o.total_seconds()) -        elif isinstance(o, decimal.Decimal): -            return str(o) -        elif isinstance(o, QuerySet): -            return list(o) -        elif hasattr(o, 'tolist'): -            return o.tolist() -        elif hasattr(o, '__getitem__'): +            representation = obj.isoformat() +            if obj.microsecond: +                representation = representation[:12] +            return representation +        elif isinstance(obj, datetime.timedelta): +            return str(obj.total_seconds()) +        elif isinstance(obj, decimal.Decimal): +            # Serializers will coerce decimals to strings by default. +            return float(obj) +        elif isinstance(obj, QuerySet): +            return list(obj) +        elif hasattr(obj, 'tolist'): +            # Numpy arrays and array scalars. +            return obj.tolist() +        elif hasattr(obj, '__getitem__'):              try: -                return dict(o) +                return dict(obj)              except:                  pass -        elif hasattr(o, '__iter__'): -            return [i for i in o] -        return super(JSONEncoder, self).default(o) +        elif hasattr(obj, '__iter__'): +            return [item for item in obj] +        return super(JSONEncoder, self).default(obj)  try: @@ -106,14 +107,14 @@ else:          SortedDict,          yaml.representer.SafeRepresenter.represent_dict      ) -    SafeDumper.add_representer( -        DictWithMetadata, -        yaml.representer.SafeRepresenter.represent_dict -    ) -    SafeDumper.add_representer( -        SortedDictWithMetadata, -        yaml.representer.SafeRepresenter.represent_dict -    ) +    # SafeDumper.add_representer( +    #     DictWithMetadata, +    #     yaml.representer.SafeRepresenter.represent_dict +    # ) +    # SafeDumper.add_representer( +    #     SortedDictWithMetadata, +    #     yaml.representer.SafeRepresenter.represent_dict +    # )      SafeDumper.add_representer(          types.GeneratorType,          yaml.representer.SafeRepresenter.represent_list diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py new file mode 100644 index 00000000..be72e444 --- /dev/null +++ b/rest_framework/utils/field_mapping.py @@ -0,0 +1,215 @@ +""" +Helper functions for mapping model fields to a dictionary of default +keyword arguments that should be used for their equivelent serializer fields. +""" +from django.core import validators +from django.db import models +from django.utils.text import capfirst +from rest_framework.compat import clean_manytomany_helptext +import inspect + + +def lookup_class(mapping, instance): +    """ +    Takes a dictionary with classes as keys, and an object. +    Traverses the object's inheritance hierarchy in method +    resolution order, and returns the first matching value +    from the dictionary or raises a KeyError if nothing matches. +    """ +    for cls in inspect.getmro(instance.__class__): +        if cls in mapping: +            return mapping[cls] +    raise KeyError('Class %s not found in lookup.', cls.__name__) + + +def needs_label(model_field, field_name): +    """ +    Returns `True` if the label based on the model's verbose name +    is not equal to the default label it would have based on it's field name. +    """ +    default_label = field_name.replace('_', ' ').capitalize() +    return capfirst(model_field.verbose_name) != default_label + + +def get_detail_view_name(model): +    """ +    Given a model class, return the view name to use for URL relationships +    that refer to instances of the model. +    """ +    return '%(model_name)s-detail' % { +        'app_label': model._meta.app_label, +        'model_name': model._meta.object_name.lower() +    } + + +def get_field_kwargs(field_name, model_field): +    """ +    Creates a default instance of a basic non-relational field. +    """ +    kwargs = {} +    validator_kwarg = model_field.validators + +    if model_field.null or model_field.blank: +        kwargs['required'] = False + +    if model_field.verbose_name and needs_label(model_field, field_name): +        kwargs['label'] = capfirst(model_field.verbose_name) + +    if model_field.help_text: +        kwargs['help_text'] = model_field.help_text + +    if isinstance(model_field, models.AutoField) or not model_field.editable: +        kwargs['read_only'] = True +        # Read only implies that the field is not required. +        # We have a cleaner repr on the instance if we don't set it. +        kwargs.pop('required', None) + +    if model_field.has_default(): +        kwargs['default'] = model_field.get_default() +        # Having a default implies that the field is not required. +        # We have a cleaner repr on the instance if we don't set it. +        kwargs.pop('required', None) + +    if model_field.flatchoices: +        # If this model field contains choices, then return now, +        # any further keyword arguments are not valid. +        kwargs['choices'] = model_field.flatchoices +        return kwargs + +    # Ensure that max_length is passed explicitly as a keyword arg, +    # rather than as a validator. +    max_length = getattr(model_field, 'max_length', None) +    if max_length is not None: +        kwargs['max_length'] = max_length +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if not isinstance(validator, validators.MaxLengthValidator) +        ] + +    # Ensure that min_length is passed explicitly as a keyword arg, +    # rather than as a validator. +    min_length = getattr(model_field, 'min_length', None) +    if min_length is not None: +        kwargs['min_length'] = min_length +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if not isinstance(validator, validators.MinLengthValidator) +        ] + +    # Ensure that max_value is passed explicitly as a keyword arg, +    # rather than as a validator. +    max_value = next(( +        validator.limit_value for validator in validator_kwarg +        if isinstance(validator, validators.MaxValueValidator) +    ), None) +    if max_value is not None: +        kwargs['max_value'] = max_value +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if not isinstance(validator, validators.MaxValueValidator) +        ] + +    # Ensure that max_value is passed explicitly as a keyword arg, +    # rather than as a validator. +    min_value = next(( +        validator.limit_value for validator in validator_kwarg +        if isinstance(validator, validators.MinValueValidator) +    ), None) +    if min_value is not None: +        kwargs['min_value'] = min_value +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if not isinstance(validator, validators.MinValueValidator) +        ] + +    # URLField does not need to include the URLValidator argument, +    # as it is explicitly added in. +    if isinstance(model_field, models.URLField): +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if not isinstance(validator, validators.URLValidator) +        ] + +    # EmailField does not need to include the validate_email argument, +    # as it is explicitly added in. +    if isinstance(model_field, models.EmailField): +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if validator is not validators.validate_email +        ] + +    # SlugField do not need to include the 'validate_slug' argument, +    if isinstance(model_field, models.SlugField): +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if validator is not validators.validate_slug +        ] + +    max_digits = getattr(model_field, 'max_digits', None) +    if max_digits is not None: +        kwargs['max_digits'] = max_digits + +    decimal_places = getattr(model_field, 'decimal_places', None) +    if decimal_places is not None: +        kwargs['decimal_places'] = decimal_places + +    if isinstance(model_field, models.BooleanField): +        # models.BooleanField has `blank=True`, but *is* actually +        # required *unless* a default is provided. +        # Also note that Django<1.6 uses `default=False` for +        # models.BooleanField, but Django>=1.6 uses `default=None`. +        kwargs.pop('required', None) + +    if validator_kwarg: +        kwargs['validators'] = validator_kwarg + +    # The following will only be used by ModelField classes. +    # Gets removed for everything else. +    kwargs['model_field'] = model_field + +    return kwargs + + +def get_relation_kwargs(field_name, relation_info): +    """ +    Creates a default instance of a flat relational field. +    """ +    model_field, related_model, to_many, has_through_model = relation_info +    kwargs = { +        'queryset': related_model._default_manager, +        'view_name': get_detail_view_name(related_model) +    } + +    if to_many: +        kwargs['many'] = True + +    if has_through_model: +        kwargs['read_only'] = True +        kwargs.pop('queryset', None) + +    if model_field: +        if model_field.null or model_field.blank: +            kwargs['required'] = False +        if model_field.verbose_name and needs_label(model_field, field_name): +            kwargs['label'] = capfirst(model_field.verbose_name) +        if not model_field.editable: +            kwargs['read_only'] = True +            kwargs.pop('queryset', None) +        help_text = clean_manytomany_helptext(model_field.help_text) +        if help_text: +            kwargs['help_text'] = help_text + +    return kwargs + + +def get_nested_relation_kwargs(relation_info): +    kwargs = {'read_only': True} +    if relation_info.to_many: +        kwargs['many'] = True +    return kwargs + + +def get_url_kwargs(model_field): +    return { +        'view_name': get_detail_view_name(model_field) +    } diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 6d53aed1..470af51b 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -2,11 +2,12 @@  Utility functions to return a formatted name and description for a given view.  """  from __future__ import unicode_literals +import re  from django.utils.html import escape  from django.utils.safestring import mark_safe -from rest_framework.compat import apply_markdown -import re + +from rest_framework.compat import apply_markdown, force_text  def remove_trailing_string(content, trailing): @@ -28,6 +29,7 @@ def dedent(content):      as it fails to dedent multiline docstrings that include      unindented text on the initial line.      """ +    content = force_text(content)      whitespace_counts = [len(line) - len(line.lstrip(' '))                           for line in content.splitlines()[1:] if line.lstrip()] diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py new file mode 100644 index 00000000..edc591e9 --- /dev/null +++ b/rest_framework/utils/html.py @@ -0,0 +1,88 @@ +""" +Helpers for dealing with HTML input. +""" +import re + + +def is_html_input(dictionary): +    # MultiDict type datastructures are used to represent HTML form input, +    # which may have more than one value for each key. +    return hasattr(dictionary, 'getlist') + + +def parse_html_list(dictionary, prefix=''): +    """ +    Used to suport list values in HTML forms. +    Supports lists of primitives and/or dictionaries. + +    * List of primitives. + +    { +        '[0]': 'abc', +        '[1]': 'def', +        '[2]': 'hij' +    } +        --> +    [ +        'abc', +        'def', +        'hij' +    ] + +    * List of dictionaries. + +    { +        '[0]foo': 'abc', +        '[0]bar': 'def', +        '[1]foo': 'hij', +        '[2]bar': 'klm', +    } +        --> +    [ +        {'foo': 'abc', 'bar': 'def'}, +        {'foo': 'hij', 'bar': 'klm'} +    ] +    """ +    Dict = type(dictionary) +    ret = {} +    regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix)) +    for field, value in dictionary.items(): +        match = regex.match(field) +        if not match: +            continue +        index, key = match.groups() +        index = int(index) +        if not key: +            ret[index] = value +        elif isinstance(ret.get(index), dict): +            ret[index][key] = value +        else: +            ret[index] = Dict({key: value}) +    return [ret[item] for item in sorted(ret.keys())] + + +def parse_html_dict(dictionary, prefix): +    """ +    Used to support dictionary values in HTML forms. + +    { +        'profile.username': 'example', +        'profile.email': 'example@example.com', +    } +        --> +    { +        'profile': { +            'username': 'example, +            'email': 'example@example.com' +        } +    } +    """ +    ret = {} +    regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix)) +    for field, value in dictionary.items(): +        match = regex.match(field) +        if not match: +            continue +        key = match.groups()[0] +        ret[key] = value +    return ret diff --git a/rest_framework/utils/humanize_datetime.py b/rest_framework/utils/humanize_datetime.py new file mode 100644 index 00000000..649f2abc --- /dev/null +++ b/rest_framework/utils/humanize_datetime.py @@ -0,0 +1,47 @@ +""" +Helper functions that convert strftime formats into more readable representations. +""" +from rest_framework import ISO_8601 + + +def datetime_formats(formats): +    format = ', '.join(formats).replace( +        ISO_8601, +        'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' +    ) +    return humanize_strptime(format) + + +def date_formats(formats): +    format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') +    return humanize_strptime(format) + + +def time_formats(formats): +    format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') +    return humanize_strptime(format) + + +def humanize_strptime(format_string): +    # Note that we're missing some of the locale specific mappings that +    # don't really make sense. +    mapping = { +        "%Y": "YYYY", +        "%y": "YY", +        "%m": "MM", +        "%b": "[Jan-Dec]", +        "%B": "[January-December]", +        "%d": "DD", +        "%H": "hh", +        "%I": "hh",  # Requires '%p' to differentiate from '%H'. +        "%M": "mm", +        "%S": "ss", +        "%f": "uuuuuu", +        "%a": "[Mon-Sun]", +        "%A": "[Monday-Sunday]", +        "%p": "[AM|PM]", +        "%z": "[+HHMM|-HHMM]" +    } +    for key, val in mapping.items(): +        format_string = format_string.replace(key, val) +    return format_string diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py new file mode 100644 index 00000000..b6c41174 --- /dev/null +++ b/rest_framework/utils/model_meta.py @@ -0,0 +1,129 @@ +""" +Helper function for returning the field information that is associated +with a model class. This includes returning all the forward and reverse +relationships and their associated metadata. + +Usage: `get_field_info(model)` returns a `FieldInfo` instance. +""" +from collections import namedtuple +from django.db import models +from django.utils import six +from django.utils.datastructures import SortedDict +import inspect + + +FieldInfo = namedtuple('FieldResult', [ +    'pk',  # Model field instance +    'fields',  # Dict of field name -> model field instance +    'forward_relations',  # Dict of field name -> RelationInfo +    'reverse_relations',  # Dict of field name -> RelationInfo +    'fields_and_pk',  # Shortcut for 'pk' + 'fields' +    'relations'  # Shortcut for 'forward_relations' + 'reverse_relations' +]) + +RelationInfo = namedtuple('RelationInfo', [ +    'model_field', +    'related', +    'to_many', +    'has_through_model' +]) + + +def _resolve_model(obj): +    """ +    Resolve supplied `obj` to a Django model class. + +    `obj` must be a Django model class itself, or a string +    representation of one.  Useful in situtations like GH #1225 where +    Django may not have resolved a string-based reference to a model in +    another model's foreign key definition. + +    String representations should have the format: +        'appname.ModelName' +    """ +    if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: +        app_name, model_name = obj.split('.') +        return models.get_model(app_name, model_name) +    elif inspect.isclass(obj) and issubclass(obj, models.Model): +        return obj +    raise ValueError("{0} is not a Django model".format(obj)) + + +def get_field_info(model): +    """ +    Given a model class, returns a `FieldInfo` instance containing metadata +    about the various field types on the model. +    """ +    opts = model._meta.concrete_model._meta + +    # Deal with the primary key. +    pk = opts.pk +    while pk.rel and pk.rel.parent_link: +        # If model is a child via multitable inheritance, use parent's pk. +        pk = pk.rel.to._meta.pk + +    # Deal with regular fields. +    fields = SortedDict() +    for field in [field for field in opts.fields if field.serialize and not field.rel]: +        fields[field.name] = field + +    # Deal with forward relationships. +    forward_relations = SortedDict() +    for field in [field for field in opts.fields if field.serialize and field.rel]: +        forward_relations[field.name] = RelationInfo( +            model_field=field, +            related=_resolve_model(field.rel.to), +            to_many=False, +            has_through_model=False +        ) + +    # Deal with forward many-to-many relationships. +    for field in [field for field in opts.many_to_many if field.serialize]: +        forward_relations[field.name] = RelationInfo( +            model_field=field, +            related=_resolve_model(field.rel.to), +            to_many=True, +            has_through_model=( +                not field.rel.through._meta.auto_created +            ) +        ) + +    # Deal with reverse relationships. +    reverse_relations = SortedDict() +    for relation in opts.get_all_related_objects(): +        accessor_name = relation.get_accessor_name() +        reverse_relations[accessor_name] = RelationInfo( +            model_field=None, +            related=relation.model, +            to_many=relation.field.rel.multiple, +            has_through_model=False +        ) + +    # Deal with reverse many-to-many relationships. +    for relation in opts.get_all_related_many_to_many_objects(): +        accessor_name = relation.get_accessor_name() +        reverse_relations[accessor_name] = RelationInfo( +            model_field=None, +            related=relation.model, +            to_many=True, +            has_through_model=( +                hasattr(relation.field.rel, 'through') and +                not relation.field.rel.through._meta.auto_created +            ) +        ) + +    # Shortcut that merges both regular fields and the pk, +    # for simplifying regular field lookup. +    fields_and_pk = SortedDict() +    fields_and_pk['pk'] = pk +    fields_and_pk[pk.name] = pk +    fields_and_pk.update(fields) + +    # Shortcut that merges both forward and reverse relationships + +    relations = SortedDict( +        list(forward_relations.items()) + +        list(reverse_relations.items()) +    ) + +    return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations) diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py new file mode 100644 index 00000000..e64fdd22 --- /dev/null +++ b/rest_framework/utils/representation.py @@ -0,0 +1,87 @@ +""" +Helper functions for creating user-friendly representations +of serializer classes and serializer fields. +""" +from django.db import models +import re + + +def manager_repr(value): +    model = value.model +    opts = model._meta +    for _, name, manager in opts.concrete_managers + opts.abstract_managers: +        if manager == value: +            return '%s.%s.all()' % (model._meta.object_name, name) +    return repr(value) + + +def smart_repr(value): +    if isinstance(value, models.Manager): +        return manager_repr(value) + +    value = repr(value) + +    # Representations like u'help text' +    # should simply be presented as 'help text' +    if value.startswith("u'") and value.endswith("'"): +        return value[1:] + +    # Representations like +    # <django.core.validators.RegexValidator object at 0x1047af050> +    # Should be presented as +    # <django.core.validators.RegexValidator object> +    value = re.sub(' at 0x[0-9a-f]{4,32}>', '>', value) + +    return value + + +def field_repr(field, force_many=False): +    kwargs = field._kwargs +    if force_many: +        kwargs = kwargs.copy() +        kwargs['many'] = True +        kwargs.pop('child', None) + +    arg_string = ', '.join([smart_repr(val) for val in field._args]) +    kwarg_string = ', '.join([ +        '%s=%s' % (key, smart_repr(val)) +        for key, val in sorted(kwargs.items()) +    ]) +    if arg_string and kwarg_string: +        arg_string += ', ' + +    if force_many: +        class_name = force_many.__class__.__name__ +    else: +        class_name = field.__class__.__name__ + +    return "%s(%s%s)" % (class_name, arg_string, kwarg_string) + + +def serializer_repr(serializer, indent, force_many=None): +    ret = field_repr(serializer, force_many) + ':' +    indent_str = '    ' * indent + +    if force_many: +        fields = force_many.fields +    else: +        fields = serializer.fields + +    for field_name, field in fields.items(): +        ret += '\n' + indent_str + field_name + ' = ' +        if hasattr(field, 'fields'): +            ret += serializer_repr(field, indent + 1) +        elif hasattr(field, 'child'): +            ret += list_repr(field, indent + 1) +        elif hasattr(field, 'child_relation'): +            ret += field_repr(field.child_relation, force_many=field.child_relation) +        else: +            ret += field_repr(field) +    return ret + + +def list_repr(serializer, indent): +    child = serializer.child +    if hasattr(child, 'fields'): +        return serializer_repr(serializer, indent, force_many=child) +    return field_repr(serializer) diff --git a/rest_framework/views.py b/rest_framework/views.py index 38346ab7..9f08a4ad 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework.  """  from __future__ import unicode_literals -from django.core.exceptions import PermissionDenied +from django.core.exceptions import PermissionDenied, ValidationError, NON_FIELD_ERRORS  from django.http import Http404  from django.utils.datastructures import SortedDict  from django.views.decorators.csrf import csrf_exempt @@ -51,7 +51,8 @@ def exception_handler(exc):      Returns the response that should be used for any given exception.      By default we handle the REST framework `APIException`, and also -    Django's builtin `Http404` and `PermissionDenied` exceptions. +    Django's built-in `ValidationError`, `Http404` and `PermissionDenied` +    exceptions.      Any unhandled exceptions may return `None`, which will cause a 500 error      to be raised. @@ -61,13 +62,22 @@ def exception_handler(exc):          if getattr(exc, 'auth_header', None):              headers['WWW-Authenticate'] = exc.auth_header          if getattr(exc, 'wait', None): -            headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait              headers['Retry-After'] = '%d' % exc.wait          return Response({'detail': exc.detail},                          status=exc.status_code,                          headers=headers) +    elif isinstance(exc, ValidationError): +        # ValidationErrors may include the non-field key named '__all__'. +        # When returning a response we map this to a key name that can be +        # modified in settings. +        if NON_FIELD_ERRORS in exc.message_dict: +            errors = exc.message_dict.pop(NON_FIELD_ERRORS) +            exc.message_dict[api_settings.NON_FIELD_ERRORS_KEY] = errors +        return Response(exc.message_dict, +                        status=status.HTTP_400_BAD_REQUEST) +      elif isinstance(exc, Http404):          return Response({'detail': 'Not found'},                          status=status.HTTP_404_NOT_FOUND) diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index bb5b304e..84b4bd8d 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -20,6 +20,7 @@ from __future__ import unicode_literals  from functools import update_wrapper  from django.utils.decorators import classonlymethod +from django.views.decorators.csrf import csrf_exempt  from rest_framework import views, generics, mixins @@ -89,7 +90,7 @@ class ViewSetMixin(object):          # resolved URL.          view.cls = cls          view.suffix = initkwargs.get('suffix', None) -        return view +        return csrf_exempt(view)      def initialize_request(self, request, *args, **kargs):          """ | 
