diff options
| author | Tom Christie | 2014-08-29 16:46:26 +0100 | 
|---|---|---|
| committer | Tom Christie | 2014-08-29 16:46:26 +0100 | 
| commit | 4ac4676a40b121d27cfd1173ff548d96b8d3de2f (patch) | |
| tree | 560509db11316a36189088dd8b03df4126a696cd | |
| parent | 371d30aa8737c4b3aaf28ee10cc2b77a9c4d1fd9 (diff) | |
| download | django-rest-framework-4ac4676a40b121d27cfd1173ff548d96b8d3de2f.tar.bz2 | |
First pass
| -rw-r--r-- | rest_framework/fields.py | 1166 | ||||
| -rw-r--r-- | rest_framework/generics.py | 10 | ||||
| -rw-r--r-- | rest_framework/mixins.py | 24 | ||||
| -rw-r--r-- | rest_framework/pagination.py | 22 | ||||
| -rw-r--r-- | rest_framework/relations.py | 486 | ||||
| -rw-r--r-- | rest_framework/renderers.py | 10 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 1096 | ||||
| -rw-r--r-- | rest_framework/utils/encoders.py | 18 | ||||
| -rw-r--r-- | rest_framework/utils/html.py | 86 | 
9 files changed, 632 insertions, 2286 deletions
| diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 9d707c9b..a83bf94c 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,1038 +1,308 @@ -""" -Serializer fields perform validation on incoming data. +from rest_framework.utils import html -They are very similar to Django's form fields. -""" -from __future__ import unicode_literals -import copy -import datetime -import inspect -import re -import warnings -from decimal import Decimal, DecimalException -from django import forms -from django.core import validators -from django.core.exceptions import ValidationError -from django.conf import settings -from django.db.models.fields import BLANK_CHOICE_DASH -from django.http import QueryDict -from django.forms import widgets -from django.utils import six, timezone -from django.utils.encoding import is_protected_type -from django.utils.translation import ugettext_lazy as _ -from django.utils.datastructures import SortedDict -from django.utils.dateparse import parse_date, parse_datetime, parse_time -from rest_framework import ISO_8601 -from rest_framework.compat import ( -    BytesIO, smart_text, -    force_text, is_non_str_iterable -) -from rest_framework.settings import api_settings - - -def is_simple_callable(obj): +class empty:      """ -    True if the object is a callable that takes no arguments. -    """ -    function = inspect.isfunction(obj) -    method = inspect.ismethod(obj) - -    if not (function or method): -        return False +    This class is used to represent no data being provided for a given input +    or output value. -    args, _, _, defaults = inspect.getargspec(obj) -    len_args = len(args) if function else len(args) - 1 -    len_defaults = len(defaults) if defaults else 0 -    return len_args <= len_defaults +    It is required because `None` may be a valid input or output value. +    """ +    pass -def get_component(obj, attr_name): +def get_attribute(instance, attrs):      """ -    Given an object, and an attribute name, -    return that attribute on the object. +    Similar to Python's built in `getattr(instance, attr)`, +    but takes a list of nested attributes, instead of a single attribute.      """ -    if isinstance(obj, dict): -        val = obj.get(attr_name) -    else: -        val = getattr(obj, attr_name) - -    if is_simple_callable(val): -        return val() -    return val - - -def readable_datetime_formats(formats): -    format = ', '.join(formats).replace( -        ISO_8601, -        'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' -    ) -    return humanize_strptime(format) - - -def readable_date_formats(formats): -    format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') -    return humanize_strptime(format) +    for attr in attrs: +        instance = getattr(instance, attr) +    return instance -def readable_time_formats(formats): -    format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') -    return humanize_strptime(format) - - -def humanize_strptime(format_string): -    # Note that we're missing some of the locale specific mappings that -    # don't really make sense. -    mapping = { -        "%Y": "YYYY", -        "%y": "YY", -        "%m": "MM", -        "%b": "[Jan-Dec]", -        "%B": "[January-December]", -        "%d": "DD", -        "%H": "hh", -        "%I": "hh",  # Requires '%p' to differentiate from '%H'. -        "%M": "mm", -        "%S": "ss", -        "%f": "uuuuuu", -        "%a": "[Mon-Sun]", -        "%A": "[Monday-Sunday]", -        "%p": "[AM|PM]", -        "%z": "[+HHMM|-HHMM]" -    } -    for key, val in mapping.items(): -        format_string = format_string.replace(key, val) -    return format_string - - -def strip_multiple_choice_msg(help_text): +def set_value(dictionary, keys, value):      """ -    Remove the 'Hold down "control" ...' message that is Django enforces in -    select multiple fields on ModelForms.  (Required for 1.5 and earlier) +    Similar to Python's built in `dictionary[key] = value`, +    but takes a list of nested keys instead of a single key. -    See https://code.djangoproject.com/ticket/9321 +    set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2} +    set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2} +    set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}}      """ -    multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.') -    multiple_choice_msg = force_text(multiple_choice_msg) +    if not keys: +        dictionary.update(value) +        return -    return help_text.replace(multiple_choice_msg, '') +    for key in keys[:-1]: +        if key not in dictionary: +            dictionary[key] = {} +        dictionary = dictionary[key] +    dictionary[keys[-1]] = value -class Field(object): -    read_only = True -    creation_counter = 0 -    empty = '' -    type_name = None -    partial = False -    use_files = False -    form_field_class = forms.CharField -    type_label = 'field' -    widget = None -    def __init__(self, source=None, label=None, help_text=None): -        self.parent = None +class ValidationError(Exception): +    pass -        self.creation_counter = Field.creation_counter -        Field.creation_counter += 1 -        self.source = source +class SkipField(Exception): +    pass -        if label is not None: -            self.label = smart_text(label) -        else: -            self.label = None -        if help_text is not None: -            self.help_text = strip_multiple_choice_msg(smart_text(help_text)) -        else: -            self.help_text = None +class Field(object): +    _creation_counter = 0 -        self._errors = [] -        self._value = None -        self._name = None +    MESSAGES = { +        'required': 'This field is required.' +    } -    @property -    def errors(self): -        return self._errors +    _NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' +    _NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' +    _NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' +    _NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' +    _MISSING_ERROR_MESSAGE = ( +        'ValidationError raised by `{class_name}`, but error key `{key}` does ' +        'not exist in the `MESSAGES` dictionary.' +    ) -    def widget_html(self): -        if not self.widget: -            return '' +    def __init__(self, read_only=False, write_only=False, +                 required=None, default=empty, initial=None, source=None, +                 label=None, style=None): +        self._creation_counter = Field._creation_counter +        Field._creation_counter += 1 -        attrs = {} -        if 'id' not in self.widget.attrs: -            attrs['id'] = self._name +        # If `required` is unset, then use `True` unless a default is provided. +        if required is None: +            required = default is empty and not read_only -        return self.widget.render(self._name, self._value, attrs=attrs) +        # Some combinations of keyword arguments do not make sense. +        assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY +        assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED +        assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT +        assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT -    def label_tag(self): -        return '<label for="%s">%s:</label>' % (self._name, self.label) +        self.read_only = read_only +        self.write_only = write_only +        self.required = required +        self.default = default +        self.source = source +        self.initial = initial +        self.label = label +        self.style = {} if style is None else style -    def initialize(self, parent, field_name): +    def bind(self, field_name, parent, root):          """ -        Called to set up a field prior to field_to_native or field_from_native. - -        parent - The parent serializer. -        field_name - The name of the field being initialized. +        Setup the context for the field instance.          """ +        self.field_name = field_name          self.parent = parent -        self.root = parent.root or parent -        self.context = self.root.context -        self.partial = self.root.partial -        if self.partial: -            self.required = False +        self.root = root -    def field_from_native(self, data, files, field_name, into): -        """ -        Given a dictionary and a field name, updates the dictionary `into`, -        with the field and it's deserialized value. -        """ -        return +        # `self.label` should deafult to being based on the field name. +        if self.label is None: +            self.label = self.field_name.replace('_', ' ').capitalize() -    def field_to_native(self, obj, field_name): -        """ -        Given an object and a field name, returns the value that should be -        serialized for that field. -        """ -        if obj is None: -            return self.empty +        # self.source should default to being the same as the field name. +        if self.source is None: +            self.source = field_name +        # self.source_attrs is a list of attributes that need to be looked up +        # when serializing the instance, or populating the validated data.          if self.source == '*': -            return self.to_native(obj) - -        source = self.source or field_name -        value = obj - -        for component in source.split('.'): -            value = get_component(value, component) -            if value is None: -                break - -        return self.to_native(value) +            self.source_attrs = [] +        else: +            self.source_attrs = self.source.split('.') -    def to_native(self, value): +    def get_initial(self):          """ -        Converts the field's value into it's simple representation. +        Return a value to use when the field is being returned as a primative +        value, without any object instance.          """ -        if is_simple_callable(value): -            value = value() - -        if is_protected_type(value): -            return value -        elif (is_non_str_iterable(value) and -              not isinstance(value, (dict, six.string_types))): -            return [self.to_native(item) for item in value] -        elif isinstance(value, dict): -            # Make sure we preserve field ordering, if it exists -            ret = SortedDict() -            for key, val in value.items(): -                ret[key] = self.to_native(val) -            return ret -        return force_text(value) +        return self.initial -    def attributes(self): +    def get_value(self, dictionary):          """ -        Returns a dictionary of attributes to be used when serializing to xml. +        Given the *incoming* primative data, return the value for this field +        that should be validated and transformed to a native value.          """ -        if self.type_name: -            return {'type': self.type_name} -        return {} - -    def metadata(self): -        metadata = SortedDict() -        metadata['type'] = self.type_label -        metadata['required'] = getattr(self, 'required', False) -        optional_attrs = ['read_only', 'label', 'help_text', -                          'min_length', 'max_length'] -        for attr in optional_attrs: -            value = getattr(self, attr, None) -            if value is not None and value != '': -                metadata[attr] = force_text(value, strings_only=True) -        return metadata - - -class WritableField(Field): -    """ -    Base for read/write fields. -    """ -    write_only = False -    default_validators = [] -    default_error_messages = { -        'required': _('This field is required.'), -        'invalid': _('Invalid value.'), -    } -    widget = widgets.TextInput -    default = None - -    def __init__(self, source=None, label=None, help_text=None, -                 read_only=False, write_only=False, required=None, -                 validators=[], error_messages=None, widget=None, -                 default=None, blank=None): - -        super(WritableField, self).__init__(source=source, label=label, help_text=help_text) - -        self.read_only = read_only -        self.write_only = write_only - -        assert not (read_only and write_only), "Cannot set read_only=True and write_only=True" - -        if required is None: -            self.required = not(read_only) -        else: -            assert not (read_only and required), "Cannot set required=True and read_only=True" -            self.required = required - -        messages = {} -        for c in reversed(self.__class__.__mro__): -            messages.update(getattr(c, 'default_error_messages', {})) -        messages.update(error_messages or {}) -        self.error_messages = messages +        return dictionary.get(self.field_name, empty) -        self.validators = self.default_validators + validators -        self.default = default if default is not None else self.default - -        # Widgets are only used for HTML forms. -        widget = widget or self.widget -        if isinstance(widget, type): -            widget = widget() -        self.widget = widget +    def get_attribute(self, instance): +        """ +        Given the *outgoing* object instance, return the value for this field +        that should be returned as a primative value. +        """ +        return get_attribute(instance, self.source_attrs) -    def __deepcopy__(self, memo): -        result = copy.copy(self) -        memo[id(self)] = result -        result.validators = self.validators[:] -        return result +    def get_default(self): +        """ +        Return the default value to use when validating data if no input +        is provided for this field. -    def get_default_value(self): -        if is_simple_callable(self.default): -            return self.default() +        If a default has not been set for this field then this will simply +        return `empty`, indicating that no value should be set in the +        validated data for this field. +        """ +        if self.default is empty: +            raise SkipField()          return self.default -    def validate(self, value): -        if value in validators.EMPTY_VALUES and self.required: -            raise ValidationError(self.error_messages['required']) +    def validate(self, data=empty): +        """ +        Validate a simple representation and return the internal value. -    def run_validators(self, value): -        if value in validators.EMPTY_VALUES: -            return -        errors = [] -        for v in self.validators: -            try: -                v(value) -            except ValidationError as e: -                if hasattr(e, 'code') and e.code in self.error_messages: -                    message = self.error_messages[e.code] -                    if e.params: -                        message = message % e.params -                    errors.append(message) -                else: -                    errors.extend(e.messages) -        if errors: -            raise ValidationError(errors) +        The provided data may be `empty` if no representation was included. +        May return `empty` if the field should not be included in the +        validated data. +        """ +        if data is empty: +            if self.required: +                self.fail('required') +            return self.get_default() -    def field_to_native(self, obj, field_name): -        if self.write_only: -            return None -        return super(WritableField, self).field_to_native(obj, field_name) +        return self.to_native(data) -    def field_from_native(self, data, files, field_name, into): +    def to_native(self, data):          """ -        Given a dictionary and a field name, updates the dictionary `into`, -        with the field and it's deserialized value. +        Transform the *incoming* primative data into a native value.          """ -        if self.read_only: -            return - -        try: -            data = data or {} -            if self.use_files: -                files = files or {} -                try: -                    native = files[field_name] -                except KeyError: -                    native = data[field_name] -            else: -                native = data[field_name] -        except KeyError: -            if self.default is not None and not self.partial: -                # Note: partial updates shouldn't set defaults -                native = self.get_default_value() -            else: -                if self.required: -                    raise ValidationError(self.error_messages['required']) -                return - -        value = self.from_native(native) -        if self.source == '*': -            if value: -                into.update(value) -        else: -            self.validate(value) -            self.run_validators(value) -            into[self.source or field_name] = value +        return data -    def from_native(self, value): +    def to_primative(self, value):          """ -        Reverts a simple representation back to the field's value. +        Transform the *outgoing* native value into primative data.          """          return value - -class ModelField(WritableField): -    """ -    A generic field that can be used against an arbitrary model field. -    """ -    def __init__(self, *args, **kwargs): +    def fail(self, key, **kwargs): +        """ +        A helper method that simply raises a validation error. +        """          try: -            self.model_field = kwargs.pop('model_field') +            raise ValidationError(self.MESSAGES[key].format(**kwargs))          except KeyError: -            raise ValueError("ModelField requires 'model_field' kwarg") - -        self.min_length = kwargs.pop('min_length', -                                     getattr(self.model_field, 'min_length', None)) -        self.max_length = kwargs.pop('max_length', -                                     getattr(self.model_field, 'max_length', None)) -        self.min_value = kwargs.pop('min_value', -                                    getattr(self.model_field, 'min_value', None)) -        self.max_value = kwargs.pop('max_value', -                                    getattr(self.model_field, 'max_value', None)) - -        super(ModelField, self).__init__(*args, **kwargs) - -        if self.min_length is not None: -            self.validators.append(validators.MinLengthValidator(self.min_length)) -        if self.max_length is not None: -            self.validators.append(validators.MaxLengthValidator(self.max_length)) -        if self.min_value is not None: -            self.validators.append(validators.MinValueValidator(self.min_value)) -        if self.max_value is not None: -            self.validators.append(validators.MaxValueValidator(self.max_value)) - -    def from_native(self, value): -        rel = getattr(self.model_field, "rel", None) -        if rel is not None: -            return rel.to._meta.get_field(rel.field_name).to_python(value) -        else: -            return self.model_field.to_python(value) +            class_name = self.__class__.__name__ +            msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) +            raise AssertionError(msg) -    def field_to_native(self, obj, field_name): -        value = self.model_field._get_val_from_obj(obj) -        if is_protected_type(value): -            return value -        return self.model_field.value_to_string(obj) -    def attributes(self): -        return { -            "type": self.model_field.get_internal_type() -        } - - -# Typed Fields - -class BooleanField(WritableField): -    type_name = 'BooleanField' -    type_label = 'boolean' -    form_field_class = forms.BooleanField -    widget = widgets.CheckboxInput -    default_error_messages = { -        'invalid': _("'%s' value must be either True or False."), +class BooleanField(Field): +    MESSAGES = { +        'required': 'This field is required.', +        'invalid_value': '`{input}` is not a valid boolean.'      } -    empty = False - -    def field_from_native(self, data, files, field_name, into): -        # HTML checkboxes do not explicitly represent unchecked as `False` -        # we deal with that here... -        if isinstance(data, QueryDict) and self.default is None: -            self.default = False - -        return super(BooleanField, self).field_from_native( -            data, files, field_name, into -        ) - -    def from_native(self, value): -        if value in ('true', 't', 'True', '1'): +    TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True} +    FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False} + +    def get_value(self, dictionary): +        if html.is_html_input(dictionary): +            # HTML forms do not send a `False` value on an empty checkbox, +            # so we override the default empty value to be False. +            return dictionary.get(self.field_name, False) +        return dictionary.get(self.field_name, empty) + +    def to_native(self, data): +        if data in self.TRUE_VALUES:              return True -        if value in ('false', 'f', 'False', '0'): +        elif data in self.FALSE_VALUES:              return False -        return bool(value) +        self.fail('invalid_value', input=data) -class CharField(WritableField): -    type_name = 'CharField' -    type_label = 'string' -    form_field_class = forms.CharField +class CharField(Field): +    MESSAGES = { +        'required': 'This field is required.', +        'blank': 'This field may not be blank.' +    } -    def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs): -        self.max_length, self.min_length = max_length, min_length -        self.allow_none = allow_none +    def __init__(self, *args, **kwargs): +        self.allow_blank = kwargs.pop('allow_blank', False)          super(CharField, self).__init__(*args, **kwargs) -        if min_length is not None: -            self.validators.append(validators.MinLengthValidator(min_length)) -        if max_length is not None: -            self.validators.append(validators.MaxLengthValidator(max_length)) - -    def from_native(self, value): -        if isinstance(value, six.string_types): -            return value - -        if value is None and not self.allow_none: -            return '' - -        return smart_text(value) - -class URLField(CharField): -    type_name = 'URLField' -    type_label = 'url' - -    def __init__(self, **kwargs): -        if 'validators' not in kwargs: -            kwargs['validators'] = [validators.URLValidator()] -        super(URLField, self).__init__(**kwargs) +    def to_native(self, data): +        if data == '' and not self.allow_blank: +            self.fail('blank') +        return str(data) -class SlugField(CharField): -    type_name = 'SlugField' -    type_label = 'slug' -    form_field_class = forms.SlugField - -    default_error_messages = { -        'invalid': _("Enter a valid 'slug' consisting of letters, numbers," -                     " underscores or hyphens."), +class ChoiceField(Field): +    MESSAGES = { +        'required': 'This field is required.', +        'invalid_choice': '`{input}` is not a valid choice.'      } -    default_validators = [validators.validate_slug] +    coerce_to_type = str      def __init__(self, *args, **kwargs): -        super(SlugField, self).__init__(*args, **kwargs) - +        choices = kwargs.pop('choices') + +        assert choices, '`choices` argument is required and may not be empty' + +        # Allow either single or paired choices style: +        # choices = [1, 2, 3] +        # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] +        pairs = [ +            isinstance(item, (list, tuple)) and len(item) == 2 +            for item in choices +        ] +        if all(pairs): +            self.choices = {key: val for key, val in choices} +        else: +            self.choices = {item: item for item in choices} -class ChoiceField(WritableField): -    type_name = 'ChoiceField' -    type_label = 'choice' -    form_field_class = forms.ChoiceField -    widget = widgets.Select -    default_error_messages = { -        'invalid_choice': _('Select a valid choice. %(value)s is not one of ' -                            'the available choices.'), -    } +        # Map the string representation of choices to the underlying value. +        # Allows us to deal with eg. integer choices while supporting either +        # integer or string input, but still get the correct datatype out. +        self.choice_strings_to_values = { +            str(key): key for key in self.choices.keys() +        } -    def __init__(self, choices=(), blank_display_value=None, *args, **kwargs): -        self.empty = kwargs.pop('empty', '')          super(ChoiceField, self).__init__(*args, **kwargs) -        self.choices = choices -        if not self.required: -            if blank_display_value is None: -                blank_choice = BLANK_CHOICE_DASH -            else: -                blank_choice = [('', blank_display_value)] -            self.choices = blank_choice + self.choices - -    def _get_choices(self): -        return self._choices - -    def _set_choices(self, value): -        # Setting choices also sets the choices on the widget. -        # choices can be any iterable, but we call list() on it because -        # it will be consumed more than once. -        self._choices = self.widget.choices = list(value) - -    choices = property(_get_choices, _set_choices) - -    def metadata(self): -        data = super(ChoiceField, self).metadata() -        data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices] -        return data - -    def validate(self, value): -        """ -        Validates that the input is in self.choices. -        """ -        super(ChoiceField, self).validate(value) -        if value and not self.valid_value(value): -            raise ValidationError(self.error_messages['invalid_choice'] % {'value': value}) - -    def valid_value(self, value): -        """ -        Check to see if the provided value is a valid choice. -        """ -        for k, v in self.choices: -            if isinstance(v, (list, tuple)): -                # This is an optgroup, so look inside the group for options -                for k2, v2 in v: -                    if value == smart_text(k2): -                        return True -            else: -                if value == smart_text(k) or value == k: -                    return True -        return False - -    def from_native(self, value): -        value = super(ChoiceField, self).from_native(value) -        if value == self.empty or value in validators.EMPTY_VALUES: -            return self.empty -        return value - - -class EmailField(CharField): -    type_name = 'EmailField' -    type_label = 'email' -    form_field_class = forms.EmailField - -    default_error_messages = { -        'invalid': _('Enter a valid email address.'), -    } -    default_validators = [validators.validate_email] - -    def from_native(self, value): -        ret = super(EmailField, self).from_native(value) -        if ret is None: -            return None -        return ret.strip() - - -class RegexField(CharField): -    type_name = 'RegexField' -    type_label = 'regex' -    form_field_class = forms.RegexField - -    def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs): -        super(RegexField, self).__init__(max_length, min_length, *args, **kwargs) -        self.regex = regex - -    def _get_regex(self): -        return self._regex - -    def _set_regex(self, regex): -        if isinstance(regex, six.string_types): -            regex = re.compile(regex) -        self._regex = regex -        if hasattr(self, '_regex_validator') and self._regex_validator in self.validators: -            self.validators.remove(self._regex_validator) -        self._regex_validator = validators.RegexValidator(regex=regex) -        self.validators.append(self._regex_validator) - -    regex = property(_get_regex, _set_regex) - - -class DateField(WritableField): -    type_name = 'DateField' -    type_label = 'date' -    widget = widgets.DateInput -    form_field_class = forms.DateField - -    default_error_messages = { -        'invalid': _("Date has wrong format. Use one of these formats instead: %s"), -    } -    empty = None -    input_formats = api_settings.DATE_INPUT_FORMATS -    format = api_settings.DATE_FORMAT - -    def __init__(self, input_formats=None, format=None, *args, **kwargs): -        self.input_formats = input_formats if input_formats is not None else self.input_formats -        self.format = format if format is not None else self.format -        super(DateField, self).__init__(*args, **kwargs) - -    def from_native(self, value): -        if value in validators.EMPTY_VALUES: -            return None - -        if isinstance(value, datetime.datetime): -            if timezone and settings.USE_TZ and timezone.is_aware(value): -                # Convert aware datetimes to the default time zone -                # before casting them to dates (#17742). -                default_timezone = timezone.get_default_timezone() -                value = timezone.make_naive(value, default_timezone) -            return value.date() -        if isinstance(value, datetime.date): -            return value - -        for format in self.input_formats: -            if format.lower() == ISO_8601: -                try: -                    parsed = parse_date(value) -                except (ValueError, TypeError): -                    pass -                else: -                    if parsed is not None: -                        return parsed -            else: -                try: -                    parsed = datetime.datetime.strptime(value, format) -                except (ValueError, TypeError): -                    pass -                else: -                    return parsed.date() - -        msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats) -        raise ValidationError(msg) - -    def to_native(self, value): -        if value is None or self.format is None: -            return value - -        if isinstance(value, datetime.datetime): -            value = value.date() - -        if self.format.lower() == ISO_8601: -            return value.isoformat() -        return value.strftime(self.format) - - -class DateTimeField(WritableField): -    type_name = 'DateTimeField' -    type_label = 'datetime' -    widget = widgets.DateTimeInput -    form_field_class = forms.DateTimeField - -    default_error_messages = { -        'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), -    } -    empty = None -    input_formats = api_settings.DATETIME_INPUT_FORMATS -    format = api_settings.DATETIME_FORMAT - -    def __init__(self, input_formats=None, format=None, *args, **kwargs): -        self.input_formats = input_formats if input_formats is not None else self.input_formats -        self.format = format if format is not None else self.format -        super(DateTimeField, self).__init__(*args, **kwargs) - -    def from_native(self, value): -        if value in validators.EMPTY_VALUES: -            return None - -        if isinstance(value, datetime.datetime): -            return value -        if isinstance(value, datetime.date): -            value = datetime.datetime(value.year, value.month, value.day) -            if settings.USE_TZ: -                # For backwards compatibility, interpret naive datetimes in -                # local time. This won't work during DST change, but we can't -                # do much about it, so we let the exceptions percolate up the -                # call stack. -                warnings.warn("DateTimeField received a naive datetime (%s)" -                              " while time zone support is active." % value, -                              RuntimeWarning) -                default_timezone = timezone.get_default_timezone() -                value = timezone.make_aware(value, default_timezone) -            return value - -        for format in self.input_formats: -            if format.lower() == ISO_8601: -                try: -                    parsed = parse_datetime(value) -                except (ValueError, TypeError): -                    pass -                else: -                    if parsed is not None: -                        return parsed -            else: -                try: -                    parsed = datetime.datetime.strptime(value, format) -                except (ValueError, TypeError): -                    pass -                else: -                    return parsed - -        msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats) -        raise ValidationError(msg) - -    def to_native(self, value): -        if value is None or self.format is None: -            return value - -        if self.format.lower() == ISO_8601: -            ret = value.isoformat() -            if ret.endswith('+00:00'): -                ret = ret[:-6] + 'Z' -            return ret -        return value.strftime(self.format) - - -class TimeField(WritableField): -    type_name = 'TimeField' -    type_label = 'time' -    widget = widgets.TimeInput -    form_field_class = forms.TimeField - -    default_error_messages = { -        'invalid': _("Time has wrong format. Use one of these formats instead: %s"), -    } -    empty = None -    input_formats = api_settings.TIME_INPUT_FORMATS -    format = api_settings.TIME_FORMAT - -    def __init__(self, input_formats=None, format=None, *args, **kwargs): -        self.input_formats = input_formats if input_formats is not None else self.input_formats -        self.format = format if format is not None else self.format -        super(TimeField, self).__init__(*args, **kwargs) - -    def from_native(self, value): -        if value in validators.EMPTY_VALUES: -            return None - -        if isinstance(value, datetime.time): -            return value - -        for format in self.input_formats: -            if format.lower() == ISO_8601: -                try: -                    parsed = parse_time(value) -                except (ValueError, TypeError): -                    pass -                else: -                    if parsed is not None: -                        return parsed -            else: -                try: -                    parsed = datetime.datetime.strptime(value, format) -                except (ValueError, TypeError): -                    pass -                else: -                    return parsed.time() - -        msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats) -        raise ValidationError(msg) - -    def to_native(self, value): -        if value is None or self.format is None: -            return value - -        if isinstance(value, datetime.datetime): -            value = value.time() - -        if self.format.lower() == ISO_8601: -            return value.isoformat() -        return value.strftime(self.format) - - -class IntegerField(WritableField): -    type_name = 'IntegerField' -    type_label = 'integer' -    form_field_class = forms.IntegerField -    empty = 0 - -    default_error_messages = { -        'invalid': _('Enter a whole number.'), -        'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), -        'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), -    } - -    def __init__(self, max_value=None, min_value=None, *args, **kwargs): -        self.max_value, self.min_value = max_value, min_value -        super(IntegerField, self).__init__(*args, **kwargs) - -        if max_value is not None: -            self.validators.append(validators.MaxValueValidator(max_value)) -        if min_value is not None: -            self.validators.append(validators.MinValueValidator(min_value)) - -    def from_native(self, value): -        if value in validators.EMPTY_VALUES: -            return None - -        try: -            value = int(str(value)) -        except (ValueError, TypeError): -            raise ValidationError(self.error_messages['invalid']) -        return value - - -class FloatField(WritableField): -    type_name = 'FloatField' -    type_label = 'float' -    form_field_class = forms.FloatField -    empty = 0 - -    default_error_messages = { -        'invalid': _("'%s' value must be a float."), -    } - -    def from_native(self, value): -        if value in validators.EMPTY_VALUES: -            return None +    def to_native(self, data):          try: -            return float(value) -        except (TypeError, ValueError): -            msg = self.error_messages['invalid'] % value -            raise ValidationError(msg) - +            return self.choice_strings_to_values[str(data)] +        except KeyError: +            self.fail('invalid_choice', input=data) -class DecimalField(WritableField): -    type_name = 'DecimalField' -    type_label = 'decimal' -    form_field_class = forms.DecimalField -    empty = Decimal('0') -    default_error_messages = { -        'invalid': _('Enter a number.'), -        'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), -        'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), -        'max_digits': _('Ensure that there are no more than %s digits in total.'), -        'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), -        'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') +class MultipleChoiceField(ChoiceField): +    MESSAGES = { +        'required': 'This field is required.', +        'invalid_choice': '`{input}` is not a valid choice.', +        'not_a_list': 'Expected a list of items but got type `{input_type}`'      } -    def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): -        self.max_value, self.min_value = max_value, min_value -        self.max_digits, self.decimal_places = max_digits, decimal_places -        super(DecimalField, self).__init__(*args, **kwargs) - -        if max_value is not None: -            self.validators.append(validators.MaxValueValidator(max_value)) -        if min_value is not None: -            self.validators.append(validators.MinValueValidator(min_value)) - -    def from_native(self, value): -        """ -        Validates that the input is a decimal number. Returns a Decimal -        instance. Returns None for empty values. Ensures that there are no more -        than max_digits in the number, and no more than decimal_places digits -        after the decimal point. -        """ -        if value in validators.EMPTY_VALUES: -            return None -        value = smart_text(value).strip() -        try: -            value = Decimal(value) -        except DecimalException: -            raise ValidationError(self.error_messages['invalid']) -        return value - -    def validate(self, value): -        super(DecimalField, self).validate(value) -        if value in validators.EMPTY_VALUES: -            return -        # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, -        # since it is never equal to itself. However, NaN is the only value that -        # isn't equal to itself, so we can use this to identify NaN -        if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): -            raise ValidationError(self.error_messages['invalid']) -        sign, digittuple, exponent = value.as_tuple() -        decimals = abs(exponent) -        # digittuple doesn't include any leading zeros. -        digits = len(digittuple) -        if decimals > digits: -            # We have leading zeros up to or past the decimal point.  Count -            # everything past the decimal point as a digit.  We do not count -            # 0 before the decimal point as a digit since that would mean -            # we would not allow max_digits = decimal_places. -            digits = decimals -        whole_digits = digits - decimals - -        if self.max_digits is not None and digits > self.max_digits: -            raise ValidationError(self.error_messages['max_digits'] % self.max_digits) -        if self.decimal_places is not None and decimals > self.decimal_places: -            raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) -        if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): -            raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) -        return value - +    def to_native(self, data): +        if not hasattr(data, '__iter__'): +            self.fail('not_a_list', input_type=type(data).__name__) +        return set([ +            super(MultipleChoiceField, self).to_native(item) +            for item in data +        ]) -class FileField(WritableField): -    use_files = True -    type_name = 'FileField' -    type_label = 'file upload' -    form_field_class = forms.FileField -    widget = widgets.FileInput -    default_error_messages = { -        'invalid': _("No file was submitted. Check the encoding type on the form."), -        'missing': _("No file was submitted."), -        'empty': _("The submitted file is empty."), -        'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'), -        'contradiction': _('Please either submit a file or check the clear checkbox, not both.') +class IntegerField(Field): +    MESSAGES = { +        'required': 'This field is required.', +        'invalid_integer': 'A valid integer is required.'      } -    def __init__(self, *args, **kwargs): -        self.max_length = kwargs.pop('max_length', None) -        self.allow_empty_file = kwargs.pop('allow_empty_file', False) -        super(FileField, self).__init__(*args, **kwargs) - -    def from_native(self, data): -        if data in validators.EMPTY_VALUES: -            return None - -        # UploadedFile objects should have name and size attributes. +    def to_native(self, data):          try: -            file_name = data.name -            file_size = data.size -        except AttributeError: -            raise ValidationError(self.error_messages['invalid']) - -        if self.max_length is not None and len(file_name) > self.max_length: -            error_values = {'max': self.max_length, 'length': len(file_name)} -            raise ValidationError(self.error_messages['max_length'] % error_values) -        if not file_name: -            raise ValidationError(self.error_messages['invalid']) -        if not self.allow_empty_file and not file_size: -            raise ValidationError(self.error_messages['empty']) - +            data = int(str(data)) +        except (ValueError, TypeError): +            self.fail('invalid_integer')          return data -    def to_native(self, value): -        return value.name - - -class ImageField(FileField): -    use_files = True -    type_name = 'ImageField' -    type_label = 'image upload' -    form_field_class = forms.ImageField - -    default_error_messages = { -        'invalid_image': _("Upload a valid image. The file you uploaded was " -                           "either not an image or a corrupted image."), -    } - -    def from_native(self, data): -        """ -        Checks that the file-upload field data contains a valid image (GIF, JPG, -        PNG, possibly others -- whatever the Python Imaging Library supports). -        """ -        f = super(ImageField, self).from_native(data) -        if f is None: -            return None - -        from rest_framework.compat import Image -        assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.' - -        # We need to get a file object for PIL. We might have a path or we might -        # have to read the data into memory. -        if hasattr(data, 'temporary_file_path'): -            file = data.temporary_file_path() -        else: -            if hasattr(data, 'read'): -                file = BytesIO(data.read()) -            else: -                file = BytesIO(data['content']) - -        try: -            # load() could spot a truncated JPEG, but it loads the entire -            # image in memory, which is a DoS vector. See #3848 and #18520. -            # verify() must be called immediately after the constructor. -            Image.open(file).verify() -        except ImportError: -            # Under PyPy, it is possible to import PIL. However, the underlying -            # _imaging C module isn't available, so an ImportError will be -            # raised. Catch and re-raise. -            raise -        except Exception:  # Python Imaging Library doesn't recognize it as an image -            raise ValidationError(self.error_messages['invalid_image']) -        if hasattr(f, 'seek') and callable(f.seek): -            f.seek(0) -        return f - -class SerializerMethodField(Field): -    """ -    A field that gets its value by calling a method on the serializer it's attached to. -    """ - -    def __init__(self, method_name, *args, **kwargs): -        self.method_name = method_name -        super(SerializerMethodField, self).__init__(*args, **kwargs) - -    def field_to_native(self, obj, field_name): -        value = getattr(self.parent, self.method_name)(obj) -        return self.to_native(value) +class MethodField(Field): +    def __init__(self, **kwargs): +        kwargs['source'] = '*' +        kwargs['read_only'] = True +        super(MethodField, self).__init__(**kwargs) + +    def to_primative(self, value): +        attr = 'get_{field_name}'.format(field_name=self.field_name) +        method = getattr(self.parent, attr) +        return method(value) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index b3bd6ce9..6705cbb2 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -79,18 +79,16 @@ class GenericAPIView(views.APIView):              'view': self          } -    def get_serializer(self, instance=None, data=None, files=None, many=False, -                       partial=False, allow_add_remove=False): +    def get_serializer(self, instance=None, data=None, many=False, partial=False):          """          Return the serializer instance that should be used for validating and          deserializing input, and for serializing output.          """          serializer_class = self.get_serializer_class()          context = self.get_serializer_context() -        return serializer_class(instance, data=data, files=files, -                                many=many, partial=partial, -                                allow_add_remove=allow_add_remove, -                                context=context) +        return serializer_class( +            instance, data=data, many=many, partial=partial, context=context +        )      def get_pagination_serializer(self, page):          """ diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index ac59d979..ee01cabc 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -36,12 +36,10 @@ class CreateModelMixin(object):      Create a model instance.      """      def create(self, request, *args, **kwargs): -        serializer = self.get_serializer(data=request.DATA, files=request.FILES) +        serializer = self.get_serializer(data=request.DATA)          if serializer.is_valid(): -            self.pre_save(serializer.object) -            self.object = serializer.save(force_insert=True) -            self.post_save(self.object, created=True) +            self.object = serializer.save()              headers = self.get_success_headers(serializer.data)              return Response(serializer.data, status=status.HTTP_201_CREATED,                              headers=headers) @@ -90,26 +88,20 @@ class UpdateModelMixin(object):          partial = kwargs.pop('partial', False)          self.object = self.get_object_or_none() -        serializer = self.get_serializer(self.object, data=request.DATA, -                                         files=request.FILES, partial=partial) +        serializer = self.get_serializer(self.object, data=request.DATA, partial=partial)          if not serializer.is_valid():              return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) -        try: -            self.pre_save(serializer.object) -        except ValidationError as err: -            # full_clean on model instance may be called in pre_save, -            # so we have to handle eventual errors. -            return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST) +        lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field +        lookup_value = self.kwargs[lookup_url_kwarg] +        extras = {self.lookup_field: lookup_value}          if self.object is None: -            self.object = serializer.save(force_insert=True) -            self.post_save(self.object, created=True) +            self.object = serializer.save(extras=extras)              return Response(serializer.data, status=status.HTTP_201_CREATED) -        self.object = serializer.save(force_update=True) -        self.post_save(self.object, created=False) +        self.object = serializer.save(extras=extras)          return Response(serializer.data, status=status.HTTP_200_OK)      def partial_update(self, request, *args, **kwargs): diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index d51ea929..83ef97c5 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -48,17 +48,17 @@ class DefaultObjectSerializer(serializers.Field):          super(DefaultObjectSerializer, self).__init__(source=source) -class PaginationSerializerOptions(serializers.SerializerOptions): -    """ -    An object that stores the options that may be provided to a -    pagination serializer by using the inner `Meta` class. +# class PaginationSerializerOptions(serializers.SerializerOptions): +#     """ +#     An object that stores the options that may be provided to a +#     pagination serializer by using the inner `Meta` class. -    Accessible on the instance as `serializer.opts`. -    """ -    def __init__(self, meta): -        super(PaginationSerializerOptions, self).__init__(meta) -        self.object_serializer_class = getattr(meta, 'object_serializer_class', -                                               DefaultObjectSerializer) +#     Accessible on the instance as `serializer.opts`. +#     """ +#     def __init__(self, meta): +#         super(PaginationSerializerOptions, self).__init__(meta) +#         self.object_serializer_class = getattr(meta, 'object_serializer_class', +#                                                DefaultObjectSerializer)  class BasePaginationSerializer(serializers.Serializer): @@ -66,7 +66,7 @@ class BasePaginationSerializer(serializers.Serializer):      A base class for pagination serializers to inherit from,      to make implementing custom serializers more easy.      """ -    _options_class = PaginationSerializerOptions +    # _options_class = PaginationSerializerOptions      results_field = 'results'      def __init__(self, *args, **kwargs): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 56870b40..e69de29b 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,486 +0,0 @@ -""" -Serializer fields that deal with relationships. - -These fields allow you to specify the style that should be used to represent -model relationships, including hyperlinks, primary keys, or slugs. -""" -from __future__ import unicode_literals -from django.core.exceptions import ObjectDoesNotExist, ValidationError -from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch -from django import forms -from django.db.models.fields import BLANK_CHOICE_DASH -from django.forms import widgets -from django.forms.models import ModelChoiceIterator -from django.utils.translation import ugettext_lazy as _ -from rest_framework.fields import Field, WritableField, get_component, is_simple_callable -from rest_framework.reverse import reverse -from rest_framework.compat import urlparse -from rest_framework.compat import smart_text - - -# Relational fields - -# Not actually Writable, but subclasses may need to be. -class RelatedField(WritableField): -    """ -    Base class for related model fields. - -    This represents a relationship using the unicode representation of the target. -    """ -    widget = widgets.Select -    many_widget = widgets.SelectMultiple -    form_field_class = forms.ChoiceField -    many_form_field_class = forms.MultipleChoiceField -    null_values = (None, '', 'None') - -    cache_choices = False -    empty_label = None -    read_only = True -    many = False - -    def __init__(self, *args, **kwargs): -        queryset = kwargs.pop('queryset', None) -        self.many = kwargs.pop('many', self.many) -        if self.many: -            self.widget = self.many_widget -            self.form_field_class = self.many_form_field_class - -        kwargs['read_only'] = kwargs.pop('read_only', self.read_only) -        super(RelatedField, self).__init__(*args, **kwargs) - -        if not self.required: -            # Accessed in ModelChoiceIterator django/forms/models.py:1034 -            # If set adds empty choice. -            self.empty_label = BLANK_CHOICE_DASH[0][1] - -        self.queryset = queryset - -    def initialize(self, parent, field_name): -        super(RelatedField, self).initialize(parent, field_name) -        if self.queryset is None and not self.read_only: -            manager = getattr(self.parent.opts.model, self.source or field_name) -            if hasattr(manager, 'related'):  # Forward -                self.queryset = manager.related.model._default_manager.all() -            else:  # Reverse -                self.queryset = manager.field.rel.to._default_manager.all() - -    # We need this stuff to make form choices work... - -    def prepare_value(self, obj): -        return self.to_native(obj) - -    def label_from_instance(self, obj): -        """ -        Return a readable representation for use with eg. select widgets. -        """ -        desc = smart_text(obj) -        ident = smart_text(self.to_native(obj)) -        if desc == ident: -            return desc -        return "%s - %s" % (desc, ident) - -    def _get_queryset(self): -        return self._queryset - -    def _set_queryset(self, queryset): -        self._queryset = queryset -        self.widget.choices = self.choices - -    queryset = property(_get_queryset, _set_queryset) - -    def _get_choices(self): -        # If self._choices is set, then somebody must have manually set -        # the property self.choices. In this case, just return self._choices. -        if hasattr(self, '_choices'): -            return self._choices - -        # Otherwise, execute the QuerySet in self.queryset to determine the -        # choices dynamically. Return a fresh ModelChoiceIterator that has not been -        # consumed. Note that we're instantiating a new ModelChoiceIterator *each* -        # time _get_choices() is called (and, thus, each time self.choices is -        # accessed) so that we can ensure the QuerySet has not been consumed. This -        # construct might look complicated but it allows for lazy evaluation of -        # the queryset. -        return ModelChoiceIterator(self) - -    def _set_choices(self, value): -        # Setting choices also sets the choices on the widget. -        # choices can be any iterable, but we call list() on it because -        # it will be consumed more than once. -        self._choices = self.widget.choices = list(value) - -    choices = property(_get_choices, _set_choices) - -    # Default value handling - -    def get_default_value(self): -        default = super(RelatedField, self).get_default_value() -        if self.many and default is None: -            return [] -        return default - -    # Regular serializer stuff... - -    def field_to_native(self, obj, field_name): -        try: -            if self.source == '*': -                return self.to_native(obj) - -            source = self.source or field_name -            value = obj - -            for component in source.split('.'): -                if value is None: -                    break -                value = get_component(value, component) -        except ObjectDoesNotExist: -            return None - -        if value is None: -            return None - -        if self.many: -            if is_simple_callable(getattr(value, 'all', None)): -                return [self.to_native(item) for item in value.all()] -            else: -                # Also support non-queryset iterables. -                # This allows us to also support plain lists of related items. -                return [self.to_native(item) for item in value] -        return self.to_native(value) - -    def field_from_native(self, data, files, field_name, into): -        if self.read_only: -            return - -        try: -            if self.many: -                try: -                    # Form data -                    value = data.getlist(field_name) -                    if value == [''] or value == []: -                        raise KeyError -                except AttributeError: -                    # Non-form data -                    value = data[field_name] -            else: -                value = data[field_name] -        except KeyError: -            if self.partial: -                return -            value = self.get_default_value() - -        if value in self.null_values: -            if self.required: -                raise ValidationError(self.error_messages['required']) -            into[(self.source or field_name)] = None -        elif self.many: -            into[(self.source or field_name)] = [self.from_native(item) for item in value] -        else: -            into[(self.source or field_name)] = self.from_native(value) - - -# PrimaryKey relationships - -class PrimaryKeyRelatedField(RelatedField): -    """ -    Represents a relationship as a pk value. -    """ -    read_only = False - -    default_error_messages = { -        'does_not_exist': _("Invalid pk '%s' - object does not exist."), -        'incorrect_type': _('Incorrect type.  Expected pk value, received %s.'), -    } - -    # TODO: Remove these field hacks... -    def prepare_value(self, obj): -        return self.to_native(obj.pk) - -    def label_from_instance(self, obj): -        """ -        Return a readable representation for use with eg. select widgets. -        """ -        desc = smart_text(obj) -        ident = smart_text(self.to_native(obj.pk)) -        if desc == ident: -            return desc -        return "%s - %s" % (desc, ident) - -    # TODO: Possibly change this to just take `obj`, through prob less performant -    def to_native(self, pk): -        return pk - -    def from_native(self, data): -        if self.queryset is None: -            raise Exception('Writable related fields must include a `queryset` argument') - -        try: -            return self.queryset.get(pk=data) -        except ObjectDoesNotExist: -            msg = self.error_messages['does_not_exist'] % smart_text(data) -            raise ValidationError(msg) -        except (TypeError, ValueError): -            received = type(data).__name__ -            msg = self.error_messages['incorrect_type'] % received -            raise ValidationError(msg) - -    def field_to_native(self, obj, field_name): -        if self.many: -            # To-many relationship - -            queryset = None -            if not self.source: -                # Prefer obj.serializable_value for performance reasons -                try: -                    queryset = obj.serializable_value(field_name) -                except AttributeError: -                    pass -            if queryset is None: -                # RelatedManager (reverse relationship) -                source = self.source or field_name -                queryset = obj -                for component in source.split('.'): -                    if queryset is None: -                        return [] -                    queryset = get_component(queryset, component) - -            # Forward relationship -            if is_simple_callable(getattr(queryset, 'all', None)): -                return [self.to_native(item.pk) for item in queryset.all()] -            else: -                # Also support non-queryset iterables. -                # This allows us to also support plain lists of related items. -                return [self.to_native(item.pk) for item in queryset] - -        # To-one relationship -        try: -            # Prefer obj.serializable_value for performance reasons -            pk = obj.serializable_value(self.source or field_name) -        except AttributeError: -            # RelatedObject (reverse relationship) -            try: -                pk = getattr(obj, self.source or field_name).pk -            except (ObjectDoesNotExist, AttributeError): -                return None - -        # Forward relationship -        return self.to_native(pk) - - -# Slug relationships - -class SlugRelatedField(RelatedField): -    """ -    Represents a relationship using a unique field on the target. -    """ -    read_only = False - -    default_error_messages = { -        'does_not_exist': _("Object with %s=%s does not exist."), -        'invalid': _('Invalid value.'), -    } - -    def __init__(self, *args, **kwargs): -        self.slug_field = kwargs.pop('slug_field', None) -        assert self.slug_field, 'slug_field is required' -        super(SlugRelatedField, self).__init__(*args, **kwargs) - -    def to_native(self, obj): -        return getattr(obj, self.slug_field) - -    def from_native(self, data): -        if self.queryset is None: -            raise Exception('Writable related fields must include a `queryset` argument') - -        try: -            return self.queryset.get(**{self.slug_field: data}) -        except ObjectDoesNotExist: -            raise ValidationError(self.error_messages['does_not_exist'] % -                                  (self.slug_field, smart_text(data))) -        except (TypeError, ValueError): -            msg = self.error_messages['invalid'] -            raise ValidationError(msg) - - -# Hyperlinked relationships - -class HyperlinkedRelatedField(RelatedField): -    """ -    Represents a relationship using hyperlinking. -    """ -    read_only = False -    lookup_field = 'pk' - -    default_error_messages = { -        'no_match': _('Invalid hyperlink - No URL match'), -        'incorrect_match': _('Invalid hyperlink - Incorrect URL match'), -        'configuration_error': _('Invalid hyperlink due to configuration error'), -        'does_not_exist': _("Invalid hyperlink - object does not exist."), -        'incorrect_type': _('Incorrect type.  Expected url string, received %s.'), -    } - -    def __init__(self, *args, **kwargs): -        try: -            self.view_name = kwargs.pop('view_name') -        except KeyError: -            raise ValueError("Hyperlinked field requires 'view_name' kwarg") - -        self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) -        self.format = kwargs.pop('format', None) - -        super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) - -    def get_url(self, obj, view_name, request, format): -        """ -        Given an object, return the URL that hyperlinks to the object. - -        May raise a `NoReverseMatch` if the `view_name` and `lookup_field` -        attributes are not configured to correctly match the URL conf. -        """ -        lookup_field = getattr(obj, self.lookup_field) -        kwargs = {self.lookup_field: lookup_field} -        return reverse(view_name, kwargs=kwargs, request=request, format=format) - -    def get_object(self, queryset, view_name, view_args, view_kwargs): -        """ -        Return the object corresponding to a matched URL. - -        Takes the matched URL conf arguments, and the queryset, and should -        return an object instance, or raise an `ObjectDoesNotExist` exception. -        """ -        lookup_value = view_kwargs[self.lookup_field] -        filter_kwargs = {self.lookup_field: lookup_value} -        return queryset.get(**filter_kwargs) - -    def to_native(self, obj): -        view_name = self.view_name -        request = self.context.get('request', None) -        format = self.format or self.context.get('format', None) - -        assert request is not None, ( -            "`HyperlinkedRelatedField` requires the request in the serializer " -            "context. Add `context={'request': request}` when instantiating " -            "the serializer." -        ) - -        # If the object has not yet been saved then we cannot hyperlink to it. -        if getattr(obj, 'pk', None) is None: -            return - -        # Return the hyperlink, or error if incorrectly configured. -        try: -            return self.get_url(obj, view_name, request, format) -        except NoReverseMatch: -            msg = ( -                'Could not resolve URL for hyperlinked relationship using ' -                'view name "%s". You may have failed to include the related ' -                'model in your API, or incorrectly configured the ' -                '`lookup_field` attribute on this field.' -            ) -            raise Exception(msg % view_name) - -    def from_native(self, value): -        # Convert URL -> model instance pk -        # TODO: Use values_list -        queryset = self.queryset -        if queryset is None: -            raise Exception('Writable related fields must include a `queryset` argument') - -        try: -            http_prefix = value.startswith(('http:', 'https:')) -        except AttributeError: -            msg = self.error_messages['incorrect_type'] -            raise ValidationError(msg % type(value).__name__) - -        if http_prefix: -            # If needed convert absolute URLs to relative path -            value = urlparse.urlparse(value).path -            prefix = get_script_prefix() -            if value.startswith(prefix): -                value = '/' + value[len(prefix):] - -        try: -            match = resolve(value) -        except Exception: -            raise ValidationError(self.error_messages['no_match']) - -        if match.view_name != self.view_name: -            raise ValidationError(self.error_messages['incorrect_match']) - -        try: -            return self.get_object(queryset, match.view_name, -                                   match.args, match.kwargs) -        except (ObjectDoesNotExist, TypeError, ValueError): -            raise ValidationError(self.error_messages['does_not_exist']) - - -class HyperlinkedIdentityField(Field): -    """ -    Represents the instance, or a property on the instance, using hyperlinking. -    """ -    lookup_field = 'pk' -    read_only = True - -    def __init__(self, *args, **kwargs): -        try: -            self.view_name = kwargs.pop('view_name') -        except KeyError: -            msg = "HyperlinkedIdentityField requires 'view_name' argument" -            raise ValueError(msg) - -        self.format = kwargs.pop('format', None) -        lookup_field = kwargs.pop('lookup_field', None) -        self.lookup_field = lookup_field or self.lookup_field - -        super(HyperlinkedIdentityField, self).__init__(*args, **kwargs) - -    def field_to_native(self, obj, field_name): -        request = self.context.get('request', None) -        format = self.context.get('format', None) -        view_name = self.view_name - -        assert request is not None, ( -            "`HyperlinkedIdentityField` requires the request in the serializer" -            " context. Add `context={'request': request}` when instantiating " -            "the serializer." -        ) - -        # By default use whatever format is given for the current context -        # unless the target is a different type to the source. -        # -        # Eg. Consider a HyperlinkedIdentityField pointing from a json -        # representation to an html property of that representation... -        # -        # '/snippets/1/' should link to '/snippets/1/highlight/' -        # ...but... -        # '/snippets/1/.json' should link to '/snippets/1/highlight/.html' -        if format and self.format and self.format != format: -            format = self.format - -        # Return the hyperlink, or error if incorrectly configured. -        try: -            return self.get_url(obj, view_name, request, format) -        except NoReverseMatch: -            msg = ( -                'Could not resolve URL for hyperlinked relationship using ' -                'view name "%s". You may have failed to include the related ' -                'model in your API, or incorrectly configured the ' -                '`lookup_field` attribute on this field.' -            ) -            raise Exception(msg % view_name) - -    def get_url(self, obj, view_name, request, format): -        """ -        Given an object, return the URL that hyperlinks to the object. - -        May raise a `NoReverseMatch` if the `view_name` and `lookup_field` -        attributes are not configured to correctly match the URL conf. -        """ -        lookup_field = getattr(obj, self.lookup_field, None) -        kwargs = {self.lookup_field: lookup_field} - -        # Handle unsaved object case -        if lookup_field is None: -            return None - -        return reverse(view_name, kwargs=kwargs, request=request, format=format) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 748ebac9..e8935b01 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -458,7 +458,7 @@ class BrowsableAPIRenderer(BaseRenderer):              ):                  return -            serializer = view.get_serializer(instance=obj, data=data, files=files) +            serializer = view.get_serializer(instance=obj, data=data)              serializer.is_valid()              data = serializer.data @@ -579,10 +579,10 @@ class BrowsableAPIRenderer(BaseRenderer):              'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],              'response_headers': response_headers, -            'put_form': self.get_rendered_html_form(view, 'PUT', request), -            'post_form': self.get_rendered_html_form(view, 'POST', request), -            'delete_form': self.get_rendered_html_form(view, 'DELETE', request), -            'options_form': self.get_rendered_html_form(view, 'OPTIONS', request), +            #'put_form': self.get_rendered_html_form(view, 'PUT', request), +            #'post_form': self.get_rendered_html_form(view, 'POST', request), +            #'delete_form': self.get_rendered_html_form(view, 'DELETE', request), +            #'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),              'raw_data_put_form': raw_data_put_form,              'raw_data_post_form': raw_data_post_form, diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index be8ad3f2..d121812d 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,21 +10,14 @@ python primitives.  2. The process of marshalling between python primitives and request and  response content is handled by parsers and renderers.  """ -from __future__ import unicode_literals -import copy -import datetime -import inspect -import types -from decimal import Decimal -from django.contrib.contenttypes.generic import GenericForeignKey -from django.core.paginator import Page  from django.db import models -from django.forms import widgets  from django.utils import six -from django.utils.datastructures import SortedDict -from django.core.exceptions import ObjectDoesNotExist +from collections import namedtuple, OrderedDict +from rest_framework.fields import empty, set_value, Field, SkipField, ValidationError  from rest_framework.settings import api_settings - +from rest_framework.utils import html +import copy +import inspect  # Note: We do the following so that users of the framework can use this style:  # @@ -37,635 +30,339 @@ from rest_framework.relations import *  # NOQA  from rest_framework.fields import *  # NOQA -def _resolve_model(obj): -    """ -    Resolve supplied `obj` to a Django model class. +FieldResult = namedtuple('FieldResult', ['field', 'value', 'error']) -    `obj` must be a Django model class itself, or a string -    representation of one.  Useful in situtations like GH #1225 where -    Django may not have resolved a string-based reference to a model in -    another model's foreign key definition. - -    String representations should have the format: -        'appname.ModelName' -    """ -    if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: -        app_name, model_name = obj.split('.') -        return models.get_model(app_name, model_name) -    elif inspect.isclass(obj) and issubclass(obj, models.Model): -        return obj -    else: -        raise ValueError("{0} is not a Django model".format(obj)) - - -def pretty_name(name): -    """Converts 'first_name' to 'First name'""" -    if not name: -        return '' -    return name.replace('_', ' ').capitalize() +class BaseSerializer(Field): +    def __init__(self, instance=None, data=None, **kwargs): +        super(BaseSerializer, self).__init__(**kwargs) +        self.instance = instance +        self._initial_data = data -class RelationsList(list): -    _deleted = [] +    def to_native(self, data): +        raise NotImplementedError() +    def to_primative(self, instance): +        raise NotImplementedError() -class NestedValidationError(ValidationError): -    """ -    The default ValidationError behavior is to stringify each item in the list -    if the messages are a list of error messages. +    def update(self, instance): +        raise NotImplementedError() -    In the case of nested serializers, where the parent has many children, -    then the child's `serializer.errors` will be a list of dicts.  In the case -    of a single child, the `serializer.errors` will be a dict. +    def create(self): +        raise NotImplementedError() -    We need to override the default behavior to get properly nested error dicts. -    """ +    def save(self, extras=None): +        if extras is not None: +            self._validated_data.update(extras) -    def __init__(self, message): -        if isinstance(message, dict): -            self._messages = [message] +        if self.instance is not None: +            self.update(self.instance)          else: -            self._messages = message - -    @property -    def messages(self): -        return self._messages +            self.instance = self.create() +        return self.instance -class DictWithMetadata(dict): -    """ -    A dict-like object, that can have additional properties attached. -    """ -    def __getstate__(self): -        """ -        Used by pickle (e.g., caching). -        Overridden to remove the metadata from the dict, since it shouldn't be -        pickled and may in some instances be unpickleable. -        """ -        return dict(self) - +    def is_valid(self): +        try: +            self._validated_data = self.to_native(self._initial_data) +        except ValidationError as exc: +            self._validated_data = {} +            self._errors = exc.args[0] +            return False +        self._errors = {} +        return True -class SortedDictWithMetadata(SortedDict): -    """ -    A sorted dict-like object, that can have additional properties attached. -    """ -    def __getstate__(self): -        """ -        Used by pickle (e.g., caching). -        Overriden to remove the metadata from the dict, since it shouldn't be -        pickle and may in some instances be unpickleable. -        """ -        return SortedDict(self).__dict__ +    @property +    def data(self): +        if not hasattr(self, '_data'): +            if self.instance is not None: +                self._data = self.to_primative(self.instance) +            elif self._initial_data is not None: +                self._data = { +                    field_name: field.get_value(self._initial_data) +                    for field_name, field in self.fields.items() +                } +            else: +                self._data = self.get_initial() +        return self._data +    @property +    def errors(self): +        if not hasattr(self, '_errors'): +            msg = 'You must call `.is_valid()` before accessing `.errors`.' +            raise AssertionError(msg) +        return self._errors -def _is_protected_type(obj): -    """ -    True if the object is a native datatype that does not need to -    be serialized further. -    """ -    return isinstance(obj, ( -        types.NoneType, -        int, long, -        datetime.datetime, datetime.date, datetime.time, -        float, Decimal, -        basestring) -    ) +    @property +    def validated_data(self): +        if not hasattr(self, '_validated_data'): +            msg = 'You must call `.is_valid()` before accessing `.validated_data`.' +            raise AssertionError(msg) +        return self._validated_data -def _get_declared_fields(bases, attrs): +class SerializerMetaclass(type):      """ -    Create a list of serializer field instances from the passed in 'attrs', -    plus any fields on the base classes (in 'bases'). +    This metaclass sets a dictionary named `base_fields` on the class. -    Note that all fields from the base classes are used. +    Any fields included as attributes on either the class or it's superclasses +    will be include in the `base_fields` dictionary.      """ -    fields = [(field_name, attrs.pop(field_name)) -              for field_name, obj in list(six.iteritems(attrs)) -              if isinstance(obj, Field)] -    fields.sort(key=lambda x: x[1].creation_counter) -    # If this class is subclassing another Serializer, add that Serializer's -    # fields.  Note that we loop over the bases in *reverse*. This is necessary -    # in order to maintain the correct order of fields. -    for base in bases[::-1]: -        if hasattr(base, 'base_fields'): -            fields = list(base.base_fields.items()) + fields +    @classmethod +    def _get_fields(cls, bases, attrs): +        fields = [(field_name, attrs.pop(field_name)) +                  for field_name, obj in list(attrs.items()) +                  if isinstance(obj, Field)] +        fields.sort(key=lambda x: x[1]._creation_counter) -    return SortedDict(fields) +        # If this class is subclassing another Serializer, add that Serializer's +        # fields.  Note that we loop over the bases in *reverse*. This is necessary +        # in order to maintain the correct order of fields. +        for base in bases[::-1]: +            if hasattr(base, 'base_fields'): +                fields = list(base.base_fields.items()) + fields +        return OrderedDict(fields) -class SerializerMetaclass(type):      def __new__(cls, name, bases, attrs): -        attrs['base_fields'] = _get_declared_fields(bases, attrs) +        attrs['base_fields'] = cls._get_fields(bases, attrs)          return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) -class SerializerOptions(object): -    """ -    Meta class options for Serializer -    """ -    def __init__(self, meta): -        self.depth = getattr(meta, 'depth', 0) -        self.fields = getattr(meta, 'fields', ()) -        self.exclude = getattr(meta, 'exclude', ()) +@six.add_metaclass(SerializerMetaclass) +class Serializer(BaseSerializer): +    def __new__(cls, *args, **kwargs): +        many = kwargs.pop('many', False) +        if many: +            class DynamicListSerializer(ListSerializer): +                child = cls() +            return DynamicListSerializer(*args, **kwargs) +        return super(Serializer, cls).__new__(cls) -class BaseSerializer(WritableField): -    """ -    This is the Serializer implementation. -    We need to implement it as `BaseSerializer` due to metaclass magicks. -    """ -    class Meta(object): -        pass - -    _options_class = SerializerOptions -    _dict_class = SortedDictWithMetadata - -    def __init__(self, instance=None, data=None, files=None, -                 context=None, partial=False, many=False, -                 allow_add_remove=False, **kwargs): -        super(BaseSerializer, self).__init__(**kwargs) -        self.opts = self._options_class(self.Meta) -        self.parent = None -        self.root = None -        self.partial = partial -        self.many = many -        self.allow_add_remove = allow_add_remove +    def __init__(self, *args, **kwargs): +        kwargs.pop('context', None) +        kwargs.pop('partial', None) +        kwargs.pop('many', False) -        self.context = context or {} +        super(Serializer, self).__init__(*args, **kwargs) -        self.init_data = data -        self.init_files = files -        self.object = instance +        # Every new serializer is created with a clone of the field instances. +        # This allows users to dynamically modify the fields on a serializer +        # instance without affecting every other serializer class.          self.fields = self.get_fields() -        self._data = None -        self._files = None -        self._errors = None - -        if many and instance is not None and not hasattr(instance, '__iter__'): -            raise ValueError('instance should be a queryset or other iterable with many=True') - -        if allow_add_remove and not many: -            raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') - -    ##### -    # Methods to determine which fields to use when (de)serializing objects. - -    def get_default_fields(self): -        """ -        Return the complete set of default fields for the object, as a dict. -        """ -        return {} - -    def get_fields(self): -        """ -        Returns the complete set of fields for the object as a dict. - -        This will be the set of any explicitly declared fields, -        plus the set of fields returned by get_default_fields(). -        """ -        ret = SortedDict() - -        # Get the explicitly declared fields -        base_fields = copy.deepcopy(self.base_fields) -        for key, field in base_fields.items(): -            ret[key] = field - -        # Add in the default fields -        default_fields = self.get_default_fields() -        for key, val in default_fields.items(): -            if key not in ret: -                ret[key] = val - -        # If 'fields' is specified, use those fields, in that order. -        if self.opts.fields: -            assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple' -            new = SortedDict() -            for key in self.opts.fields: -                new[key] = ret[key] -            ret = new - -        # Remove anything in 'exclude' -        if self.opts.exclude: -            assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple' -            for key in self.opts.exclude: -                ret.pop(key, None) - -        for key, field in ret.items(): -            field.initialize(parent=self, field_name=key) - -        return ret - -    ##### -    # Methods to convert or revert from objects <--> primitive representations. - -    def get_field_key(self, field_name): -        """ -        Return the key that should be used for a given field. -        """ -        return field_name - -    def restore_fields(self, data, files): -        """ -        Core of deserialization, together with `restore_object`. -        Converts a dictionary of data into a dictionary of deserialized fields. -        """ -        reverted_data = {} - -        if data is not None and not isinstance(data, dict): -            self._errors['non_field_errors'] = ['Invalid data'] -            return None - +        # Setup all the child fields, to provide them with the current context.          for field_name, field in self.fields.items(): -            field.initialize(parent=self, field_name=field_name) -            try: -                field.field_from_native(data, files, field_name, reverted_data) -            except ValidationError as err: -                self._errors[field_name] = list(err.messages) +            field.bind(field_name, self, self) -        return reverted_data +    def get_fields(self): +        return copy.deepcopy(self.base_fields) -    def perform_validation(self, attrs): -        """ -        Run `validate_<fieldname>()` and `validate()` methods on the serializer -        """ +    def bind(self, field_name, parent, root): +        # If the serializer is used as a field then when it becomes bound +        # it also needs to bind all its child fields. +        super(Serializer, self).bind(field_name, parent, root)          for field_name, field in self.fields.items(): -            if field_name in self._errors: -                continue +            field.bind(field_name, self, root) -            source = field.source or field_name -            if self.partial and source not in attrs: -                continue -            try: -                validate_method = getattr(self, 'validate_%s' % field_name, None) -                if validate_method: -                    attrs = validate_method(attrs, source) -            except ValidationError as err: -                self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) - -        # If there are already errors, we don't run .validate() because -        # field-validation failed and thus `attrs` may not be complete. -        # which in turn can cause inconsistent validation errors. -        if not self._errors: -            try: -                attrs = self.validate(attrs) -            except ValidationError as err: -                if hasattr(err, 'message_dict'): -                    for field_name, error_messages in err.message_dict.items(): -                        self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) -                elif hasattr(err, 'messages'): -                    self._errors['non_field_errors'] = err.messages - -        return attrs +    def get_initial(self): +        return { +            field.field_name: field.get_initial() +            for field in self.fields.values() +        } -    def validate(self, attrs): -        """ -        Stub method, to be overridden in Serializer subclasses -        """ -        return attrs +    def get_value(self, dictionary): +        # We override the default field access in order to support +        # nested HTML forms. +        if html.is_html_input(dictionary): +            return html.parse_html_dict(dictionary, prefix=self.field_name) +        return dictionary.get(self.field_name, empty) -    def restore_object(self, attrs, instance=None): +    def to_native(self, data):          """ -        Deserialize a dictionary of attributes into an object instance. -        You should override this method to control how deserialized objects -        are instantiated. +        Dict of native values <- Dict of primitive datatypes.          """ -        if instance is not None: -            instance.update(attrs) -            return instance -        return attrs +        ret = {} +        errors = {} +        fields = [field for field in self.fields.values() if not field.read_only] -    def to_native(self, obj): -        """ -        Serialize objects -> primitives. -        """ -        ret = self._dict_class() -        ret.fields = self._dict_class() +        for field in fields: +            primitive_value = field.get_value(data) +            try: +                validated_value = field.validate(primitive_value) +            except ValidationError as exc: +                errors[field.field_name] = str(exc) +            except SkipField: +                pass +            else: +                set_value(ret, field.source_attrs, validated_value) -        for field_name, field in self.fields.items(): -            if field.read_only and obj is None: -                continue -            field.initialize(parent=self, field_name=field_name) -            key = self.get_field_key(field_name) -            value = field.field_to_native(obj, field_name) -            method = getattr(self, 'transform_%s' % field_name, None) -            if callable(method): -                value = method(obj, value) -            if not getattr(field, 'write_only', False): -                ret[key] = value -            ret.fields[key] = self.augment_field(field, field_name, key, value) +        if errors: +            raise ValidationError(errors)          return ret -    def from_native(self, data, files=None): -        """ -        Deserialize primitives -> objects. -        """ -        self._errors = {} - -        if data is not None or files is not None: -            attrs = self.restore_fields(data, files) -            if attrs is not None: -                attrs = self.perform_validation(attrs) -        else: -            self._errors['non_field_errors'] = ['No input provided'] - -        if not self._errors: -            return self.restore_object(attrs, instance=getattr(self, 'object', None)) - -    def augment_field(self, field, field_name, key, value): -        # This horrible stuff is to manage serializers rendering to HTML -        field._errors = self._errors.get(key) if self._errors else None -        field._name = field_name -        field._value = self.init_data.get(key) if self._errors and self.init_data else value -        if not field.label: -            field.label = pretty_name(key) -        return field - -    def field_to_native(self, obj, field_name): +    def to_primative(self, instance):          """ -        Override default so that the serializer can be used as a nested field -        across relationships. +        Object instance -> Dict of primitive datatypes.          """ -        if self.write_only: -            return None +        ret = OrderedDict() +        fields = [field for field in self.fields.values() if not field.write_only] -        if self.source == '*': -            return self.to_native(obj) +        for field in fields: +            native_value = field.get_attribute(instance) +            ret[field.field_name] = field.to_primative(native_value) -        # Get the raw field value -        try: -            source = self.source or field_name -            value = obj - -            for component in source.split('.'): -                if value is None: -                    break -                value = get_component(value, component) -        except ObjectDoesNotExist: -            return None +        return ret -        if is_simple_callable(getattr(value, 'all', None)): -            return [self.to_native(item) for item in value.all()] +    def __iter__(self): +        errors = self.errors if hasattr(self, '_errors') else {} +        for field in self.fields.values(): +            value = self.data.get(field.field_name) if self.data else None +            error = errors.get(field.field_name) +            yield FieldResult(field, value, error) -        if value is None: -            return None -        if self.many: -            return [self.to_native(item) for item in value] -        return self.to_native(value) +class ListSerializer(BaseSerializer): +    child = None +    initial = [] -    def field_from_native(self, data, files, field_name, into): -        """ -        Override default so that the serializer can be used as a writable -        nested field across relationships. -        """ -        if self.read_only: -            return +    def __init__(self, *args, **kwargs): +        self.child = kwargs.pop('child', copy.deepcopy(self.child)) +        assert self.child is not None, '`child` is a required argument.' -        try: -            value = data[field_name] -        except KeyError: -            if self.default is not None and not self.partial: -                # Note: partial updates shouldn't set defaults -                value = copy.deepcopy(self.default) -            else: -                if self.required: -                    raise ValidationError(self.error_messages['required']) -                return - -        if self.source == '*': -            if value: -                reverted_data = self.restore_fields(value, {}) -                if not self._errors: -                    into.update(reverted_data) -        else: -            if value in (None, ''): -                into[(self.source or field_name)] = None -            else: -                # Set the serializer object if it exists -                obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None - -                # If we have a model manager or similar object then we need -                # to iterate through each instance. -                if ( -                    self.many and -                    not hasattr(obj, '__iter__') and -                    is_simple_callable(getattr(obj, 'all', None)) -                ): -                    obj = obj.all() - -                kwargs = { -                    'instance': obj, -                    'data': value, -                    'context': self.context, -                    'partial': self.partial, -                    'many': self.many, -                    'allow_add_remove': self.allow_add_remove -                } -                serializer = self.__class__(**kwargs) +        kwargs.pop('context', None) +        kwargs.pop('partial', None) -                if serializer.is_valid(): -                    into[self.source or field_name] = serializer.object -                else: -                    # Propagate errors up to our parent -                    raise NestedValidationError(serializer.errors) +        super(ListSerializer, self).__init__(*args, **kwargs) +        self.child.bind('', self, self) -    def get_identity(self, data): -        """ -        This hook is required for bulk update. -        It is used to determine the canonical identity of a given object. +    def bind(self, field_name, parent, root): +        # If the list is used as a field then it needs to provide +        # the current context to the child serializer. +        super(ListSerializer, self).bind(field_name, parent, root) +        self.child.bind(field_name, self, root) -        Note that the data has not been validated at this point, so we need -        to make sure that we catch any cases of incorrect datatypes being -        passed to this method. -        """ -        try: -            return data.get('id', None) -        except AttributeError: -            return None +    def get_value(self, dictionary): +        # We override the default field access in order to support +        # lists in HTML forms. +        if is_html_input(dictionary): +            return html.parse_html_list(dictionary, prefix=self.field_name) +        return dictionary.get(self.field_name, empty) -    @property -    def errors(self): +    def to_native(self, data):          """ -        Run deserialization and return error data, -        setting self.object if no errors occurred. +        List of dicts of native values <- List of dicts of primitive datatypes.          """ -        if self._errors is None: -            data, files = self.init_data, self.init_files +        if html.is_html_input(data): +            data = html.parse_html_list(data) -            if self.many is not None: -                many = self.many -            else: -                many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) -                if many: -                    warnings.warn('Implicit list/queryset serialization is deprecated. ' -                                  'Use the `many=True` flag when instantiating the serializer.', -                                  DeprecationWarning, stacklevel=3) - -            if many: -                ret = RelationsList() -                errors = [] -                update = self.object is not None - -                if update: -                    # If this is a bulk update we need to map all the objects -                    # to a canonical identity so we can determine which -                    # individual object is being updated for each item in the -                    # incoming data -                    objects = self.object -                    identities = [self.get_identity(self.to_native(obj)) for obj in objects] -                    identity_to_objects = dict(zip(identities, objects)) - -                if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): -                    for item in data: -                        if update: -                            # Determine which object we're updating -                            identity = self.get_identity(item) -                            self.object = identity_to_objects.pop(identity, None) -                            if self.object is None and not self.allow_add_remove: -                                ret.append(None) -                                errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) -                                continue - -                        ret.append(self.from_native(item, None)) -                        errors.append(self._errors) - -                    if update and self.allow_add_remove: -                        ret._deleted = identity_to_objects.values() - -                    self._errors = any(errors) and errors or [] -                else: -                    self._errors = {'non_field_errors': ['Expected a list of items.']} -            else: -                ret = self.from_native(data, files) - -            if not self._errors: -                self.object = ret - -        return self._errors - -    def is_valid(self): -        return not self.errors +        return [self.child.validate(item) for item in data] -    @property -    def data(self): +    def to_primative(self, data):          """ -        Returns the serialized data on the serializer. +        List of object instances -> List of dicts of primitive datatypes.          """ -        if self._data is None: -            obj = self.object +        return [self.child.to_primative(item) for item in data] -            if self.many is not None: -                many = self.many -            else: -                many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) -                if many: -                    warnings.warn('Implicit list/queryset serialization is deprecated. ' -                                  'Use the `many=True` flag when instantiating the serializer.', -                                  DeprecationWarning, stacklevel=2) - -            if many: -                self._data = [self.to_native(item) for item in obj] -            else: -                self._data = self.to_native(obj) +    def create(self, attrs_list): +        return [self.child.create(attrs) for attrs in attrs_list] -        return self._data +    def save(self): +        if self.instance is not None: +            self.update(self.instance, self.validated_data) +        self.instance = self.create(self.validated_data) +        return self.instance -    def save_object(self, obj, **kwargs): -        obj.save(**kwargs) -    def delete_object(self, obj): -        obj.delete() - -    def save(self, **kwargs): -        """ -        Save the deserialized object and return it. -        """ -        # Clear cached _data, which may be invalidated by `save()` -        self._data = None - -        if isinstance(self.object, list): -            [self.save_object(item, **kwargs) for item in self.object] - -            if self.object._deleted: -                [self.delete_object(item) for item in self.object._deleted] -        else: -            self.save_object(self.object, **kwargs) - -        return self.object - -    def metadata(self): -        """ -        Return a dictionary of metadata about the fields on the serializer. -        Useful for things like responding to OPTIONS requests, or generating -        API schemas for auto-documentation. -        """ -        return SortedDict( -            [ -                (field_name, field.metadata()) -                for field_name, field in six.iteritems(self.fields) -            ] -        ) +def _resolve_model(obj): +    """ +    Resolve supplied `obj` to a Django model class. +    `obj` must be a Django model class itself, or a string +    representation of one.  Useful in situtations like GH #1225 where +    Django may not have resolved a string-based reference to a model in +    another model's foreign key definition. -class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)): -    pass +    String representations should have the format: +        'appname.ModelName' +    """ +    if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: +        app_name, model_name = obj.split('.') +        return models.get_model(app_name, model_name) +    elif inspect.isclass(obj) and issubclass(obj, models.Model): +        return obj +    else: +        raise ValueError("{0} is not a Django model".format(obj)) -class ModelSerializerOptions(SerializerOptions): +class ModelSerializerOptions(object):      """      Meta class options for ModelSerializer      """      def __init__(self, meta): -        super(ModelSerializerOptions, self).__init__(meta) -        self.model = getattr(meta, 'model', None) -        self.read_only_fields = getattr(meta, 'read_only_fields', ()) -        self.write_only_fields = getattr(meta, 'write_only_fields', ()) +        self.model = getattr(meta, 'model') +        self.fields = getattr(meta, 'fields', ()) +        self.depth = getattr(meta, 'depth', 0)  class ModelSerializer(Serializer): -    """ -    A serializer that deals with model instances and querysets. -    """ -    _options_class = ModelSerializerOptions -      field_mapping = {          models.AutoField: IntegerField, -        models.FloatField: FloatField, +        # models.FloatField: FloatField,          models.IntegerField: IntegerField,          models.PositiveIntegerField: IntegerField,          models.SmallIntegerField: IntegerField,          models.PositiveSmallIntegerField: IntegerField, -        models.DateTimeField: DateTimeField, -        models.DateField: DateField, -        models.TimeField: TimeField, -        models.DecimalField: DecimalField, -        models.EmailField: EmailField, +        # models.DateTimeField: DateTimeField, +        # models.DateField: DateField, +        # models.TimeField: TimeField, +        # models.DecimalField: DecimalField, +        # models.EmailField: EmailField,          models.CharField: CharField, -        models.URLField: URLField, -        models.SlugField: SlugField, +        # models.URLField: URLField, +        # models.SlugField: SlugField,          models.TextField: CharField,          models.CommaSeparatedIntegerField: CharField,          models.BooleanField: BooleanField,          models.NullBooleanField: BooleanField, -        models.FileField: FileField, -        models.ImageField: ImageField, +        # models.FileField: FileField, +        # models.ImageField: ImageField,      } +    _options_class = ModelSerializerOptions + +    def __init__(self, *args, **kwargs): +        self.opts = self._options_class(self.Meta) +        super(ModelSerializer, self).__init__(*args, **kwargs) + +    def get_fields(self): +        # Get the explicitly declared fields. +        fields = copy.deepcopy(self.base_fields) + +        # Add in the default fields. +        for key, val in self.get_default_fields().items(): +            if key not in fields: +                fields[key] = val + +        # If `fields` is set on the `Meta` class, +        # then use only those fields, and in that order. +        if self.opts.fields: +            fields = OrderedDict([ +                (key, fields[key]) for key in self.opts.fields +            ]) + +        return fields +      def get_default_fields(self):          """          Return all the fields that should be serialized for the model.          """ -          cls = self.opts.model -        assert cls is not None, ( -            "Serializer class '%s' is missing 'model' Meta option" % -            self.__class__.__name__ -        )          opts = cls._meta.concrete_model._meta -        ret = SortedDict() +        ret = OrderedDict()          nested = bool(self.opts.depth)          # Deal with adding the primary key field @@ -694,29 +391,9 @@ class ModelSerializer(Serializer):                      has_through_model = True              if model_field.rel and nested: -                if len(inspect.getargspec(self.get_nested_field).args) == 2: -                    warnings.warn( -                        'The `get_nested_field(model_field)` call signature ' -                        'is deprecated. ' -                        'Use `get_nested_field(model_field, related_model, ' -                        'to_many) instead', -                        DeprecationWarning -                    ) -                    field = self.get_nested_field(model_field) -                else: -                    field = self.get_nested_field(model_field, related_model, to_many) +                field = self.get_nested_field(model_field, related_model, to_many)              elif model_field.rel: -                if len(inspect.getargspec(self.get_nested_field).args) == 3: -                    warnings.warn( -                        'The `get_related_field(model_field, to_many)` call ' -                        'signature is deprecated. ' -                        'Use `get_related_field(model_field, related_model, ' -                        'to_many) instead', -                        DeprecationWarning -                    ) -                    field = self.get_related_field(model_field, to_many=to_many) -                else: -                    field = self.get_related_field(model_field, related_model, to_many) +                field = self.get_related_field(model_field, related_model, to_many)              else:                  field = self.get_field(model_field) @@ -763,38 +440,6 @@ class ModelSerializer(Serializer):                  ret[accessor_name] = field -        # Ensure that 'read_only_fields' is an iterable -        assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple' - -        # Add the `read_only` flag to any fields that have been specified -        # in the `read_only_fields` option -        for field_name in self.opts.read_only_fields: -            assert field_name not in self.base_fields.keys(), ( -                "field '%s' on serializer '%s' specified in " -                "`read_only_fields`, but also added " -                "as an explicit field.  Remove it from `read_only_fields`." % -                (field_name, self.__class__.__name__)) -            assert field_name in ret, ( -                "Non-existant field '%s' specified in `read_only_fields` " -                "on serializer '%s'." % -                (field_name, self.__class__.__name__)) -            ret[field_name].read_only = True - -        # Ensure that 'write_only_fields' is an iterable -        assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple' - -        for field_name in self.opts.write_only_fields: -            assert field_name not in self.base_fields.keys(), ( -                "field '%s' on serializer '%s' specified in " -                "`write_only_fields`, but also added " -                "as an explicit field.  Remove it from `write_only_fields`." % -                (field_name, self.__class__.__name__)) -            assert field_name in ret, ( -                "Non-existant field '%s' specified in `write_only_fields` " -                "on serializer '%s'." % -                (field_name, self.__class__.__name__)) -            ret[field_name].write_only = True -          return ret      def get_pk_field(self, model_field): @@ -825,28 +470,24 @@ class ModelSerializer(Serializer):          # TODO: filter queryset using:          # .using(db).complex_filter(self.rel.limit_choices_to) -        kwargs = { -            'queryset': related_model._default_manager, -            'many': to_many -        } +        kwargs = {} +        #     'queryset': related_model._default_manager, +        #     'many': to_many +        # }          if model_field:              kwargs['required'] = not(model_field.null or model_field.blank) -            if model_field.help_text is not None: -                kwargs['help_text'] = model_field.help_text +            # if model_field.help_text is not None: +            #     kwargs['help_text'] = model_field.help_text              if model_field.verbose_name is not None:                  kwargs['label'] = model_field.verbose_name -              if not model_field.editable:                  kwargs['read_only'] = True -              if model_field.verbose_name is not None:                  kwargs['label'] = model_field.verbose_name -            if model_field.help_text is not None: -                kwargs['help_text'] = model_field.help_text - -        return PrimaryKeyRelatedField(**kwargs) +        return IntegerField(**kwargs) +        # TODO: return PrimaryKeyRelatedField(**kwargs)      def get_field(self, model_field):          """ @@ -869,8 +510,8 @@ class ModelSerializer(Serializer):          if model_field.verbose_name is not None:              kwargs['label'] = model_field.verbose_name -        if model_field.help_text is not None: -            kwargs['help_text'] = model_field.help_text +        # if model_field.help_text is not None: +        #     kwargs['help_text'] = model_field.help_text          # TODO: TypedChoiceField?          if model_field.flatchoices:  # This ModelField contains choices @@ -880,7 +521,7 @@ class ModelSerializer(Serializer):              return ChoiceField(**kwargs)          # put this below the ChoiceField because min_value isn't a valid initializer -        if issubclass(model_field.__class__, models.PositiveIntegerField) or\ +        if issubclass(model_field.__class__, models.PositiveIntegerField) or \                  issubclass(model_field.__class__, models.PositiveSmallIntegerField):              kwargs['min_value'] = 0 @@ -888,170 +529,27 @@ class ModelSerializer(Serializer):                  issubclass(model_field.__class__, (models.CharField, models.TextField)):              kwargs['allow_none'] = True -        attribute_dict = { -            models.CharField: ['max_length'], -            models.CommaSeparatedIntegerField: ['max_length'], -            models.DecimalField: ['max_digits', 'decimal_places'], -            models.EmailField: ['max_length'], -            models.FileField: ['max_length'], -            models.ImageField: ['max_length'], -            models.SlugField: ['max_length'], -            models.URLField: ['max_length'], -        } - -        if model_field.__class__ in attribute_dict: -            attributes = attribute_dict[model_field.__class__] -            for attribute in attributes: -                kwargs.update({attribute: getattr(model_field, attribute)}) +        # attribute_dict = { +        #     models.CharField: ['max_length'], +        #     models.CommaSeparatedIntegerField: ['max_length'], +        #     models.DecimalField: ['max_digits', 'decimal_places'], +        #     models.EmailField: ['max_length'], +        #     models.FileField: ['max_length'], +        #     models.ImageField: ['max_length'], +        #     models.SlugField: ['max_length'], +        #     models.URLField: ['max_length'], +        # } + +        # if model_field.__class__ in attribute_dict: +        #     attributes = attribute_dict[model_field.__class__] +        #     for attribute in attributes: +        #         kwargs.update({attribute: getattr(model_field, attribute)})          try:              return self.field_mapping[model_field.__class__](**kwargs)          except KeyError: -            return ModelField(model_field=model_field, **kwargs) - -    def get_validation_exclusions(self, instance=None): -        """ -        Return a list of field names to exclude from model validation. -        """ -        cls = self.opts.model -        opts = cls._meta.concrete_model._meta -        exclusions = [field.name for field in opts.fields + opts.many_to_many] - -        for field_name, field in self.fields.items(): -            field_name = field.source or field_name -            if ( -                field_name in exclusions -                and not field.read_only -                and (field.required or hasattr(instance, field_name)) -                and not isinstance(field, Serializer) -            ): -                exclusions.remove(field_name) -        return exclusions - -    def full_clean(self, instance): -        """ -        Perform Django's full_clean, and populate the `errors` dictionary -        if any validation errors occur. - -        Note that we don't perform this inside the `.restore_object()` method, -        so that subclasses can override `.restore_object()`, and still get -        the full_clean validation checking. -        """ -        try: -            instance.full_clean(exclude=self.get_validation_exclusions(instance)) -        except ValidationError as err: -            self._errors = err.message_dict -            return None -        return instance - -    def restore_object(self, attrs, instance=None): -        """ -        Restore the model instance. -        """ -        m2m_data = {} -        related_data = {} -        nested_forward_relations = {} -        meta = self.opts.model._meta - -        # Reverse fk or one-to-one relations -        for (obj, model) in meta.get_all_related_objects_with_model(): -            field_name = obj.get_accessor_name() -            if field_name in attrs: -                related_data[field_name] = attrs.pop(field_name) - -        # Reverse m2m relations -        for (obj, model) in meta.get_all_related_m2m_objects_with_model(): -            field_name = obj.get_accessor_name() -            if field_name in attrs: -                m2m_data[field_name] = attrs.pop(field_name) - -        # Forward m2m relations -        for field in meta.many_to_many + meta.virtual_fields: -            if isinstance(field, GenericForeignKey): -                continue -            if field.name in attrs: -                m2m_data[field.name] = attrs.pop(field.name) - -        # Nested forward relations - These need to be marked so we can save -        # them before saving the parent model instance. -        for field_name in attrs.keys(): -            if isinstance(self.fields.get(field_name, None), Serializer): -                nested_forward_relations[field_name] = attrs[field_name] - -        # Create an empty instance of the model -        if instance is None: -            instance = self.opts.model() - -        for key, val in attrs.items(): -            try: -                setattr(instance, key, val) -            except ValueError: -                self._errors[key] = [self.error_messages['required']] - -        # Any relations that cannot be set until we've -        # saved the model get hidden away on these -        # private attributes, so we can deal with them -        # at the point of save. -        instance._related_data = related_data -        instance._m2m_data = m2m_data -        instance._nested_forward_relations = nested_forward_relations - -        return instance - -    def from_native(self, data, files): -        """ -        Override the default method to also include model field validation. -        """ -        instance = super(ModelSerializer, self).from_native(data, files) -        if not self._errors: -            return self.full_clean(instance) - -    def save_object(self, obj, **kwargs): -        """ -        Save the deserialized object. -        """ -        if getattr(obj, '_nested_forward_relations', None): -            # Nested relationships need to be saved before we can save the -            # parent instance. -            for field_name, sub_object in obj._nested_forward_relations.items(): -                if sub_object: -                    self.save_object(sub_object) -                setattr(obj, field_name, sub_object) - -        obj.save(**kwargs) - -        if getattr(obj, '_m2m_data', None): -            for accessor_name, object_list in obj._m2m_data.items(): -                setattr(obj, accessor_name, object_list) -            del(obj._m2m_data) - -        if getattr(obj, '_related_data', None): -            related_fields = dict([ -                (field.get_accessor_name(), field) -                for field, model -                in obj._meta.get_all_related_objects_with_model() -            ]) -            for accessor_name, related in obj._related_data.items(): -                if isinstance(related, RelationsList): -                    # Nested reverse fk relationship -                    for related_item in related: -                        fk_field = related_fields[accessor_name].field.name -                        setattr(related_item, fk_field, obj) -                        self.save_object(related_item) - -                    # Delete any removed objects -                    if related._deleted: -                        [self.delete_object(item) for item in related._deleted] - -                elif isinstance(related, models.Model): -                    # Nested reverse one-one relationship -                    fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name -                    setattr(related, fk_field, obj) -                    self.save_object(related) -                else: -                    # Reverse FK or reverse one-one -                    setattr(obj, accessor_name, related) -            del(obj._related_data) +            # TODO: Change this to `return ModelField(model_field=model_field, **kwargs)` +            return CharField(**kwargs)  class HyperlinkedModelSerializerOptions(ModelSerializerOptions): @@ -1066,14 +564,10 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):  class HyperlinkedModelSerializer(ModelSerializer): -    """ -    A subclass of ModelSerializer that uses hyperlinked relationships, -    instead of primary key relationships. -    """      _options_class = HyperlinkedModelSerializerOptions      _default_view_name = '%(model_name)s-detail' -    _hyperlink_field_class = HyperlinkedRelatedField -    _hyperlink_identify_field_class = HyperlinkedIdentityField +    #_hyperlink_field_class = HyperlinkedRelatedField +    #_hyperlink_identify_field_class = HyperlinkedIdentityField      def get_default_fields(self):          fields = super(HyperlinkedModelSerializer, self).get_default_fields() @@ -1081,15 +575,15 @@ class HyperlinkedModelSerializer(ModelSerializer):          if self.opts.view_name is None:              self.opts.view_name = self._get_default_view_name(self.opts.model) -        if self.opts.url_field_name not in fields: -            url_field = self._hyperlink_identify_field_class( -                view_name=self.opts.view_name, -                lookup_field=self.opts.lookup_field -            ) -            ret = self._dict_class() -            ret[self.opts.url_field_name] = url_field -            ret.update(fields) -            fields = ret +        # if self.opts.url_field_name not in fields: +        #     url_field = self._hyperlink_identify_field_class( +        #         view_name=self.opts.view_name, +        #         lookup_field=self.opts.lookup_field +        #     ) +        #     ret = self._dict_class() +        #     ret[self.opts.url_field_name] = url_field +        #     ret.update(fields) +        #     fields = ret          return fields @@ -1103,33 +597,25 @@ class HyperlinkedModelSerializer(ModelSerializer):          """          # TODO: filter queryset using:          # .using(db).complex_filter(self.rel.limit_choices_to) -        kwargs = { -            'queryset': related_model._default_manager, -            'view_name': self._get_default_view_name(related_model), -            'many': to_many -        } +        # kwargs = { +        #     'queryset': related_model._default_manager, +        #     'view_name': self._get_default_view_name(related_model), +        #     'many': to_many +        # } +        kwargs = {}          if model_field:              kwargs['required'] = not(model_field.null or model_field.blank) -            if model_field.help_text is not None: -                kwargs['help_text'] = model_field.help_text +            # if model_field.help_text is not None: +            #     kwargs['help_text'] = model_field.help_text              if model_field.verbose_name is not None:                  kwargs['label'] = model_field.verbose_name -        if self.opts.lookup_field: -            kwargs['lookup_field'] = self.opts.lookup_field - -        return self._hyperlink_field_class(**kwargs) +        return IntegerField(**kwargs) +        # if self.opts.lookup_field: +        #     kwargs['lookup_field'] = self.opts.lookup_field -    def get_identity(self, data): -        """ -        This hook is required for bulk update. -        We need to override the default, to use the url as the identity. -        """ -        try: -            return data.get(self.opts.url_field_name, None) -        except AttributeError: -            return None +        # return self._hyperlink_field_class(**kwargs)      def _get_default_view_name(self, model):          """ diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 00ffdfba..6a2f6126 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -7,7 +7,7 @@ from django.db.models.query import QuerySet  from django.utils.datastructures import SortedDict  from django.utils.functional import Promise  from rest_framework.compat import force_text -from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata +# from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata  import datetime  import decimal  import types @@ -106,14 +106,14 @@ else:          SortedDict,          yaml.representer.SafeRepresenter.represent_dict      ) -    SafeDumper.add_representer( -        DictWithMetadata, -        yaml.representer.SafeRepresenter.represent_dict -    ) -    SafeDumper.add_representer( -        SortedDictWithMetadata, -        yaml.representer.SafeRepresenter.represent_dict -    ) +    # SafeDumper.add_representer( +    #     DictWithMetadata, +    #     yaml.representer.SafeRepresenter.represent_dict +    # ) +    # SafeDumper.add_representer( +    #     SortedDictWithMetadata, +    #     yaml.representer.SafeRepresenter.represent_dict +    # )      SafeDumper.add_representer(          types.GeneratorType,          yaml.representer.SafeRepresenter.represent_list diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py new file mode 100644 index 00000000..bf17050d --- /dev/null +++ b/rest_framework/utils/html.py @@ -0,0 +1,86 @@ +""" +Helpers for dealing with HTML input. +""" + +def is_html_input(dictionary): +    # MultiDict type datastructures are used to represent HTML form input, +    # which may have more than one value for each key. +    return hasattr(dictionary, 'getlist') + + +def parse_html_list(dictionary, prefix=''): +    """ +    Used to suport list values in HTML forms. +    Supports lists of primitives and/or dictionaries. + +    * List of primitives. + +    { +        '[0]': 'abc', +        '[1]': 'def', +        '[2]': 'hij' +    } +        --> +    [ +        'abc', +        'def', +        'hij' +    ] + +    * List of dictionaries. + +    { +        '[0]foo': 'abc', +        '[0]bar': 'def', +        '[1]foo': 'hij', +        '[2]bar': 'klm', +    } +        --> +    [ +        {'foo': 'abc', 'bar': 'def'}, +        {'foo': 'hij', 'bar': 'klm'} +    ] +    """ +    Dict = type(dictionary) +    ret = {} +    regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix)) +    for field, value in dictionary.items(): +        match = regex.match(field) +        if not match: +            continue +        index, key = match.groups() +        index = int(index) +        if not key: +            ret[index] = value +        elif isinstance(ret.get(index), dict): +            ret[index][key] = value +        else: +            ret[index] = Dict({key: value}) +    return [ret[item] for item in sorted(ret.keys())] + + +def parse_html_dict(dictionary, prefix): +    """ +    Used to support dictionary values in HTML forms. + +    { +        'profile.username': 'example', +        'profile.email': 'example@example.com', +    } +        --> +    { +        'profile': { +            'username': 'example, +            'email': 'example@example.com' +        } +    } +    """ +    ret = {} +    regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix)) +    for field, value in dictionary.items(): +        match = regex.match(field) +        if not match: +            continue +        key = match.groups()[0] +        ret[key] = value +    return ret | 
