aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorEleni Lixourioti2014-11-15 14:27:41 +0000
committerEleni Lixourioti2014-11-15 14:27:41 +0000
commit1aa77830955dcdf829f65a9001b6b8900dfc8755 (patch)
tree1f6d0bea3c0fe720a298b2da177bb91e8a74a19c /rest_framework
parentafaa52a378705b7f0475d5ece04a2cf49af4b7c2 (diff)
parent88008c0a687219e3104d548196915b1068536d74 (diff)
downloaddjango-rest-framework-1aa77830955dcdf829f65a9001b6b8900dfc8755.tar.bz2
Merge branch 'version-3.1' of github.com:tomchristie/django-rest-framework into oauth_as_package
Conflicts: .travis.yml
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/authtoken/migrations/0001_initial.py7
-rw-r--r--rest_framework/authtoken/serializers.py7
-rw-r--r--rest_framework/authtoken/views.py3
-rw-r--r--rest_framework/compat.py11
-rw-r--r--rest_framework/decorators.py35
-rw-r--r--rest_framework/exceptions.py4
-rw-r--r--rest_framework/fields.py1229
-rw-r--r--rest_framework/generics.py233
-rw-r--r--rest_framework/mixins.py170
-rw-r--r--rest_framework/pagination.py39
-rw-r--r--rest_framework/parsers.py22
-rw-r--r--rest_framework/relations.py678
-rw-r--r--rest_framework/renderers.py49
-rw-r--r--rest_framework/routers.py14
-rw-r--r--rest_framework/serializers.py1375
-rw-r--r--rest_framework/settings.py33
-rw-r--r--rest_framework/test.py2
-rw-r--r--rest_framework/utils/encoders.py83
-rw-r--r--rest_framework/utils/field_mapping.py215
-rw-r--r--rest_framework/utils/formatting.py6
-rw-r--r--rest_framework/utils/html.py88
-rw-r--r--rest_framework/utils/humanize_datetime.py47
-rw-r--r--rest_framework/utils/model_meta.py129
-rw-r--r--rest_framework/utils/representation.py87
-rw-r--r--rest_framework/views.py16
-rw-r--r--rest_framework/viewsets.py3
26 files changed, 1857 insertions, 2728 deletions
diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py
index 2e5d6b47..769f6202 100644
--- a/rest_framework/authtoken/migrations/0001_initial.py
+++ b/rest_framework/authtoken/migrations/0001_initial.py
@@ -1,4 +1,4 @@
-# encoding: utf8
+# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.db import models, migrations
@@ -15,12 +15,11 @@ class Migration(migrations.Migration):
migrations.CreateModel(
name='Token',
fields=[
- ('key', models.CharField(max_length=40, serialize=False, primary_key=True)),
- ('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, to_field='id')),
+ ('key', models.CharField(primary_key=True, serialize=False, max_length=40)),
('created', models.DateTimeField(auto_now_add=True)),
+ ('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, related_name='auth_token')),
],
options={
- 'abstract': False,
},
bases=(models.Model,),
),
diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py
index 99e99ae3..c2c456de 100644
--- a/rest_framework/authtoken/serializers.py
+++ b/rest_framework/authtoken/serializers.py
@@ -19,11 +19,12 @@ class AuthTokenSerializer(serializers.Serializer):
if not user.is_active:
msg = _('User account is disabled.')
raise serializers.ValidationError(msg)
- attrs['user'] = user
- return attrs
else:
- msg = _('Unable to login with provided credentials.')
+ msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg)
else:
msg = _('Must include "username" and "password"')
raise serializers.ValidationError(msg)
+
+ attrs['user'] = user
+ return attrs
diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py
index 7c03cb76..94e6f061 100644
--- a/rest_framework/authtoken/views.py
+++ b/rest_framework/authtoken/views.py
@@ -18,7 +18,8 @@ class ObtainAuthToken(APIView):
def post(self, request):
serializer = self.serializer_class(data=request.DATA)
if serializer.is_valid():
- token, created = Token.objects.get_or_create(user=serializer.object['user'])
+ user = serializer.validated_data['user']
+ token, created = Token.objects.get_or_create(user=user)
return Response({'token': token.key})
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index bc5719ef..6c243462 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -39,6 +39,17 @@ except ImportError:
django_filters = None
+if django.VERSION >= (1, 6):
+ def clean_manytomany_helptext(text):
+ return text
+else:
+ # Up to version 1.5 many to many fields automatically suffix
+ # the `help_text` attribute with hardcoded text.
+ def clean_manytomany_helptext(text):
+ if text.endswith(' Hold down "Control", or "Command" on a Mac, to select more than one.'):
+ text = text[:-69]
+ return text
+
# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
# Fixes (#1712). We keep the try/except for the test suite.
guardian = None
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 449ba0a2..d28d6e22 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -10,7 +10,6 @@ from __future__ import unicode_literals
from django.utils import six
from rest_framework.views import APIView
import types
-import warnings
def api_view(http_method_names):
@@ -130,37 +129,3 @@ def list_route(methods=['get'], **kwargs):
func.kwargs = kwargs
return func
return decorator
-
-
-# These are now pending deprecation, in favor of `detail_route` and `list_route`.
-
-def link(**kwargs):
- """
- Used to mark a method on a ViewSet that should be routed for detail GET requests.
- """
- msg = 'link is pending deprecation. Use detail_route instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
-
- def decorator(func):
- func.bind_to_methods = ['get']
- func.detail = True
- func.kwargs = kwargs
- return func
-
- return decorator
-
-
-def action(methods=['post'], **kwargs):
- """
- Used to mark a method on a ViewSet that should be routed for detail POST requests.
- """
- msg = 'action is pending deprecation. Use detail_route instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
-
- def decorator(func):
- func.bind_to_methods = methods
- func.detail = True
- func.kwargs = kwargs
- return func
-
- return decorator
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index ad52d172..06b5e8a2 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -15,7 +15,7 @@ class APIException(Exception):
Subclasses should provide `.status_code` and `.default_detail` properties.
"""
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
- default_detail = ''
+ default_detail = 'A server error occured'
def __init__(self, detail=None):
self.detail = detail or self.default_detail
@@ -54,7 +54,7 @@ class MethodNotAllowed(APIException):
class NotAcceptable(APIException):
status_code = status.HTTP_406_NOT_ACCEPTABLE
- default_detail = "Could not satisfy the request's Accept header"
+ default_detail = "Could not satisfy the request Accept header"
def __init__(self, detail=None, available_renderers=None):
self.detail = detail or self.default_detail
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 8e15345d..0c78b3fb 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,34 +1,28 @@
-"""
-Serializer fields perform validation on incoming data.
-
-They are very similar to Django's form fields.
-"""
-from __future__ import unicode_literals
-
-import copy
-import datetime
-import inspect
-import re
-import warnings
-from decimal import Decimal, DecimalException
-from django import forms
+from django.conf import settings
from django.core import validators
from django.core.exceptions import ValidationError
-from django.conf import settings
-from django.db.models.fields import BLANK_CHOICE_DASH
-from django.http import QueryDict
-from django.forms import widgets
-from django.utils import six, timezone
+from django.utils import timezone
+from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
-from django.utils.datastructures import SortedDict
-from django.utils.dateparse import parse_date, parse_datetime, parse_time
from rest_framework import ISO_8601
-from rest_framework.compat import (
- BytesIO, smart_text,
- force_text, is_non_str_iterable
-)
+from rest_framework.compat import smart_text
from rest_framework.settings import api_settings
+from rest_framework.utils import html, representation, humanize_datetime
+import datetime
+import decimal
+import inspect
+import warnings
+
+
+class empty:
+ """
+ This class is used to represent no data being provided for a given input
+ or output value.
+
+ It is required because `None` may be a valid input or output value.
+ """
+ pass
def is_simple_callable(obj):
@@ -47,597 +41,487 @@ def is_simple_callable(obj):
return len_args <= len_defaults
-def get_component(obj, attr_name):
+def get_attribute(instance, attrs):
"""
- Given an object, and an attribute name,
- return that attribute on the object.
+ Similar to Python's built in `getattr(instance, attr)`,
+ but takes a list of nested attributes, instead of a single attribute.
+
+ Also accepts either attribute lookup on objects or dictionary lookups.
"""
- if isinstance(obj, dict):
- val = obj.get(attr_name)
- else:
- val = getattr(obj, attr_name)
-
- if is_simple_callable(val):
- return val()
- return val
-
-
-def readable_datetime_formats(formats):
- format = ', '.join(formats).replace(
- ISO_8601,
- 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
- )
- return humanize_strptime(format)
-
-
-def readable_date_formats(formats):
- format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
- return humanize_strptime(format)
-
-
-def readable_time_formats(formats):
- format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
- return humanize_strptime(format)
-
-
-def humanize_strptime(format_string):
- # Note that we're missing some of the locale specific mappings that
- # don't really make sense.
- mapping = {
- "%Y": "YYYY",
- "%y": "YY",
- "%m": "MM",
- "%b": "[Jan-Dec]",
- "%B": "[January-December]",
- "%d": "DD",
- "%H": "hh",
- "%I": "hh", # Requires '%p' to differentiate from '%H'.
- "%M": "mm",
- "%S": "ss",
- "%f": "uuuuuu",
- "%a": "[Mon-Sun]",
- "%A": "[Monday-Sunday]",
- "%p": "[AM|PM]",
- "%z": "[+HHMM|-HHMM]"
- }
- for key, val in mapping.items():
- format_string = format_string.replace(key, val)
- return format_string
+ for attr in attrs:
+ try:
+ instance = getattr(instance, attr)
+ except AttributeError as exc:
+ try:
+ return instance[attr]
+ except (KeyError, TypeError):
+ raise exc
+ return instance
-def strip_multiple_choice_msg(help_text):
+def set_value(dictionary, keys, value):
"""
- Remove the 'Hold down "control" ...' message that is Django enforces in
- select multiple fields on ModelForms. (Required for 1.5 and earlier)
+ Similar to Python's built in `dictionary[key] = value`,
+ but takes a list of nested keys instead of a single key.
- See https://code.djangoproject.com/ticket/9321
+ set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2}
+ set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2}
+ set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}}
"""
- multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.')
- multiple_choice_msg = force_text(multiple_choice_msg)
+ if not keys:
+ dictionary.update(value)
+ return
- return help_text.replace(multiple_choice_msg, '')
+ for key in keys[:-1]:
+ if key not in dictionary:
+ dictionary[key] = {}
+ dictionary = dictionary[key]
+ dictionary[keys[-1]] = value
-class Field(object):
- read_only = True
- creation_counter = 0
- empty = ''
- type_name = None
- partial = False
- use_files = False
- form_field_class = forms.CharField
- type_label = 'field'
- widget = None
-
- def __init__(self, source=None, label=None, help_text=None):
- self.parent = None
-
- self.creation_counter = Field.creation_counter
- Field.creation_counter += 1
- self.source = source
+class SkipField(Exception):
+ pass
- if label is not None:
- self.label = smart_text(label)
- else:
- self.label = None
- if help_text is not None:
- self.help_text = strip_multiple_choice_msg(smart_text(help_text))
- else:
- self.help_text = None
+NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
+NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
+NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
+NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
+MISSING_ERROR_MESSAGE = (
+ 'ValidationError raised by `{class_name}`, but error key `{key}` does '
+ 'not exist in the `error_messages` dictionary.'
+)
- self._errors = []
- self._value = None
- self._name = None
- @property
- def errors(self):
- return self._errors
+class Field(object):
+ _creation_counter = 0
- def widget_html(self):
- if not self.widget:
- return ''
+ default_error_messages = {
+ 'required': _('This field is required.')
+ }
+ default_validators = []
- attrs = {}
- if 'id' not in self.widget.attrs:
- attrs['id'] = self._name
+ def __init__(self, read_only=False, write_only=False,
+ required=None, default=empty, initial=None, source=None,
+ label=None, help_text=None, style=None,
+ error_messages=None, validators=[]):
+ self._creation_counter = Field._creation_counter
+ Field._creation_counter += 1
- return self.widget.render(self._name, self._value, attrs=attrs)
+ # If `required` is unset, then use `True` unless a default is provided.
+ if required is None:
+ required = default is empty and not read_only
- def label_tag(self):
- return '<label for="%s">%s:</label>' % (self._name, self.label)
+ # Some combinations of keyword arguments do not make sense.
+ assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY
+ assert not (read_only and required), NOT_READ_ONLY_REQUIRED
+ assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
+ assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
- def initialize(self, parent, field_name):
- """
- Called to set up a field prior to field_to_native or field_from_native.
+ self.read_only = read_only
+ self.write_only = write_only
+ self.required = required
+ self.default = default
+ self.source = source
+ self.initial = initial
+ self.label = label
+ self.help_text = help_text
+ self.style = {} if style is None else style
+ self.validators = validators or self.default_validators[:]
- parent - The parent serializer.
- field_name - The name of the field being initialized.
- """
- self.parent = parent
- self.root = parent.root or parent
- self.context = self.root.context
- self.partial = self.root.partial
- if self.partial:
- self.required = False
+ # Collect default error message from self and parent classes
+ messages = {}
+ for cls in reversed(self.__class__.__mro__):
+ messages.update(getattr(cls, 'default_error_messages', {}))
+ messages.update(error_messages or {})
+ self.error_messages = messages
- def field_from_native(self, data, files, field_name, into):
+ def __new__(cls, *args, **kwargs):
"""
- Given a dictionary and a field name, updates the dictionary `into`,
- with the field and it's deserialized value.
+ When a field is instantiated, we store the arguments that were used,
+ so that we can present a helpful representation of the object.
"""
- return
+ instance = super(Field, cls).__new__(cls)
+ instance._args = args
+ instance._kwargs = kwargs
+ return instance
- def field_to_native(self, obj, field_name):
+ def bind(self, field_name, parent, root):
"""
- Given an object and a field name, returns the value that should be
- serialized for that field.
+ Setup the context for the field instance.
"""
- if obj is None:
- return self.empty
-
- if self.source == '*':
- return self.to_native(obj)
+ self.field_name = field_name
+ self.parent = parent
+ self.root = root
+ self.context = parent.context
- source = self.source or field_name
- value = obj
+ # `self.label` should deafult to being based on the field name.
+ if self.label is None:
+ self.label = field_name.replace('_', ' ').capitalize()
- for component in source.split('.'):
- value = get_component(value, component)
- if value is None:
- break
+ # self.source should default to being the same as the field name.
+ if self.source is None:
+ self.source = field_name
- return self.to_native(value)
+ # self.source_attrs is a list of attributes that need to be looked up
+ # when serializing the instance, or populating the validated data.
+ if self.source == '*':
+ self.source_attrs = []
+ else:
+ self.source_attrs = self.source.split('.')
- def to_native(self, value):
+ def get_initial(self):
"""
- Converts the field's value into it's simple representation.
+ Return a value to use when the field is being returned as a primative
+ value, without any object instance.
"""
- if is_simple_callable(value):
- value = value()
-
- if is_protected_type(value):
- return value
- elif (is_non_str_iterable(value) and
- not isinstance(value, (dict, six.string_types))):
- return [self.to_native(item) for item in value]
- elif isinstance(value, dict):
- # Make sure we preserve field ordering, if it exists
- ret = SortedDict()
- for key, val in value.items():
- ret[key] = self.to_native(val)
- return ret
- return force_text(value)
+ return self.initial
- def attributes(self):
+ def get_value(self, dictionary):
"""
- Returns a dictionary of attributes to be used when serializing to xml.
+ Given the *incoming* primative data, return the value for this field
+ that should be validated and transformed to a native value.
"""
- if self.type_name:
- return {'type': self.type_name}
- return {}
-
- def metadata(self):
- metadata = SortedDict()
- metadata['type'] = self.type_label
- metadata['required'] = getattr(self, 'required', False)
- optional_attrs = ['read_only', 'label', 'help_text',
- 'min_length', 'max_length']
- for attr in optional_attrs:
- value = getattr(self, attr, None)
- if value is not None and value != '':
- metadata[attr] = force_text(value, strings_only=True)
- return metadata
-
-
-class WritableField(Field):
- """
- Base for read/write fields.
- """
- write_only = False
- default_validators = []
- default_error_messages = {
- 'required': _('This field is required.'),
- 'invalid': _('Invalid value.'),
- }
- widget = widgets.TextInput
- default = None
-
- def __init__(self, source=None, label=None, help_text=None,
- read_only=False, write_only=False, required=None,
- validators=[], error_messages=None, widget=None,
- default=None, blank=None):
-
- super(WritableField, self).__init__(source=source, label=label, help_text=help_text)
-
- self.read_only = read_only
- self.write_only = write_only
-
- assert not (read_only and write_only), "Cannot set read_only=True and write_only=True"
-
- if required is None:
- self.required = not(read_only)
- else:
- assert not (read_only and required), "Cannot set required=True and read_only=True"
- self.required = required
+ return dictionary.get(self.field_name, empty)
- messages = {}
- for c in reversed(self.__class__.__mro__):
- messages.update(getattr(c, 'default_error_messages', {}))
- messages.update(error_messages or {})
- self.error_messages = messages
+ def get_attribute(self, instance):
+ """
+ Given the *outgoing* object instance, return the value for this field
+ that should be returned as a primative value.
+ """
+ return get_attribute(instance, self.source_attrs)
- self.validators = self.default_validators + validators
- self.default = default if default is not None else self.default
+ def get_default(self):
+ """
+ Return the default value to use when validating data if no input
+ is provided for this field.
- # Widgets are only used for HTML forms.
- widget = widget or self.widget
- if isinstance(widget, type):
- widget = widget()
- self.widget = widget
+ If a default has not been set for this field then this will simply
+ return `empty`, indicating that no value should be set in the
+ validated data for this field.
+ """
+ if self.default is empty:
+ raise SkipField()
+ return self.default
- def __deepcopy__(self, memo):
- result = copy.copy(self)
- memo[id(self)] = result
- result.validators = self.validators[:]
- return result
+ def run_validation(self, data=empty):
+ """
+ Validate a simple representation and return the internal value.
- def get_default_value(self):
- if is_simple_callable(self.default):
- return self.default()
- return self.default
+ The provided data may be `empty` if no representation was included.
+ May return `empty` if the field should not be included in the
+ validated data.
+ """
+ if data is empty:
+ if self.required:
+ self.fail('required')
+ return self.get_default()
- def validate(self, value):
- if value in validators.EMPTY_VALUES and self.required:
- raise ValidationError(self.error_messages['required'])
+ value = self.to_internal_value(data)
+ self.run_validators(value)
+ return value
def run_validators(self, value):
- if value in validators.EMPTY_VALUES:
+ if value in (None, '', [], (), {}):
return
+
errors = []
- for v in self.validators:
+ for validator in self.validators:
try:
- v(value)
- except ValidationError as e:
- if hasattr(e, 'code') and e.code in self.error_messages:
- message = self.error_messages[e.code]
- if e.params:
- message = message % e.params
- errors.append(message)
- else:
- errors.extend(e.messages)
+ validator(value)
+ except ValidationError as exc:
+ errors.extend(exc.messages)
if errors:
raise ValidationError(errors)
- def field_to_native(self, obj, field_name):
- if self.write_only:
- return None
- return super(WritableField, self).field_to_native(obj, field_name)
-
- def field_from_native(self, data, files, field_name, into):
+ def to_internal_value(self, data):
"""
- Given a dictionary and a field name, updates the dictionary `into`,
- with the field and it's deserialized value.
+ Transform the *incoming* primative data into a native value.
"""
- if self.read_only:
- return
-
- try:
- data = data or {}
- if self.use_files:
- files = files or {}
- try:
- native = files[field_name]
- except KeyError:
- native = data[field_name]
- else:
- native = data[field_name]
- except KeyError:
- if self.default is not None and not self.partial:
- # Note: partial updates shouldn't set defaults
- native = self.get_default_value()
- else:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- return
-
- value = self.from_native(native)
- if self.source == '*':
- if value:
- into.update(value)
- else:
- self.validate(value)
- self.run_validators(value)
- into[self.source or field_name] = value
+ raise NotImplementedError('to_internal_value() must be implemented.')
- def from_native(self, value):
+ def to_representation(self, value):
"""
- Reverts a simple representation back to the field's value.
+ Transform the *outgoing* native value into primative data.
"""
- return value
-
+ raise NotImplementedError('to_representation() must be implemented.')
-class ModelField(WritableField):
- """
- A generic field that can be used against an arbitrary model field.
- """
- def __init__(self, *args, **kwargs):
+ def fail(self, key, **kwargs):
+ """
+ A helper method that simply raises a validation error.
+ """
try:
- self.model_field = kwargs.pop('model_field')
+ msg = self.error_messages[key]
except KeyError:
- raise ValueError("ModelField requires 'model_field' kwarg")
-
- self.min_length = kwargs.pop('min_length',
- getattr(self.model_field, 'min_length', None))
- self.max_length = kwargs.pop('max_length',
- getattr(self.model_field, 'max_length', None))
- self.min_value = kwargs.pop('min_value',
- getattr(self.model_field, 'min_value', None))
- self.max_value = kwargs.pop('max_value',
- getattr(self.model_field, 'max_value', None))
-
- super(ModelField, self).__init__(*args, **kwargs)
-
- if self.min_length is not None:
- self.validators.append(validators.MinLengthValidator(self.min_length))
- if self.max_length is not None:
- self.validators.append(validators.MaxLengthValidator(self.max_length))
- if self.min_value is not None:
- self.validators.append(validators.MinValueValidator(self.min_value))
- if self.max_value is not None:
- self.validators.append(validators.MaxValueValidator(self.max_value))
+ class_name = self.__class__.__name__
+ msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
+ raise AssertionError(msg)
+ raise ValidationError(msg.format(**kwargs))
- def from_native(self, value):
- rel = getattr(self.model_field, "rel", None)
- if rel is not None:
- return rel.to._meta.get_field(rel.field_name).to_python(value)
- else:
- return self.model_field.to_python(value)
-
- def field_to_native(self, obj, field_name):
- value = self.model_field._get_val_from_obj(obj)
- if is_protected_type(value):
- return value
- return self.model_field.value_to_string(obj)
+ def __repr__(self):
+ return representation.field_repr(self)
- def attributes(self):
- return {
- "type": self.model_field.get_internal_type()
- }
+# Boolean types...
-# Typed Fields
-
-class BooleanField(WritableField):
- type_name = 'BooleanField'
- type_label = 'boolean'
- form_field_class = forms.BooleanField
- widget = widgets.CheckboxInput
+class BooleanField(Field):
default_error_messages = {
- 'invalid': _("'%s' value must be either True or False."),
+ 'invalid': _('`{input}` is not a valid boolean.')
}
- empty = False
-
- def field_from_native(self, data, files, field_name, into):
- # HTML checkboxes do not explicitly represent unchecked as `False`
- # we deal with that here...
- if isinstance(data, QueryDict) and self.default is None:
- self.default = False
-
- return super(BooleanField, self).field_from_native(
- data, files, field_name, into
- )
+ TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
+ FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
+
+ def get_value(self, dictionary):
+ if html.is_html_input(dictionary):
+ # HTML forms do not send a `False` value on an empty checkbox,
+ # so we override the default empty value to be False.
+ return dictionary.get(self.field_name, False)
+ return dictionary.get(self.field_name, empty)
+
+ def to_internal_value(self, data):
+ if data in self.TRUE_VALUES:
+ return True
+ elif data in self.FALSE_VALUES:
+ return False
+ self.fail('invalid', input=data)
- def from_native(self, value):
- if value in ('true', 't', 'True', '1'):
+ def to_representation(self, value):
+ if value is None:
+ return None
+ if value in self.TRUE_VALUES:
return True
- if value in ('false', 'f', 'False', '0'):
+ elif value in self.FALSE_VALUES:
return False
return bool(value)
-class CharField(WritableField):
- type_name = 'CharField'
- type_label = 'string'
- form_field_class = forms.CharField
+# String types...
- def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs):
- self.max_length, self.min_length = max_length, min_length
- self.allow_none = allow_none
- super(CharField, self).__init__(*args, **kwargs)
- if min_length is not None:
- self.validators.append(validators.MinLengthValidator(min_length))
- if max_length is not None:
- self.validators.append(validators.MaxLengthValidator(max_length))
+class CharField(Field):
+ default_error_messages = {
+ 'blank': _('This field may not be blank.')
+ }
- def from_native(self, value):
- if isinstance(value, six.string_types):
- return value
+ def __init__(self, **kwargs):
+ self.allow_blank = kwargs.pop('allow_blank', False)
+ self.max_length = kwargs.pop('max_length', None)
+ self.min_length = kwargs.pop('min_length', None)
+ super(CharField, self).__init__(**kwargs)
+
+ def to_internal_value(self, data):
+ if data == '' and not self.allow_blank:
+ self.fail('blank')
+ if data is None:
+ return None
+ return str(data)
+ def to_representation(self, value):
if value is None:
- if not self.allow_none:
- return ''
- else:
- # Return None explicitly because smart_text(None) == 'None'. See #1834 for details
- return None
+ return None
+ return str(value)
+
- return smart_text(value)
+class EmailField(CharField):
+ default_error_messages = {
+ 'invalid': _('Enter a valid email address.')
+ }
+ default_validators = [validators.validate_email]
+ def to_internal_value(self, data):
+ if data == '' and not self.allow_blank:
+ self.fail('blank')
+ if data is None:
+ return None
+ return str(data).strip()
-class URLField(CharField):
- type_name = 'URLField'
- type_label = 'url'
+ def to_representation(self, value):
+ if value is None:
+ return None
+ return str(value).strip()
- def __init__(self, **kwargs):
- if 'validators' not in kwargs:
- kwargs['validators'] = [validators.URLValidator()]
- super(URLField, self).__init__(**kwargs)
+class RegexField(CharField):
+ def __init__(self, regex, **kwargs):
+ kwargs['validators'] = (
+ [validators.RegexValidator(regex)] +
+ kwargs.get('validators', [])
+ )
+ super(RegexField, self).__init__(**kwargs)
-class SlugField(CharField):
- type_name = 'SlugField'
- type_label = 'slug'
- form_field_class = forms.SlugField
+class SlugField(CharField):
default_error_messages = {
- 'invalid': _("Enter a valid 'slug' consisting of letters, numbers,"
- " underscores or hyphens."),
+ 'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")
}
default_validators = [validators.validate_slug]
- def __init__(self, *args, **kwargs):
- super(SlugField, self).__init__(*args, **kwargs)
-
-class ChoiceField(WritableField):
- type_name = 'ChoiceField'
- type_label = 'choice'
- form_field_class = forms.ChoiceField
- widget = widgets.Select
+class URLField(CharField):
default_error_messages = {
- 'invalid_choice': _('Select a valid choice. %(value)s is not one of '
- 'the available choices.'),
+ 'invalid': _("Enter a valid URL.")
}
+ default_validators = [validators.URLValidator()]
- def __init__(self, choices=(), blank_display_value=None, *args, **kwargs):
- self.empty = kwargs.pop('empty', '')
- super(ChoiceField, self).__init__(*args, **kwargs)
- self.choices = choices
- if not self.required:
- if blank_display_value is None:
- blank_choice = BLANK_CHOICE_DASH
- else:
- blank_choice = [('', blank_display_value)]
- self.choices = blank_choice + self.choices
- def _get_choices(self):
- return self._choices
+# Number types...
- def _set_choices(self, value):
- # Setting choices also sets the choices on the widget.
- # choices can be any iterable, but we call list() on it because
- # it will be consumed more than once.
- self._choices = self.widget.choices = list(value)
+class IntegerField(Field):
+ default_error_messages = {
+ 'invalid': _('A valid integer is required.')
+ }
- choices = property(_get_choices, _set_choices)
+ def __init__(self, **kwargs):
+ max_value = kwargs.pop('max_value', None)
+ min_value = kwargs.pop('min_value', None)
+ super(IntegerField, self).__init__(**kwargs)
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
- def metadata(self):
- data = super(ChoiceField, self).metadata()
- data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices]
+ def to_internal_value(self, data):
+ try:
+ data = int(str(data))
+ except (ValueError, TypeError):
+ self.fail('invalid')
return data
- def validate(self, value):
- """
- Validates that the input is in self.choices.
- """
- super(ChoiceField, self).validate(value)
- if value and not self.valid_value(value):
- raise ValidationError(self.error_messages['invalid_choice'] % {'value': value})
+ def to_representation(self, value):
+ if value is None:
+ return None
+ return int(value)
- def valid_value(self, value):
- """
- Check to see if the provided value is a valid choice.
- """
- for k, v in self.choices:
- if isinstance(v, (list, tuple)):
- # This is an optgroup, so look inside the group for options
- for k2, v2 in v:
- if value == smart_text(k2):
- return True
- else:
- if value == smart_text(k) or value == k:
- return True
- return False
- def from_native(self, value):
- value = super(ChoiceField, self).from_native(value)
- if value == self.empty or value in validators.EMPTY_VALUES:
- return self.empty
- return value
+class FloatField(Field):
+ default_error_messages = {
+ 'invalid': _("'%s' value must be a float."),
+ }
+
+ def __init__(self, **kwargs):
+ max_value = kwargs.pop('max_value', None)
+ min_value = kwargs.pop('min_value', None)
+ super(FloatField, self).__init__(**kwargs)
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+ def to_internal_value(self, value):
+ if value is None:
+ return None
+ return float(value)
-class EmailField(CharField):
- type_name = 'EmailField'
- type_label = 'email'
- form_field_class = forms.EmailField
+ def to_representation(self, value):
+ if value is None:
+ return None
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ self.fail('invalid', value=value)
+
+class DecimalField(Field):
default_error_messages = {
- 'invalid': _('Enter a valid email address.'),
+ 'invalid': _('Enter a number.'),
+ 'max_value': _('Ensure this value is less than or equal to {max_value}.'),
+ 'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
+ 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
+ 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
+ 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
}
- default_validators = [validators.validate_email]
- def from_native(self, value):
- ret = super(EmailField, self).from_native(value)
- if ret is None:
+ coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING
+
+ def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, **kwargs):
+ self.max_digits = max_digits
+ self.decimal_places = decimal_places
+ self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string
+ super(DecimalField, self).__init__(**kwargs)
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+
+ def to_internal_value(self, value):
+ """
+ Validates that the input is a decimal number. Returns a Decimal
+ instance. Returns None for empty values. Ensures that there are no more
+ than max_digits in the number, and no more than decimal_places digits
+ after the decimal point.
+ """
+ if value in (None, ''):
return None
- return ret.strip()
+ value = smart_text(value).strip()
+ try:
+ value = decimal.Decimal(value)
+ except decimal.DecimalException:
+ self.fail('invalid')
-class RegexField(CharField):
- type_name = 'RegexField'
- type_label = 'regex'
- form_field_class = forms.RegexField
+ # Check for NaN. It is the only value that isn't equal to itself,
+ # so we can use this to identify NaN values.
+ if value != value:
+ self.fail('invalid')
+
+ # Check for infinity and negative infinity.
+ if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')):
+ self.fail('invalid')
- def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs):
- super(RegexField, self).__init__(max_length, min_length, *args, **kwargs)
- self.regex = regex
+ sign, digittuple, exponent = value.as_tuple()
+ decimals = abs(exponent)
+ # digittuple doesn't include any leading zeros.
+ digits = len(digittuple)
+ if decimals > digits:
+ # We have leading zeros up to or past the decimal point. Count
+ # everything past the decimal point as a digit. We do not count
+ # 0 before the decimal point as a digit since that would mean
+ # we would not allow max_digits = decimal_places.
+ digits = decimals
+ whole_digits = digits - decimals
- def _get_regex(self):
- return self._regex
+ if self.max_digits is not None and digits > self.max_digits:
+ self.fail('max_digits', max_digits=self.max_digits)
+ if self.decimal_places is not None and decimals > self.decimal_places:
+ self.fail('max_decimal_places', max_decimal_places=self.decimal_places)
+ if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
+ self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
- def _set_regex(self, regex):
- if isinstance(regex, six.string_types):
- regex = re.compile(regex)
- self._regex = regex
- if hasattr(self, '_regex_validator') and self._regex_validator in self.validators:
- self.validators.remove(self._regex_validator)
- self._regex_validator = validators.RegexValidator(regex=regex)
- self.validators.append(self._regex_validator)
+ return value
- regex = property(_get_regex, _set_regex)
+ def to_representation(self, value):
+ if isinstance(value, decimal.Decimal):
+ context = decimal.getcontext().copy()
+ context.prec = self.max_digits
+ quantized = value.quantize(
+ decimal.Decimal('.1') ** self.decimal_places,
+ context=context
+ )
+ if not self.coerce_to_string:
+ return quantized
+ return '{0:f}'.format(quantized)
+
+ if not self.coerce_to_string:
+ return value
+ return '%.*f' % (self.max_decimal_places, value)
-class DateField(WritableField):
- type_name = 'DateField'
- type_label = 'date'
- widget = widgets.DateInput
- form_field_class = forms.DateField
+# Date & time fields...
+class DateField(Field):
default_error_messages = {
- 'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
+ 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'),
}
- empty = None
- input_formats = api_settings.DATE_INPUT_FORMATS
format = api_settings.DATE_FORMAT
+ input_formats = api_settings.DATE_INPUT_FORMATS
- def __init__(self, input_formats=None, format=None, *args, **kwargs):
- self.input_formats = input_formats if input_formats is not None else self.input_formats
+ def __init__(self, format=None, input_formats=None, *args, **kwargs):
self.format = format if format is not None else self.format
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
super(DateField, self).__init__(*args, **kwargs)
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
+ def to_internal_value(self, value):
+ if value in (None, ''):
return None
if isinstance(value, datetime.datetime):
@@ -647,6 +531,7 @@ class DateField(WritableField):
default_timezone = timezone.get_default_timezone()
value = timezone.make_naive(value, default_timezone)
return value.date()
+
if isinstance(value, datetime.date):
return value
@@ -667,10 +552,10 @@ class DateField(WritableField):
else:
return parsed.date()
- msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats)
- raise ValidationError(msg)
+ humanized_format = humanize_datetime.date_formats(self.input_formats)
+ self.fail('invalid', format=humanized_format)
- def to_native(self, value):
+ def to_representation(self, value):
if value is None or self.format is None:
return value
@@ -682,30 +567,25 @@ class DateField(WritableField):
return value.strftime(self.format)
-class DateTimeField(WritableField):
- type_name = 'DateTimeField'
- type_label = 'datetime'
- widget = widgets.DateTimeInput
- form_field_class = forms.DateTimeField
-
+class DateTimeField(Field):
default_error_messages = {
- 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
+ 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'),
}
- empty = None
- input_formats = api_settings.DATETIME_INPUT_FORMATS
format = api_settings.DATETIME_FORMAT
+ input_formats = api_settings.DATETIME_INPUT_FORMATS
- def __init__(self, input_formats=None, format=None, *args, **kwargs):
- self.input_formats = input_formats if input_formats is not None else self.input_formats
+ def __init__(self, format=None, input_formats=None, *args, **kwargs):
self.format = format if format is not None else self.format
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
super(DateTimeField, self).__init__(*args, **kwargs)
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
+ def to_internal_value(self, value):
+ if value in (None, ''):
return None
if isinstance(value, datetime.datetime):
return value
+
if isinstance(value, datetime.date):
value = datetime.datetime(value.year, value.month, value.day)
if settings.USE_TZ:
@@ -737,10 +617,10 @@ class DateTimeField(WritableField):
else:
return parsed
- msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats)
- raise ValidationError(msg)
+ humanized_format = humanize_datetime.datetime_formats(self.input_formats)
+ self.fail('invalid', format=humanized_format)
- def to_native(self, value):
+ def to_representation(self, value):
if value is None or self.format is None:
return value
@@ -752,26 +632,20 @@ class DateTimeField(WritableField):
return value.strftime(self.format)
-class TimeField(WritableField):
- type_name = 'TimeField'
- type_label = 'time'
- widget = widgets.TimeInput
- form_field_class = forms.TimeField
-
+class TimeField(Field):
default_error_messages = {
- 'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
+ 'invalid': _('Time has wrong format. Use one of these formats instead: {format}'),
}
- empty = None
- input_formats = api_settings.TIME_INPUT_FORMATS
format = api_settings.TIME_FORMAT
+ input_formats = api_settings.TIME_INPUT_FORMATS
- def __init__(self, input_formats=None, format=None, *args, **kwargs):
- self.input_formats = input_formats if input_formats is not None else self.input_formats
+ def __init__(self, format=None, input_formats=None, *args, **kwargs):
self.format = format if format is not None else self.format
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
super(TimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
- if value in validators.EMPTY_VALUES:
+ if value in (None, ''):
return None
if isinstance(value, datetime.time):
@@ -794,10 +668,10 @@ class TimeField(WritableField):
else:
return parsed.time()
- msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats)
- raise ValidationError(msg)
+ humanized_format = humanize_datetime.time_formats(self.input_formats)
+ self.fail('invalid', format=humanized_format)
- def to_native(self, value):
+ def to_representation(self, value):
if value is None or self.format is None:
return value
@@ -809,234 +683,147 @@ class TimeField(WritableField):
return value.strftime(self.format)
-class IntegerField(WritableField):
- type_name = 'IntegerField'
- type_label = 'integer'
- form_field_class = forms.IntegerField
- empty = 0
+# Choice types...
+class ChoiceField(Field):
default_error_messages = {
- 'invalid': _('Enter a whole number.'),
- 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
- 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
+ 'invalid_choice': _('`{input}` is not a valid choice.')
}
- def __init__(self, max_value=None, min_value=None, *args, **kwargs):
- self.max_value, self.min_value = max_value, min_value
- super(IntegerField, self).__init__(*args, **kwargs)
+ def __init__(self, choices, **kwargs):
+ # Allow either single or paired choices style:
+ # choices = [1, 2, 3]
+ # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
+ pairs = [
+ isinstance(item, (list, tuple)) and len(item) == 2
+ for item in choices
+ ]
+ if all(pairs):
+ self.choices = dict([(key, display_value) for key, display_value in choices])
+ else:
+ self.choices = dict([(item, item) for item in choices])
- if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
- if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
+ # Map the string representation of choices to the underlying value.
+ # Allows us to deal with eg. integer choices while supporting either
+ # integer or string input, but still get the correct datatype out.
+ self.choice_strings_to_values = dict([
+ (str(key), key) for key in self.choices.keys()
+ ])
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
+ super(ChoiceField, self).__init__(**kwargs)
+ def to_internal_value(self, data):
try:
- value = int(str(value))
- except (ValueError, TypeError):
- raise ValidationError(self.error_messages['invalid'])
- return value
+ return self.choice_strings_to_values[str(data)]
+ except KeyError:
+ self.fail('invalid_choice', input=data)
+ def to_representation(self, value):
+ return value
-class FloatField(WritableField):
- type_name = 'FloatField'
- type_label = 'float'
- form_field_class = forms.FloatField
- empty = 0
+class MultipleChoiceField(ChoiceField):
default_error_messages = {
- 'invalid': _("'%s' value must be a float."),
+ 'invalid_choice': _('`{input}` is not a valid choice.'),
+ 'not_a_list': _('Expected a list of items but got type `{input_type}`')
}
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
-
- try:
- return float(value)
- except (TypeError, ValueError):
- msg = self.error_messages['invalid'] % value
- raise ValidationError(msg)
-
-
-class DecimalField(WritableField):
- type_name = 'DecimalField'
- type_label = 'decimal'
- form_field_class = forms.DecimalField
- empty = Decimal('0')
-
- default_error_messages = {
- 'invalid': _('Enter a number.'),
- 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
- 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
- 'max_digits': _('Ensure that there are no more than %s digits in total.'),
- 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'),
- 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.')
- }
+ def to_internal_value(self, data):
+ if not hasattr(data, '__iter__'):
+ self.fail('not_a_list', input_type=type(data).__name__)
+ return set([
+ super(MultipleChoiceField, self).to_internal_value(item)
+ for item in data
+ ])
- def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs):
- self.max_value, self.min_value = max_value, min_value
- self.max_digits, self.decimal_places = max_digits, decimal_places
- super(DecimalField, self).__init__(*args, **kwargs)
+ def to_representation(self, value):
+ return value
- if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
- if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
- def from_native(self, value):
- """
- Validates that the input is a decimal number. Returns a Decimal
- instance. Returns None for empty values. Ensures that there are no more
- than max_digits in the number, and no more than decimal_places digits
- after the decimal point.
- """
- if value in validators.EMPTY_VALUES:
- return None
- value = smart_text(value).strip()
- try:
- value = Decimal(value)
- except DecimalException:
- raise ValidationError(self.error_messages['invalid'])
- return value
+# File types...
- def validate(self, value):
- super(DecimalField, self).validate(value)
- if value in validators.EMPTY_VALUES:
- return
- # Check for NaN, Inf and -Inf values. We can't compare directly for NaN,
- # since it is never equal to itself. However, NaN is the only value that
- # isn't equal to itself, so we can use this to identify NaN
- if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
- raise ValidationError(self.error_messages['invalid'])
- sign, digittuple, exponent = value.as_tuple()
- decimals = abs(exponent)
- # digittuple doesn't include any leading zeros.
- digits = len(digittuple)
- if decimals > digits:
- # We have leading zeros up to or past the decimal point. Count
- # everything past the decimal point as a digit. We do not count
- # 0 before the decimal point as a digit since that would mean
- # we would not allow max_digits = decimal_places.
- digits = decimals
- whole_digits = digits - decimals
+class FileField(Field):
+ pass # TODO
- if self.max_digits is not None and digits > self.max_digits:
- raise ValidationError(self.error_messages['max_digits'] % self.max_digits)
- if self.decimal_places is not None and decimals > self.decimal_places:
- raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places)
- if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
- raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places))
- return value
+class ImageField(Field):
+ pass # TODO
-class FileField(WritableField):
- use_files = True
- type_name = 'FileField'
- type_label = 'file upload'
- form_field_class = forms.FileField
- widget = widgets.FileInput
- default_error_messages = {
- 'invalid': _("No file was submitted. Check the encoding type on the form."),
- 'missing': _("No file was submitted."),
- 'empty': _("The submitted file is empty."),
- 'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'),
- 'contradiction': _('Please either submit a file or check the clear checkbox, not both.')
- }
+# Advanced field types...
- def __init__(self, *args, **kwargs):
- self.max_length = kwargs.pop('max_length', None)
- self.allow_empty_file = kwargs.pop('allow_empty_file', False)
- super(FileField, self).__init__(*args, **kwargs)
+class ReadOnlyField(Field):
+ """
+ A read-only field that simply returns the field value.
- def from_native(self, data):
- if data in validators.EMPTY_VALUES:
- return None
+ If the field is a method with no parameters, the method will be called
+ and it's return value used as the representation.
- # UploadedFile objects should have name and size attributes.
- try:
- file_name = data.name
- file_size = data.size
- except AttributeError:
- raise ValidationError(self.error_messages['invalid'])
-
- if self.max_length is not None and len(file_name) > self.max_length:
- error_values = {'max': self.max_length, 'length': len(file_name)}
- raise ValidationError(self.error_messages['max_length'] % error_values)
- if not file_name:
- raise ValidationError(self.error_messages['invalid'])
- if not self.allow_empty_file and not file_size:
- raise ValidationError(self.error_messages['empty'])
+ For example, the following would call `get_expiry_date()` on the object:
- return data
+ class ExampleSerializer(self):
+ expiry_date = ReadOnlyField(source='get_expiry_date')
+ """
- def to_native(self, value):
- return value.name
+ def __init__(self, **kwargs):
+ kwargs['read_only'] = True
+ super(ReadOnlyField, self).__init__(**kwargs)
+ def to_representation(self, value):
+ if is_simple_callable(value):
+ return value()
+ return value
-class ImageField(FileField):
- use_files = True
- type_name = 'ImageField'
- type_label = 'image upload'
- form_field_class = forms.ImageField
- default_error_messages = {
- 'invalid_image': _("Upload a valid image. The file you uploaded was "
- "either not an image or a corrupted image."),
- }
+class SerializerMethodField(Field):
+ """
+ A read-only field that get its representation from calling a method on the
+ parent serializer class. The method called will be of the form
+ "get_{field_name}", and should take a single argument, which is the
+ object being serialized.
- def from_native(self, data):
- """
- Checks that the file-upload field data contains a valid image (GIF, JPG,
- PNG, possibly others -- whatever the Python Imaging Library supports).
- """
- f = super(ImageField, self).from_native(data)
- if f is None:
- return None
+ For example:
- from rest_framework.compat import Image
- assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.'
+ class ExampleSerializer(self):
+ extra_info = SerializerMethodField()
- # We need to get a file object for PIL. We might have a path or we might
- # have to read the data into memory.
- if hasattr(data, 'temporary_file_path'):
- file = data.temporary_file_path()
- else:
- if hasattr(data, 'read'):
- file = BytesIO(data.read())
- else:
- file = BytesIO(data['content'])
+ def get_extra_info(self, obj):
+ return ... # Calculate some data to return.
+ """
+ def __init__(self, method_attr=None, **kwargs):
+ self.method_attr = method_attr
+ kwargs['source'] = '*'
+ kwargs['read_only'] = True
+ super(SerializerMethodField, self).__init__(**kwargs)
- try:
- # load() could spot a truncated JPEG, but it loads the entire
- # image in memory, which is a DoS vector. See #3848 and #18520.
- # verify() must be called immediately after the constructor.
- Image.open(file).verify()
- except ImportError:
- # Under PyPy, it is possible to import PIL. However, the underlying
- # _imaging C module isn't available, so an ImportError will be
- # raised. Catch and re-raise.
- raise
- except Exception: # Python Imaging Library doesn't recognize it as an image
- raise ValidationError(self.error_messages['invalid_image'])
- if hasattr(f, 'seek') and callable(f.seek):
- f.seek(0)
- return f
+ def to_representation(self, value):
+ method_attr = self.method_attr
+ if method_attr is None:
+ method_attr = 'get_{field_name}'.format(field_name=self.field_name)
+ method = getattr(self.parent, method_attr)
+ return method(value)
-class SerializerMethodField(Field):
+class ModelField(Field):
"""
- A field that gets its value by calling a method on the serializer it's attached to.
+ A generic field that can be used against an arbitrary model field.
+
+ This is used by `ModelSerializer` when dealing with custom model fields,
+ that do not have a serializer field to be mapped to.
"""
+ def __init__(self, model_field, **kwargs):
+ self.model_field = model_field
+ kwargs['source'] = '*'
+ super(ModelField, self).__init__(**kwargs)
- def __init__(self, method_name, *args, **kwargs):
- self.method_name = method_name
- super(SerializerMethodField, self).__init__(*args, **kwargs)
+ def to_internal_value(self, data):
+ rel = getattr(self.model_field, 'rel', None)
+ if rel is not None:
+ return rel.to._meta.get_field(rel.field_name).to_python(data)
+ return self.model_field.to_python(data)
- def field_to_native(self, obj, field_name):
- value = getattr(self.parent, self.method_name)(obj)
- return self.to_native(value)
+ def to_representation(self, obj):
+ value = self.model_field._get_val_from_obj(obj)
+ if is_protected_type(value):
+ return value
+ return self.model_field.value_to_string(obj)
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index a6f68657..eb6b64ef 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -3,7 +3,8 @@ Generic views that provide commonly needed behaviour.
"""
from __future__ import unicode_literals
-from django.core.exceptions import ImproperlyConfigured, PermissionDenied
+from django.db.models.query import QuerySet
+from django.core.exceptions import PermissionDenied
from django.core.paginator import Paginator, InvalidPage
from django.http import Http404
from django.shortcuts import get_object_or_404 as _get_object_or_404
@@ -11,7 +12,6 @@ from django.utils.translation import ugettext as _
from rest_framework import views, mixins, exceptions
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
-import warnings
def strict_positive_int(integer_string, cutoff=None):
@@ -28,7 +28,7 @@ def strict_positive_int(integer_string, cutoff=None):
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
"""
- Same as Django's standard shortcut, but make sure to raise 404
+ Same as Django's standard shortcut, but make sure to also raise 404
if the filter_kwargs don't match the required types.
"""
try:
@@ -51,11 +51,6 @@ class GenericAPIView(views.APIView):
queryset = None
serializer_class = None
- # This shortcut may be used instead of setting either or both
- # of the `queryset`/`serializer_class` attributes, although using
- # the explicit style is generally preferred.
- model = None
-
# If you want to use object lookups other than pk, set this attribute.
# For more complex lookup requirements override `get_object()`.
lookup_field = 'pk'
@@ -71,20 +66,10 @@ class GenericAPIView(views.APIView):
# The filter backend classes to use for queryset filtering
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
- # The following attributes may be subject to change,
+ # The following attribute may be subject to change,
# and should be considered private API.
- model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
paginator_class = Paginator
- ######################################
- # These are pending deprecation...
-
- pk_url_kwarg = 'pk'
- slug_url_kwarg = 'slug'
- slug_field = 'slug'
- allow_empty = True
- filter_backend = api_settings.FILTER_BACKEND
-
def get_serializer_context(self):
"""
Extra context provided to the serializer class.
@@ -95,18 +80,16 @@ class GenericAPIView(views.APIView):
'view': self
}
- def get_serializer(self, instance=None, data=None, files=None, many=False,
- partial=False, allow_add_remove=False):
+ def get_serializer(self, instance=None, data=None, many=False, partial=False):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
"""
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
- return serializer_class(instance, data=data, files=files,
- many=many, partial=partial,
- allow_add_remove=allow_add_remove,
- context=context)
+ return serializer_class(
+ instance, data=data, many=many, partial=partial, context=context
+ )
def get_pagination_serializer(self, page):
"""
@@ -120,37 +103,16 @@ class GenericAPIView(views.APIView):
context = self.get_serializer_context()
return pagination_serializer_class(instance=page, context=context)
- def paginate_queryset(self, queryset, page_size=None):
+ def paginate_queryset(self, queryset):
"""
Paginate a queryset if required, either returning a page object,
or `None` if pagination is not configured for this view.
"""
- deprecated_style = False
- if page_size is not None:
- warnings.warn('The `page_size` parameter to `paginate_queryset()` '
- 'is deprecated. '
- 'Note that the return style of this method is also '
- 'changed, and will simply return a page object '
- 'when called without a `page_size` argument.',
- DeprecationWarning, stacklevel=2)
- deprecated_style = True
- else:
- # Determine the required page size.
- # If pagination is not configured, simply return None.
- page_size = self.get_paginate_by()
- if not page_size:
- return None
-
- if not self.allow_empty:
- warnings.warn(
- 'The `allow_empty` parameter is deprecated. '
- 'To use `allow_empty=False` style behavior, You should override '
- '`get_queryset()` and explicitly raise a 404 on empty querysets.',
- DeprecationWarning, stacklevel=2
- )
-
- paginator = self.paginator_class(queryset, page_size,
- allow_empty_first_page=self.allow_empty)
+ page_size = self.get_paginate_by()
+ if not page_size:
+ return None
+
+ paginator = self.paginator_class(queryset, page_size)
page_kwarg = self.kwargs.get(self.page_kwarg)
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1
@@ -170,8 +132,6 @@ class GenericAPIView(views.APIView):
'message': str(exc)
})
- if deprecated_style:
- return (paginator, page, page.object_list, page.has_other_pages())
return page
def filter_queryset(self, queryset):
@@ -191,29 +151,12 @@ class GenericAPIView(views.APIView):
"""
Returns the list of filter backends that this view requires.
"""
- if self.filter_backends is None:
- filter_backends = []
- else:
- # Note that we are returning a *copy* of the class attribute,
- # so that it is safe for the view to mutate it if needed.
- filter_backends = list(self.filter_backends)
-
- if not filter_backends and self.filter_backend:
- warnings.warn(
- 'The `filter_backend` attribute and `FILTER_BACKEND` setting '
- 'are deprecated in favor of a `filter_backends` '
- 'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take '
- 'a *list* of filter backend classes.',
- DeprecationWarning, stacklevel=2
- )
- filter_backends = [self.filter_backend]
-
- return filter_backends
+ return list(self.filter_backends)
# The following methods provide default implementations
# that you may want to override for more complex cases.
- def get_paginate_by(self, queryset=None):
+ def get_paginate_by(self):
"""
Return the size of pages to use with pagination.
@@ -222,11 +165,6 @@ class GenericAPIView(views.APIView):
Otherwise defaults to using `self.paginate_by`.
"""
- if queryset is not None:
- warnings.warn('The `queryset` parameter to `get_paginate_by()` '
- 'is deprecated.',
- DeprecationWarning, stacklevel=2)
-
if self.paginate_by_param:
try:
return strict_positive_int(
@@ -248,26 +186,13 @@ class GenericAPIView(views.APIView):
(Eg. admins get full serialization, others get basic serialization)
"""
- serializer_class = self.serializer_class
- if serializer_class is not None:
- return serializer_class
-
- warnings.warn(
- 'The `.model` attribute on view classes is now deprecated in favor '
- 'of the more explicit `serializer_class` and `queryset` attributes.',
- DeprecationWarning, stacklevel=2
- )
-
- assert self.model is not None, \
- "'%s' should either include a 'serializer_class' attribute, " \
- "or use the 'model' attribute as a shortcut for " \
- "automatically generating a serializer class." \
+ assert self.serializer_class is not None, (
+ "'%s' should either include a `serializer_class` attribute, "
+ "or override the `get_serializer_class()` method."
% self.__class__.__name__
+ )
- class DefaultSerializer(self.model_serializer_class):
- class Meta:
- model = self.model
- return DefaultSerializer
+ return self.serializer_class
def get_queryset(self):
"""
@@ -284,21 +209,19 @@ class GenericAPIView(views.APIView):
(Eg. return a list of items that is specific to the user)
"""
- if self.queryset is not None:
- return self.queryset._clone()
-
- if self.model is not None:
- warnings.warn(
- 'The `.model` attribute on view classes is now deprecated in favor '
- 'of the more explicit `serializer_class` and `queryset` attributes.',
- DeprecationWarning, stacklevel=2
- )
- return self.model._default_manager.all()
+ assert self.queryset is not None, (
+ "'%s' should either include a `queryset` attribute, "
+ "or override the `get_queryset()` method."
+ % self.__class__.__name__
+ )
- error_format = "'%s' must define 'queryset' or 'model'"
- raise ImproperlyConfigured(error_format % self.__class__.__name__)
+ queryset = self.queryset
+ if isinstance(queryset, QuerySet):
+ # Ensure queryset is re-evaluated on each request.
+ queryset = queryset.all()
+ return queryset
- def get_object(self, queryset=None):
+ def get_object(self):
"""
Returns the object the view is displaying.
@@ -306,43 +229,19 @@ class GenericAPIView(views.APIView):
queryset lookups. Eg if objects are referenced using multiple
keyword arguments in the url conf.
"""
- # Determine the base queryset to use.
- if queryset is None:
- queryset = self.filter_queryset(self.get_queryset())
- else:
- pass # Deprecation warning
+ queryset = self.filter_queryset(self.get_queryset())
# Perform the lookup filtering.
- # Note that `pk` and `slug` are deprecated styles of lookup filtering.
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
- lookup = self.kwargs.get(lookup_url_kwarg, None)
- pk = self.kwargs.get(self.pk_url_kwarg, None)
- slug = self.kwargs.get(self.slug_url_kwarg, None)
-
- if lookup is not None:
- filter_kwargs = {self.lookup_field: lookup}
- elif pk is not None and self.lookup_field == 'pk':
- warnings.warn(
- 'The `pk_url_kwarg` attribute is deprecated. '
- 'Use the `lookup_field` attribute instead',
- DeprecationWarning
- )
- filter_kwargs = {'pk': pk}
- elif slug is not None and self.lookup_field == 'pk':
- warnings.warn(
- 'The `slug_url_kwarg` attribute is deprecated. '
- 'Use the `lookup_field` attribute instead',
- DeprecationWarning
- )
- filter_kwargs = {self.slug_field: slug}
- else:
- raise ImproperlyConfigured(
- 'Expected view %s to be called with a URL keyword argument '
- 'named "%s". Fix your URL conf, or set the `.lookup_field` '
- 'attribute on the view correctly.' %
- (self.__class__.__name__, self.lookup_field)
- )
+ assert lookup_url_kwarg in self.kwargs, (
+ 'Expected view %s to be called with a URL keyword argument '
+ 'named "%s". Fix your URL conf, or set the `.lookup_field` '
+ 'attribute on the view correctly.' %
+ (self.__class__.__name__, lookup_url_kwarg)
+ )
+
+ filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
obj = get_object_or_404(queryset, **filter_kwargs)
# May raise a permission denied
@@ -355,34 +254,6 @@ class GenericAPIView(views.APIView):
#
# The are not called by GenericAPIView directly,
# but are used by the mixin methods.
-
- def pre_save(self, obj):
- """
- Placeholder method for calling before saving an object.
-
- May be used to set attributes on the object that are implicit
- in either the request, or the url.
- """
- pass
-
- def post_save(self, obj, created=False):
- """
- Placeholder method for calling after saving an object.
- """
- pass
-
- def pre_delete(self, obj):
- """
- Placeholder method for calling before deleting an object.
- """
- pass
-
- def post_delete(self, obj):
- """
- Placeholder method for calling after deleting an object.
- """
- pass
-
def metadata(self, request):
"""
Return a dictionary of metadata about the view.
@@ -540,25 +411,3 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
-
-
-# Deprecated classes
-
-class MultipleObjectAPIView(GenericAPIView):
- def __init__(self, *args, **kwargs):
- warnings.warn(
- 'Subclassing `MultipleObjectAPIView` is deprecated. '
- 'You should simply subclass `GenericAPIView` instead.',
- DeprecationWarning, stacklevel=2
- )
- super(MultipleObjectAPIView, self).__init__(*args, **kwargs)
-
-
-class SingleObjectAPIView(GenericAPIView):
- def __init__(self, *args, **kwargs):
- warnings.warn(
- 'Subclassing `SingleObjectAPIView` is deprecated. '
- 'You should simply subclass `GenericAPIView` instead.',
- DeprecationWarning, stacklevel=2
- )
- super(SingleObjectAPIView, self).__init__(*args, **kwargs)
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 2cc87eef..14a6b44b 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -6,40 +6,11 @@ which allows mixin classes to be composed in interesting ways.
"""
from __future__ import unicode_literals
-from django.core.exceptions import ValidationError
from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
-import warnings
-
-
-def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None):
- """
- Given a model instance, and an optional pk and slug field,
- return the full list of all other field names on that model.
-
- For use when performing full_clean on a model instance,
- so we only clean the required fields.
- """
- include = []
-
- if pk:
- # Deprecated
- pk_field = obj._meta.pk
- while pk_field.rel:
- pk_field = pk_field.rel.to._meta.pk
- include.append(pk_field.name)
-
- if slug_field:
- # Deprecated
- include.append(slug_field)
-
- if lookup_field and lookup_field != 'pk':
- include.append(lookup_field)
-
- return [field.name for field in obj._meta.fields if field.name not in include]
class CreateModelMixin(object):
@@ -47,17 +18,11 @@ class CreateModelMixin(object):
Create a model instance.
"""
def create(self, request, *args, **kwargs):
- serializer = self.get_serializer(data=request.DATA, files=request.FILES)
-
- if serializer.is_valid():
- self.pre_save(serializer.object)
- self.object = serializer.save(force_insert=True)
- self.post_save(self.object, created=True)
- headers = self.get_success_headers(serializer.data)
- return Response(serializer.data, status=status.HTTP_201_CREATED,
- headers=headers)
-
- return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ serializer = self.get_serializer(data=request.DATA)
+ serializer.is_valid(raise_exception=True)
+ serializer.save()
+ headers = self.get_success_headers(serializer.data)
+ return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def get_success_headers(self, data):
try:
@@ -70,31 +35,13 @@ class ListModelMixin(object):
"""
List a queryset.
"""
- empty_error = "Empty list and '%(class_name)s.allow_empty' is False."
-
def list(self, request, *args, **kwargs):
- self.object_list = self.filter_queryset(self.get_queryset())
-
- # Default is to allow empty querysets. This can be altered by setting
- # `.allow_empty = False`, to raise 404 errors on empty querysets.
- if not self.allow_empty and not self.object_list:
- warnings.warn(
- 'The `allow_empty` parameter is deprecated. '
- 'To use `allow_empty=False` style behavior, You should override '
- '`get_queryset()` and explicitly raise a 404 on empty querysets.',
- DeprecationWarning
- )
- class_name = self.__class__.__name__
- error_msg = self.empty_error % {'class_name': class_name}
- raise Http404(error_msg)
-
- # Switch between paginated or standard style responses
- page = self.paginate_queryset(self.object_list)
+ instance = self.filter_queryset(self.get_queryset())
+ page = self.paginate_queryset(instance)
if page is not None:
serializer = self.get_pagination_serializer(page)
else:
- serializer = self.get_serializer(self.object_list, many=True)
-
+ serializer = self.get_serializer(instance, many=True)
return Response(serializer.data)
@@ -103,8 +50,8 @@ class RetrieveModelMixin(object):
Retrieve a model instance.
"""
def retrieve(self, request, *args, **kwargs):
- self.object = self.get_object()
- serializer = self.get_serializer(self.object)
+ instance = self.get_object()
+ serializer = self.get_serializer(instance)
return Response(serializer.data)
@@ -114,29 +61,52 @@ class UpdateModelMixin(object):
"""
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
- self.object = self.get_object_or_none()
+ instance = self.get_object()
+ serializer = self.get_serializer(instance, data=request.DATA, partial=partial)
+ serializer.is_valid(raise_exception=True)
+ serializer.save()
+ return Response(serializer.data)
- serializer = self.get_serializer(self.object, data=request.DATA,
- files=request.FILES, partial=partial)
+ def partial_update(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
- if not serializer.is_valid():
- return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- try:
- self.pre_save(serializer.object)
- except ValidationError as err:
- # full_clean on model instance may be called in pre_save,
- # so we have to handle eventual errors.
- return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST)
-
- if self.object is None:
- self.object = serializer.save(force_insert=True)
- self.post_save(self.object, created=True)
+class DestroyModelMixin(object):
+ """
+ Destroy a model instance.
+ """
+ def destroy(self, request, *args, **kwargs):
+ instance = self.get_object()
+ instance.delete()
+ return Response(status=status.HTTP_204_NO_CONTENT)
+
+
+# The AllowPUTAsCreateMixin was previously the default behaviour
+# for PUT requests. This has now been removed and must be *explictly*
+# included if it is the behavior that you want.
+# For more info see: ...
+
+class AllowPUTAsCreateMixin(object):
+ """
+ The following mixin class may be used in order to support PUT-as-create
+ behavior for incoming requests.
+ """
+ def update(self, request, *args, **kwargs):
+ partial = kwargs.pop('partial', False)
+ instance = self.get_object_or_none()
+ serializer = self.get_serializer(instance, data=request.DATA, partial=partial)
+ serializer.is_valid(raise_exception=True)
+
+ if instance is None:
+ lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
+ lookup_value = self.kwargs[lookup_url_kwarg]
+ extras = {self.lookup_field: lookup_value}
+ serializer.save(extras=extras)
return Response(serializer.data, status=status.HTTP_201_CREATED)
- self.object = serializer.save(force_update=True)
- self.post_save(self.object, created=False)
- return Response(serializer.data, status=status.HTTP_200_OK)
+ serializer.save()
+ return Response(serializer.data)
def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True
@@ -156,41 +126,3 @@ class UpdateModelMixin(object):
# PATCH requests where the object does not exist should still
# return a 404 response.
raise
-
- def pre_save(self, obj):
- """
- Set any attributes on the object that are implicit in the request.
- """
- # pk and/or slug attributes are implicit in the URL.
- lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
- lookup = self.kwargs.get(lookup_url_kwarg, None)
- pk = self.kwargs.get(self.pk_url_kwarg, None)
- slug = self.kwargs.get(self.slug_url_kwarg, None)
- slug_field = slug and self.slug_field or None
-
- if lookup:
- setattr(obj, self.lookup_field, lookup)
-
- if pk:
- setattr(obj, 'pk', pk)
-
- if slug:
- setattr(obj, slug_field, slug)
-
- # Ensure we clean the attributes so that we don't eg return integer
- # pk using a string representation, as provided by the url conf kwarg.
- if hasattr(obj, 'full_clean'):
- exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field)
- obj.full_clean(exclude)
-
-
-class DestroyModelMixin(object):
- """
- Destroy a model instance.
- """
- def destroy(self, request, *args, **kwargs):
- obj = self.get_object()
- self.pre_delete(obj)
- obj.delete()
- self.post_delete(obj)
- return Response(status=status.HTTP_204_NO_CONTENT)
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 1f5749f1..c5a9270a 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -13,7 +13,7 @@ class NextPageField(serializers.Field):
"""
page_field = 'page'
- def to_native(self, value):
+ def to_representation(self, value):
if not value.has_next():
return None
page = value.next_page_number()
@@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field):
"""
page_field = 'page'
- def to_native(self, value):
+ def to_representation(self, value):
if not value.has_previous():
return None
page = value.previous_page_number()
@@ -37,7 +37,7 @@ class PreviousPageField(serializers.Field):
return replace_query_param(url, self.page_field, page)
-class DefaultObjectSerializer(serializers.Field):
+class DefaultObjectSerializer(serializers.ReadOnlyField):
"""
If no object serializer is specified, then this serializer will be applied
as the default.
@@ -49,25 +49,11 @@ class DefaultObjectSerializer(serializers.Field):
super(DefaultObjectSerializer, self).__init__(source=source)
-class PaginationSerializerOptions(serializers.SerializerOptions):
- """
- An object that stores the options that may be provided to a
- pagination serializer by using the inner `Meta` class.
-
- Accessible on the instance as `serializer.opts`.
- """
- def __init__(self, meta):
- super(PaginationSerializerOptions, self).__init__(meta)
- self.object_serializer_class = getattr(meta, 'object_serializer_class',
- DefaultObjectSerializer)
-
-
class BasePaginationSerializer(serializers.Serializer):
"""
A base class for pagination serializers to inherit from,
to make implementing custom serializers more easy.
"""
- _options_class = PaginationSerializerOptions
results_field = 'results'
def __init__(self, *args, **kwargs):
@@ -76,22 +62,23 @@ class BasePaginationSerializer(serializers.Serializer):
"""
super(BasePaginationSerializer, self).__init__(*args, **kwargs)
results_field = self.results_field
- object_serializer = self.opts.object_serializer_class
- if 'context' in kwargs:
- context_kwarg = {'context': kwargs['context']}
- else:
- context_kwarg = {}
+ try:
+ object_serializer = self.Meta.object_serializer_class
+ except AttributeError:
+ object_serializer = DefaultObjectSerializer
- self.fields[results_field] = object_serializer(source='object_list',
- many=True,
- **context_kwarg)
+ self.fields[results_field] = serializers.ListSerializer(
+ child=object_serializer(),
+ source='object_list'
+ )
+ self.fields[results_field].bind(results_field, self, self)
class PaginationSerializer(BasePaginationSerializer):
"""
A default implementation of a pagination serializer.
"""
- count = serializers.Field(source='paginator.count')
+ count = serializers.ReadOnlyField(source='paginator.count')
next = NextPageField(source='*')
previous = PreviousPageField(source='*')
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index aa4fd3f1..fa02ecf1 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -11,7 +11,7 @@ from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter
from django.utils import six
-from rest_framework.compat import etree, yaml, force_text
+from rest_framework.compat import etree, yaml, force_text, urlparse
from rest_framework.exceptions import ParseError
from rest_framework import renderers
import json
@@ -48,7 +48,7 @@ class JSONParser(BaseParser):
"""
media_type = 'application/json'
- renderer_class = renderers.UnicodeJSONRenderer
+ renderer_class = renderers.JSONRenderer
def parse(self, stream, media_type=None, parser_context=None):
"""
@@ -290,6 +290,22 @@ class FileUploadParser(BaseParser):
try:
meta = parser_context['request'].META
disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8'))
- return force_text(disposition[1]['filename'])
+ filename_parm = disposition[1]
+ if 'filename*' in filename_parm:
+ return self.get_encoded_filename(filename_parm)
+ return force_text(filename_parm['filename'])
except (AttributeError, KeyError):
pass
+
+ def get_encoded_filename(self, filename_parm):
+ """
+ Handle encoded filenames per RFC6266. See also:
+ http://tools.ietf.org/html/rfc2231#section-4
+ """
+ encoded_filename = force_text(filename_parm['filename*'])
+ try:
+ charset, lang, filename = encoded_filename.split('\'', 2)
+ filename = urlparse.unquote(filename)
+ except (ValueError, LookupError):
+ filename = force_text(filename_parm['filename'])
+ return filename
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 1acbdce2..5aa1f8bd 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -1,356 +1,112 @@
-"""
-Serializer fields that deal with relationships.
-
-These fields allow you to specify the style that should be used to represent
-model relationships, including hyperlinks, primary keys, or slugs.
-"""
-from __future__ import unicode_literals
-from django.core.exceptions import ObjectDoesNotExist, ValidationError
-from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
-from django import forms
-from django.db.models.fields import BLANK_CHOICE_DASH
-from django.forms import widgets
-from django.forms.models import ModelChoiceIterator
-from django.utils.translation import ugettext_lazy as _
-from rest_framework.fields import Field, WritableField, get_component, is_simple_callable
+from rest_framework.compat import smart_text, urlparse
+from rest_framework.fields import Field
from rest_framework.reverse import reverse
-from rest_framework.compat import urlparse
-from rest_framework.compat import smart_text
-import warnings
-
-
-# Relational fields
-
-# Not actually Writable, but subclasses may need to be.
-class RelatedField(WritableField):
- """
- Base class for related model fields.
-
- This represents a relationship using the unicode representation of the target.
- """
- widget = widgets.Select
- many_widget = widgets.SelectMultiple
- form_field_class = forms.ChoiceField
- many_form_field_class = forms.MultipleChoiceField
- null_values = (None, '', 'None')
-
- cache_choices = False
- empty_label = None
- read_only = True
- many = False
-
- def __init__(self, *args, **kwargs):
- queryset = kwargs.pop('queryset', None)
- self.many = kwargs.pop('many', self.many)
- if self.many:
- self.widget = self.many_widget
- self.form_field_class = self.many_form_field_class
-
- kwargs['read_only'] = kwargs.pop('read_only', self.read_only)
- super(RelatedField, self).__init__(*args, **kwargs)
-
- if not self.required:
- # Accessed in ModelChoiceIterator django/forms/models.py:1034
- # If set adds empty choice.
- self.empty_label = BLANK_CHOICE_DASH[0][1]
-
- self.queryset = queryset
-
- def initialize(self, parent, field_name):
- super(RelatedField, self).initialize(parent, field_name)
- if self.queryset is None and not self.read_only:
- manager = getattr(self.parent.opts.model, self.source or field_name)
- if hasattr(manager, 'related'): # Forward
- self.queryset = manager.related.model._default_manager.all()
- else: # Reverse
- self.queryset = manager.field.rel.to._default_manager.all()
-
- # We need this stuff to make form choices work...
-
- def prepare_value(self, obj):
- return self.to_native(obj)
-
- def label_from_instance(self, obj):
- """
- Return a readable representation for use with eg. select widgets.
- """
- desc = smart_text(obj)
- ident = smart_text(self.to_native(obj))
- if desc == ident:
- return desc
- return "%s - %s" % (desc, ident)
-
- def _get_queryset(self):
- return self._queryset
-
- def _set_queryset(self, queryset):
- self._queryset = queryset
- self.widget.choices = self.choices
-
- queryset = property(_get_queryset, _set_queryset)
-
- def _get_choices(self):
- # If self._choices is set, then somebody must have manually set
- # the property self.choices. In this case, just return self._choices.
- if hasattr(self, '_choices'):
- return self._choices
-
- # Otherwise, execute the QuerySet in self.queryset to determine the
- # choices dynamically. Return a fresh ModelChoiceIterator that has not been
- # consumed. Note that we're instantiating a new ModelChoiceIterator *each*
- # time _get_choices() is called (and, thus, each time self.choices is
- # accessed) so that we can ensure the QuerySet has not been consumed. This
- # construct might look complicated but it allows for lazy evaluation of
- # the queryset.
- return ModelChoiceIterator(self)
-
- def _set_choices(self, value):
- # Setting choices also sets the choices on the widget.
- # choices can be any iterable, but we call list() on it because
- # it will be consumed more than once.
- self._choices = self.widget.choices = list(value)
-
- choices = property(_get_choices, _set_choices)
-
- # Default value handling
-
- def get_default_value(self):
- default = super(RelatedField, self).get_default_value()
- if self.many and default is None:
- return []
- return default
-
- # Regular serializer stuff...
-
- def field_to_native(self, obj, field_name):
- try:
- if self.source == '*':
- return self.to_native(obj)
-
- source = self.source or field_name
- value = obj
-
- for component in source.split('.'):
- if value is None:
- break
- value = get_component(value, component)
- except ObjectDoesNotExist:
- return None
+from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured
+from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch, Resolver404
+from django.db.models.query import QuerySet
+from django.utils.translation import ugettext_lazy as _
- if value is None:
- return None
- if self.many:
- if is_simple_callable(getattr(value, 'all', None)):
- return [self.to_native(item) for item in value.all()]
- else:
- # Also support non-queryset iterables.
- # This allows us to also support plain lists of related items.
- return [self.to_native(item) for item in value]
- return self.to_native(value)
+class RelatedField(Field):
+ def __init__(self, **kwargs):
+ self.queryset = kwargs.pop('queryset', None)
+ assert self.queryset is not None or kwargs.get('read_only', None), (
+ 'Relational field must provide a `queryset` argument, '
+ 'or set read_only=`True`.'
+ )
+ assert not (self.queryset is not None and kwargs.get('read_only', None)), (
+ 'Relational fields should not provide a `queryset` argument, '
+ 'when setting read_only=`True`.'
+ )
+ super(RelatedField, self).__init__(**kwargs)
+
+ def __new__(cls, *args, **kwargs):
+ # We override this method in order to automagically create
+ # `ManyRelation` classes instead when `many=True` is set.
+ if kwargs.pop('many', False):
+ return ManyRelation(
+ child_relation=cls(*args, **kwargs),
+ read_only=kwargs.get('read_only', False)
+ )
+ return super(RelatedField, cls).__new__(cls, *args, **kwargs)
- def field_from_native(self, data, files, field_name, into):
- if self.read_only:
- return
+ def get_queryset(self):
+ queryset = self.queryset
+ if isinstance(queryset, QuerySet):
+ # Ensure queryset is re-evaluated whenever used.
+ queryset = queryset.all()
+ return queryset
- try:
- if self.many:
- try:
- # Form data
- value = data.getlist(field_name)
- if value == [''] or value == []:
- raise KeyError
- except AttributeError:
- # Non-form data
- value = data[field_name]
- else:
- value = data[field_name]
- except KeyError:
- if self.partial:
- return
- value = self.get_default_value()
-
- if value in self.null_values:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- into[(self.source or field_name)] = None
- elif self.many:
- into[(self.source or field_name)] = [self.from_native(item) for item in value]
- else:
- into[(self.source or field_name)] = self.from_native(value)
-
-
-# PrimaryKey relationships
-class PrimaryKeyRelatedField(RelatedField):
+class StringRelatedField(Field):
"""
- Represents a relationship as a pk value.
+ A read only field that represents its targets using their
+ plain string representation.
"""
- read_only = False
-
- default_error_messages = {
- 'does_not_exist': _("Invalid pk '%s' - object does not exist."),
- 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'),
- }
-
- # TODO: Remove these field hacks...
- def prepare_value(self, obj):
- return self.to_native(obj.pk)
-
- def label_from_instance(self, obj):
- """
- Return a readable representation for use with eg. select widgets.
- """
- desc = smart_text(obj)
- ident = smart_text(self.to_native(obj.pk))
- if desc == ident:
- return desc
- return "%s - %s" % (desc, ident)
-
- # TODO: Possibly change this to just take `obj`, through prob less performant
- def to_native(self, pk):
- return pk
- def from_native(self, data):
- if self.queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
+ def __init__(self, **kwargs):
+ kwargs['read_only'] = True
+ super(StringRelatedField, self).__init__(**kwargs)
- try:
- return self.queryset.get(pk=data)
- except ObjectDoesNotExist:
- msg = self.error_messages['does_not_exist'] % smart_text(data)
- raise ValidationError(msg)
- except (TypeError, ValueError):
- received = type(data).__name__
- msg = self.error_messages['incorrect_type'] % received
- raise ValidationError(msg)
-
- def field_to_native(self, obj, field_name):
- if self.many:
- # To-many relationship
-
- queryset = None
- if not self.source:
- # Prefer obj.serializable_value for performance reasons
- try:
- queryset = obj.serializable_value(field_name)
- except AttributeError:
- pass
- if queryset is None:
- # RelatedManager (reverse relationship)
- source = self.source or field_name
- queryset = obj
- for component in source.split('.'):
- if queryset is None:
- return []
- queryset = get_component(queryset, component)
-
- # Forward relationship
- if is_simple_callable(getattr(queryset, 'all', None)):
- return [self.to_native(item.pk) for item in queryset.all()]
- else:
- # Also support non-queryset iterables.
- # This allows us to also support plain lists of related items.
- return [self.to_native(item.pk) for item in queryset]
-
- # To-one relationship
- try:
- # Prefer obj.serializable_value for performance reasons
- pk = obj.serializable_value(self.source or field_name)
- except AttributeError:
- # RelatedObject (reverse relationship)
- try:
- pk = getattr(obj, self.source or field_name).pk
- except (ObjectDoesNotExist, AttributeError):
- return None
-
- # Forward relationship
- return self.to_native(pk)
+ def to_representation(self, value):
+ return str(value)
-# Slug relationships
-
-class SlugRelatedField(RelatedField):
- """
- Represents a relationship using a unique field on the target.
- """
- read_only = False
-
+class PrimaryKeyRelatedField(RelatedField):
default_error_messages = {
- 'does_not_exist': _("Object with %s=%s does not exist."),
- 'invalid': _('Invalid value.'),
+ 'required': 'This field is required.',
+ 'does_not_exist': "Invalid pk '{pk_value}' - object does not exist.",
+ 'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.',
}
- def __init__(self, *args, **kwargs):
- self.slug_field = kwargs.pop('slug_field', None)
- assert self.slug_field, 'slug_field is required'
- super(SlugRelatedField, self).__init__(*args, **kwargs)
-
- def to_native(self, obj):
- return getattr(obj, self.slug_field)
-
- def from_native(self, data):
- if self.queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
-
+ def to_internal_value(self, data):
try:
- return self.queryset.get(**{self.slug_field: data})
+ return self.get_queryset().get(pk=data)
except ObjectDoesNotExist:
- raise ValidationError(self.error_messages['does_not_exist'] %
- (self.slug_field, smart_text(data)))
+ self.fail('does_not_exist', pk_value=data)
except (TypeError, ValueError):
- msg = self.error_messages['invalid']
- raise ValidationError(msg)
+ self.fail('incorrect_type', data_type=type(data).__name__)
+ def to_representation(self, value):
+ return value.pk
-# Hyperlinked relationships
class HyperlinkedRelatedField(RelatedField):
- """
- Represents a relationship using hyperlinking.
- """
- read_only = False
lookup_field = 'pk'
default_error_messages = {
- 'no_match': _('Invalid hyperlink - No URL match'),
- 'incorrect_match': _('Invalid hyperlink - Incorrect URL match'),
- 'configuration_error': _('Invalid hyperlink due to configuration error'),
- 'does_not_exist': _("Invalid hyperlink - object does not exist."),
- 'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
+ 'required': 'This field is required.',
+ 'no_match': 'Invalid hyperlink - No URL match',
+ 'incorrect_match': 'Invalid hyperlink - Incorrect URL match.',
+ 'does_not_exist': 'Invalid hyperlink - Object does not exist.',
+ 'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',
}
- # These are all deprecated
- pk_url_kwarg = 'pk'
- slug_field = 'slug'
- slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
-
- def __init__(self, *args, **kwargs):
- try:
- self.view_name = kwargs.pop('view_name')
- except KeyError:
- raise ValueError("Hyperlinked field requires 'view_name' kwarg")
-
+ def __init__(self, view_name=None, **kwargs):
+ assert view_name is not None, 'The `view_name` argument is required.'
+ self.view_name = view_name
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
+ self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
self.format = kwargs.pop('format', None)
- # These are deprecated
- if 'pk_url_kwarg' in kwargs:
- msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- if 'slug_url_kwarg' in kwargs:
- msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- if 'slug_field' in kwargs:
- msg = 'slug_field is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
+ # We include these simply for dependancy injection in tests.
+ # We can't add them as class attributes or they would expect an
+ # implict `self` argument to be passed.
+ self.reverse = reverse
+ self.resolve = resolve
- self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
- self.slug_field = kwargs.pop('slug_field', self.slug_field)
- default_slug_kwarg = self.slug_url_kwarg or self.slug_field
- self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
+ super(HyperlinkedRelatedField, self).__init__(**kwargs)
- super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
+ def get_object(self, view_name, view_args, view_kwargs):
+ """
+ Return the object corresponding to a matched URL.
+
+ Takes the matched URL conf arguments, and should return an
+ object instance, or raise an `ObjectDoesNotExist` exception.
+ """
+ lookup_value = view_kwargs[self.lookup_url_kwarg]
+ lookup_kwargs = {self.lookup_field: lookup_value}
+ return self.get_queryset().get(**lookup_kwargs)
def get_url(self, obj, view_name, request, format):
"""
@@ -359,176 +115,48 @@ class HyperlinkedRelatedField(RelatedField):
May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
attributes are not configured to correctly match the URL conf.
"""
- lookup_field = getattr(obj, self.lookup_field)
- kwargs = {self.lookup_field: lookup_field}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- if self.pk_url_kwarg != 'pk':
- # Only try pk if it has been explicitly set.
- # Otherwise, the default `lookup_field = 'pk'` has us covered.
- pk = obj.pk
- kwargs = {self.pk_url_kwarg: pk}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- slug = getattr(obj, self.slug_field, None)
- if slug is not None:
- # Only try slug if it corresponds to an attribute on the object.
- kwargs = {self.slug_url_kwarg: slug}
- try:
- ret = reverse(view_name, kwargs=kwargs, request=request, format=format)
- if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug':
- # If the lookup succeeds using the default slug params,
- # then `slug_field` is being used implicitly, and we
- # we need to warn about the pending deprecation.
- msg = 'Implicit slug field hyperlinked fields are deprecated.' \
- 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- return ret
- except NoReverseMatch:
- pass
-
- raise NoReverseMatch()
-
- def get_object(self, queryset, view_name, view_args, view_kwargs):
- """
- Return the object corresponding to a matched URL.
-
- Takes the matched URL conf arguments, and the queryset, and should
- return an object instance, or raise an `ObjectDoesNotExist` exception.
- """
- lookup = view_kwargs.get(self.lookup_field, None)
- pk = view_kwargs.get(self.pk_url_kwarg, None)
- slug = view_kwargs.get(self.slug_url_kwarg, None)
-
- if lookup is not None:
- filter_kwargs = {self.lookup_field: lookup}
- elif pk is not None:
- filter_kwargs = {'pk': pk}
- elif slug is not None:
- filter_kwargs = {self.slug_field: slug}
- else:
- raise ObjectDoesNotExist()
-
- return queryset.get(**filter_kwargs)
-
- def to_native(self, obj):
- view_name = self.view_name
- request = self.context.get('request', None)
- format = self.format or self.context.get('format', None)
-
- assert request is not None, (
- "`HyperlinkedRelatedField` requires the request in the serializer "
- "context. Add `context={'request': request}` when instantiating "
- "the serializer."
- )
-
- # If the object has not yet been saved then we cannot hyperlink to it.
- if getattr(obj, 'pk', None) is None:
- return
-
- # Return the hyperlink, or error if incorrectly configured.
- try:
- return self.get_url(obj, view_name, request, format)
- except NoReverseMatch:
- msg = (
- 'Could not resolve URL for hyperlinked relationship using '
- 'view name "%s". You may have failed to include the related '
- 'model in your API, or incorrectly configured the '
- '`lookup_field` attribute on this field.'
- )
- raise Exception(msg % view_name)
+ # Unsaved objects will not yet have a valid URL.
+ if obj.pk is None:
+ return None
- def from_native(self, value):
- # Convert URL -> model instance pk
- # TODO: Use values_list
- queryset = self.queryset
- if queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
+ lookup_value = getattr(obj, self.lookup_field)
+ kwargs = {self.lookup_url_kwarg: lookup_value}
+ return self.reverse(view_name, kwargs=kwargs, request=request, format=format)
+ def to_internal_value(self, data):
try:
- http_prefix = value.startswith(('http:', 'https:'))
+ http_prefix = data.startswith(('http:', 'https:'))
except AttributeError:
- msg = self.error_messages['incorrect_type']
- raise ValidationError(msg % type(value).__name__)
+ self.fail('incorrect_type', data_type=type(data).__name__)
if http_prefix:
# If needed convert absolute URLs to relative path
- value = urlparse.urlparse(value).path
+ data = urlparse.urlparse(data).path
prefix = get_script_prefix()
- if value.startswith(prefix):
- value = '/' + value[len(prefix):]
+ if data.startswith(prefix):
+ data = '/' + data[len(prefix):]
try:
- match = resolve(value)
- except Exception:
- raise ValidationError(self.error_messages['no_match'])
+ match = self.resolve(data)
+ except Resolver404:
+ self.fail('no_match')
if match.view_name != self.view_name:
- raise ValidationError(self.error_messages['incorrect_match'])
+ self.fail('incorrect_match')
try:
- return self.get_object(queryset, match.view_name,
- match.args, match.kwargs)
+ return self.get_object(match.view_name, match.args, match.kwargs)
except (ObjectDoesNotExist, TypeError, ValueError):
- raise ValidationError(self.error_messages['does_not_exist'])
+ self.fail('does_not_exist')
-
-class HyperlinkedIdentityField(Field):
- """
- Represents the instance, or a property on the instance, using hyperlinking.
- """
- lookup_field = 'pk'
- read_only = True
-
- # These are all deprecated
- pk_url_kwarg = 'pk'
- slug_field = 'slug'
- slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
-
- def __init__(self, *args, **kwargs):
- try:
- self.view_name = kwargs.pop('view_name')
- except KeyError:
- msg = "HyperlinkedIdentityField requires 'view_name' argument"
- raise ValueError(msg)
-
- self.format = kwargs.pop('format', None)
- lookup_field = kwargs.pop('lookup_field', None)
- self.lookup_field = lookup_field or self.lookup_field
-
- # These are deprecated
- if 'pk_url_kwarg' in kwargs:
- msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- if 'slug_url_kwarg' in kwargs:
- msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- if 'slug_field' in kwargs:
- msg = 'slug_field is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
-
- self.slug_field = kwargs.pop('slug_field', self.slug_field)
- default_slug_kwarg = self.slug_url_kwarg or self.slug_field
- self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
- self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
-
- super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
-
- def field_to_native(self, obj, field_name):
+ def to_representation(self, value):
request = self.context.get('request', None)
format = self.context.get('format', None)
- view_name = self.view_name
assert request is not None, (
- "`HyperlinkedIdentityField` requires the request in the serializer"
+ "`%s` requires the request in the serializer"
" context. Add `context={'request': request}` when instantiating "
- "the serializer."
+ "the serializer." % self.__class__.__name__
)
# By default use whatever format is given for the current context
@@ -545,7 +173,7 @@ class HyperlinkedIdentityField(Field):
# Return the hyperlink, or error if incorrectly configured.
try:
- return self.get_url(obj, view_name, request, format)
+ return self.get_url(value, self.view_name, request, format)
except NoReverseMatch:
msg = (
'Could not resolve URL for hyperlinked relationship using '
@@ -553,43 +181,81 @@ class HyperlinkedIdentityField(Field):
'model in your API, or incorrectly configured the '
'`lookup_field` attribute on this field.'
)
- raise Exception(msg % view_name)
+ raise ImproperlyConfigured(msg % self.view_name)
- def get_url(self, obj, view_name, request, format):
- """
- Given an object, return the URL that hyperlinks to the object.
- May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
- attributes are not configured to correctly match the URL conf.
- """
- lookup_field = getattr(obj, self.lookup_field, None)
- kwargs = {self.lookup_field: lookup_field}
+class HyperlinkedIdentityField(HyperlinkedRelatedField):
+ """
+ A read-only field that represents the identity URL for an object, itself.
- # Handle unsaved object case
- if lookup_field is None:
- return None
+ This is in contrast to `HyperlinkedRelatedField` which represents the
+ URL of relationships to other objects.
+ """
+
+ def __init__(self, view_name=None, **kwargs):
+ assert view_name is not None, 'The `view_name` argument is required.'
+ kwargs['read_only'] = True
+ kwargs['source'] = '*'
+ super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
+
+
+class SlugRelatedField(RelatedField):
+ """
+ A read-write field the represents the target of the relationship
+ by a unique 'slug' attribute.
+ """
+
+ default_error_messages = {
+ 'does_not_exist': _("Object with {slug_name}={value} does not exist."),
+ 'invalid': _('Invalid value.'),
+ }
+ def __init__(self, slug_field=None, **kwargs):
+ assert slug_field is not None, 'The `slug_field` argument is required.'
+ self.slug_field = slug_field
+ super(SlugRelatedField, self).__init__(**kwargs)
+
+ def to_internal_value(self, data):
try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- if self.pk_url_kwarg != 'pk':
- # Only try pk lookup if it has been explicitly set.
- # Otherwise, the default `lookup_field = 'pk'` has us covered.
- kwargs = {self.pk_url_kwarg: obj.pk}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- slug = getattr(obj, self.slug_field, None)
- if slug:
- # Only use slug lookup if a slug field exists on the model
- kwargs = {self.slug_url_kwarg: slug}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- raise NoReverseMatch()
+ return self.get_queryset().get(**{self.slug_field: data})
+ except ObjectDoesNotExist:
+ self.fail('does_not_exist', slug_name=self.slug_field, value=smart_text(data))
+ except (TypeError, ValueError):
+ self.fail('invalid')
+
+ def to_representation(self, obj):
+ return getattr(obj, self.slug_field)
+
+
+class ManyRelation(Field):
+ """
+ Relationships with `many=True` transparently get coerced into instead being
+ a ManyRelation with a child relationship.
+
+ The `ManyRelation` class is responsible for handling iterating through
+ the values and passing each one to the child relationship.
+
+ You shouldn't need to be using this class directly yourself.
+ """
+
+ def __init__(self, child_relation=None, *args, **kwargs):
+ self.child_relation = child_relation
+ assert child_relation is not None, '`child_relation` is a required argument.'
+ super(ManyRelation, self).__init__(*args, **kwargs)
+
+ def bind(self, field_name, parent, root):
+ # ManyRelation needs to provide the current context to the child relation.
+ super(ManyRelation, self).bind(field_name, parent, root)
+ self.child_relation.bind(field_name, parent, root)
+
+ def to_internal_value(self, data):
+ return [
+ self.child_relation.to_internal_value(item)
+ for item in data
+ ]
+
+ def to_representation(self, obj):
+ return [
+ self.child_relation.to_representation(value)
+ for value in obj.all()
+ ]
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 748ebac9..3bf03e62 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -26,6 +26,10 @@ from rest_framework.utils.breadcrumbs import get_breadcrumbs
from rest_framework import exceptions, status, VERSION
+def zero_as_none(value):
+ return None if value == 0 else value
+
+
class BaseRenderer(object):
"""
All renderers should extend this class, setting the `media_type`
@@ -44,13 +48,13 @@ class BaseRenderer(object):
class JSONRenderer(BaseRenderer):
"""
Renderer which serializes to JSON.
- Applies JSON's backslash-u character escaping for non-ascii characters.
"""
media_type = 'application/json'
format = 'json'
encoder_class = encoders.JSONEncoder
- ensure_ascii = True
+ ensure_ascii = not api_settings.UNICODE_JSON
+ compact = api_settings.COMPACT_JSON
# We don't set a charset because JSON is a binary encoding,
# that can be encoded as utf-8, utf-16 or utf-32.
@@ -62,9 +66,10 @@ class JSONRenderer(BaseRenderer):
if accepted_media_type:
# If the media type looks like 'application/json; indent=4',
# then pretty print the result.
+ # Note that we coerce `indent=0` into `indent=None`.
base_media_type, params = parse_header(accepted_media_type.encode('ascii'))
try:
- return max(min(int(params['indent']), 8), 0)
+ return zero_as_none(max(min(int(params['indent']), 8), 0))
except (KeyError, ValueError, TypeError):
pass
@@ -81,10 +86,12 @@ class JSONRenderer(BaseRenderer):
renderer_context = renderer_context or {}
indent = self.get_indent(accepted_media_type, renderer_context)
+ separators = (',', ':') if (indent is None and self.compact) else (', ', ': ')
ret = json.dumps(
data, cls=self.encoder_class,
- indent=indent, ensure_ascii=self.ensure_ascii
+ indent=indent, ensure_ascii=self.ensure_ascii,
+ separators=separators
)
# On python 2.x json.dumps() returns bytestrings if ensure_ascii=True,
@@ -96,14 +103,6 @@ class JSONRenderer(BaseRenderer):
return ret
-class UnicodeJSONRenderer(JSONRenderer):
- ensure_ascii = False
- """
- Renderer which serializes to JSON.
- Does *not* apply JSON's character escaping for non-ascii characters.
- """
-
-
class JSONPRenderer(JSONRenderer):
"""
Renderer which serializes to json,
@@ -196,7 +195,7 @@ class YAMLRenderer(BaseRenderer):
format = 'yaml'
encoder = encoders.SafeDumper
charset = 'utf-8'
- ensure_ascii = True
+ ensure_ascii = False
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -210,14 +209,6 @@ class YAMLRenderer(BaseRenderer):
return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii)
-class UnicodeYAMLRenderer(YAMLRenderer):
- """
- Renderer which serializes to YAML.
- Does *not* apply character escaping for non-ascii characters.
- """
- ensure_ascii = False
-
-
class TemplateHTMLRenderer(BaseRenderer):
"""
An HTML renderer for use with templates.
@@ -436,13 +427,13 @@ class BrowsableAPIRenderer(BaseRenderer):
if request.method == method:
try:
data = request.DATA
- files = request.FILES
+ # files = request.FILES
except ParseError:
data = None
- files = None
+ # files = None
else:
data = None
- files = None
+ # files = None
with override_method(view, request, method) as request:
obj = getattr(view, 'object', None)
@@ -458,7 +449,7 @@ class BrowsableAPIRenderer(BaseRenderer):
):
return
- serializer = view.get_serializer(instance=obj, data=data, files=files)
+ serializer = view.get_serializer(instance=obj, data=data)
serializer.is_valid()
data = serializer.data
@@ -579,10 +570,10 @@ class BrowsableAPIRenderer(BaseRenderer):
'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
'response_headers': response_headers,
- 'put_form': self.get_rendered_html_form(view, 'PUT', request),
- 'post_form': self.get_rendered_html_form(view, 'POST', request),
- 'delete_form': self.get_rendered_html_form(view, 'DELETE', request),
- 'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),
+ # 'put_form': self.get_rendered_html_form(view, 'PUT', request),
+ # 'post_form': self.get_rendered_html_form(view, 'POST', request),
+ # 'delete_form': self.get_rendered_html_form(view, 'DELETE', request),
+ # 'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),
'raw_data_put_form': raw_data_put_form,
'raw_data_post_form': raw_data_post_form,
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index ae56673d..f2d06211 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -19,6 +19,7 @@ import itertools
from collections import namedtuple
from django.conf.urls import patterns, url
from django.core.exceptions import ImproperlyConfigured
+from django.core.urlresolvers import NoReverseMatch
from rest_framework import views
from rest_framework.response import Response
from rest_framework.reverse import reverse
@@ -284,10 +285,19 @@ class DefaultRouter(SimpleRouter):
class APIRoot(views.APIView):
_ignore_model_permissions = True
- def get(self, request, format=None):
+ def get(self, request, *args, **kwargs):
ret = {}
for key, url_name in api_root_dict.items():
- ret[key] = reverse(url_name, request=request, format=format)
+ try:
+ ret[key] = reverse(
+ url_name,
+ request=request,
+ format=kwargs.get('format', None)
+ )
+ except NoReverseMatch:
+ # Don't bail out if eg. no list routes exist, only detail routes.
+ continue
+
return Response(ret)
return APIRoot.as_view()
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index be8ad3f2..d2740fc2 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -10,21 +10,20 @@ python primitives.
2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
-from __future__ import unicode_literals
-import copy
-import datetime
-import inspect
-import types
-from decimal import Decimal
-from django.contrib.contenttypes.generic import GenericForeignKey
-from django.core.paginator import Page
+from django.core.exceptions import ImproperlyConfigured, ValidationError
from django.db import models
-from django.forms import widgets
from django.utils import six
from django.utils.datastructures import SortedDict
-from django.core.exceptions import ObjectDoesNotExist
+from collections import namedtuple
+from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
-
+from rest_framework.utils import html, model_meta, representation
+from rest_framework.utils.field_mapping import (
+ get_url_kwargs, get_field_kwargs,
+ get_relation_kwargs, get_nested_relation_kwargs,
+ lookup_class
+)
+import copy
# Note: We do the following so that users of the framework can use this style:
#
@@ -37,1107 +36,453 @@ from rest_framework.relations import * # NOQA
from rest_framework.fields import * # NOQA
-def _resolve_model(obj):
- """
- Resolve supplied `obj` to a Django model class.
+FieldResult = namedtuple('FieldResult', ['field', 'value', 'error'])
- `obj` must be a Django model class itself, or a string
- representation of one. Useful in situtations like GH #1225 where
- Django may not have resolved a string-based reference to a model in
- another model's foreign key definition.
- String representations should have the format:
- 'appname.ModelName'
+class BaseSerializer(Field):
+ """
+ The BaseSerializer class provides a minimal class which may be used
+ for writing custom serializer implementations.
"""
- if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
- app_name, model_name = obj.split('.')
- return models.get_model(app_name, model_name)
- elif inspect.isclass(obj) and issubclass(obj, models.Model):
- return obj
- else:
- raise ValueError("{0} is not a Django model".format(obj))
-
-
-def pretty_name(name):
- """Converts 'first_name' to 'First name'"""
- if not name:
- return ''
- return name.replace('_', ' ').capitalize()
+ def __init__(self, instance=None, data=None, **kwargs):
+ super(BaseSerializer, self).__init__(**kwargs)
+ self.instance = instance
+ self._initial_data = data
-class RelationsList(list):
- _deleted = []
+ def to_internal_value(self, data):
+ raise NotImplementedError('`to_internal_value()` must be implemented.')
+ def to_representation(self, instance):
+ raise NotImplementedError('`to_representation()` must be implemented.')
-class NestedValidationError(ValidationError):
- """
- The default ValidationError behavior is to stringify each item in the list
- if the messages are a list of error messages.
+ def update(self, instance, attrs):
+ raise NotImplementedError('`update()` must be implemented.')
- In the case of nested serializers, where the parent has many children,
- then the child's `serializer.errors` will be a list of dicts. In the case
- of a single child, the `serializer.errors` will be a dict.
+ def create(self, attrs):
+ raise NotImplementedError('`create()` must be implemented.')
- We need to override the default behavior to get properly nested error dicts.
- """
+ def save(self, extras=None):
+ attrs = self.validated_data
+ if extras is not None:
+ attrs = dict(list(attrs.items()) + list(extras.items()))
- def __init__(self, message):
- if isinstance(message, dict):
- self._messages = [message]
+ if self.instance is not None:
+ self.update(self.instance, attrs)
else:
- self._messages = message
+ self.instance = self.create(attrs)
- @property
- def messages(self):
- return self._messages
+ return self.instance
+ def is_valid(self, raise_exception=False):
+ if not hasattr(self, '_validated_data'):
+ try:
+ self._validated_data = self.to_internal_value(self._initial_data)
+ except ValidationError as exc:
+ self._validated_data = {}
+ self._errors = exc.message_dict
+ else:
+ self._errors = {}
-class DictWithMetadata(dict):
- """
- A dict-like object, that can have additional properties attached.
- """
- def __getstate__(self):
- """
- Used by pickle (e.g., caching).
- Overridden to remove the metadata from the dict, since it shouldn't be
- pickled and may in some instances be unpickleable.
- """
- return dict(self)
+ if self._errors and raise_exception:
+ raise ValidationError(self._errors)
+ return not bool(self._errors)
-class SortedDictWithMetadata(SortedDict):
- """
- A sorted dict-like object, that can have additional properties attached.
- """
- def __getstate__(self):
- """
- Used by pickle (e.g., caching).
- Overriden to remove the metadata from the dict, since it shouldn't be
- pickle and may in some instances be unpickleable.
- """
- return SortedDict(self).__dict__
+ @property
+ def data(self):
+ if not hasattr(self, '_data'):
+ if self.instance is not None:
+ self._data = self.to_representation(self.instance)
+ elif self._initial_data is not None:
+ self._data = dict([
+ (field_name, field.get_value(self._initial_data))
+ for field_name, field in self.fields.items()
+ ])
+ else:
+ self._data = self.get_initial()
+ return self._data
+ @property
+ def errors(self):
+ if not hasattr(self, '_errors'):
+ msg = 'You must call `.is_valid()` before accessing `.errors`.'
+ raise AssertionError(msg)
+ return self._errors
-def _is_protected_type(obj):
- """
- True if the object is a native datatype that does not need to
- be serialized further.
- """
- return isinstance(obj, (
- types.NoneType,
- int, long,
- datetime.datetime, datetime.date, datetime.time,
- float, Decimal,
- basestring)
- )
+ @property
+ def validated_data(self):
+ if not hasattr(self, '_validated_data'):
+ msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
+ raise AssertionError(msg)
+ return self._validated_data
-def _get_declared_fields(bases, attrs):
+class SerializerMetaclass(type):
"""
- Create a list of serializer field instances from the passed in 'attrs',
- plus any fields on the base classes (in 'bases').
+ This metaclass sets a dictionary named `base_fields` on the class.
- Note that all fields from the base classes are used.
+ Any instances of `Field` included as attributes on either the class
+ or on any of its superclasses will be include in the
+ `base_fields` dictionary.
"""
- fields = [(field_name, attrs.pop(field_name))
- for field_name, obj in list(six.iteritems(attrs))
- if isinstance(obj, Field)]
- fields.sort(key=lambda x: x[1].creation_counter)
- # If this class is subclassing another Serializer, add that Serializer's
- # fields. Note that we loop over the bases in *reverse*. This is necessary
- # in order to maintain the correct order of fields.
- for base in bases[::-1]:
- if hasattr(base, 'base_fields'):
- fields = list(base.base_fields.items()) + fields
+ @classmethod
+ def _get_declared_fields(cls, bases, attrs):
+ fields = [(field_name, attrs.pop(field_name))
+ for field_name, obj in list(attrs.items())
+ if isinstance(obj, Field)]
+ fields.sort(key=lambda x: x[1]._creation_counter)
- return SortedDict(fields)
+ # If this class is subclassing another Serializer, add that Serializer's
+ # fields. Note that we loop over the bases in *reverse*. This is necessary
+ # in order to maintain the correct order of fields.
+ for base in bases[::-1]:
+ if hasattr(base, '_declared_fields'):
+ fields = list(base._declared_fields.items()) + fields
+ return SortedDict(fields)
-class SerializerMetaclass(type):
def __new__(cls, name, bases, attrs):
- attrs['base_fields'] = _get_declared_fields(bases, attrs)
+ attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
-class SerializerOptions(object):
- """
- Meta class options for Serializer
- """
- def __init__(self, meta):
- self.depth = getattr(meta, 'depth', 0)
- self.fields = getattr(meta, 'fields', ())
- self.exclude = getattr(meta, 'exclude', ())
-
-
-class BaseSerializer(WritableField):
- """
- This is the Serializer implementation.
- We need to implement it as `BaseSerializer` due to metaclass magicks.
- """
- class Meta(object):
- pass
-
- _options_class = SerializerOptions
- _dict_class = SortedDictWithMetadata
-
- def __init__(self, instance=None, data=None, files=None,
- context=None, partial=False, many=False,
- allow_add_remove=False, **kwargs):
- super(BaseSerializer, self).__init__(**kwargs)
- self.opts = self._options_class(self.Meta)
- self.parent = None
- self.root = None
- self.partial = partial
- self.many = many
- self.allow_add_remove = allow_add_remove
-
- self.context = context or {}
-
- self.init_data = data
- self.init_files = files
- self.object = instance
- self.fields = self.get_fields()
-
- self._data = None
- self._files = None
- self._errors = None
-
- if many and instance is not None and not hasattr(instance, '__iter__'):
- raise ValueError('instance should be a queryset or other iterable with many=True')
-
- if allow_add_remove and not many:
- raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')
-
- #####
- # Methods to determine which fields to use when (de)serializing objects.
-
- def get_default_fields(self):
- """
- Return the complete set of default fields for the object, as a dict.
- """
- return {}
-
- def get_fields(self):
- """
- Returns the complete set of fields for the object as a dict.
-
- This will be the set of any explicitly declared fields,
- plus the set of fields returned by get_default_fields().
- """
- ret = SortedDict()
-
- # Get the explicitly declared fields
- base_fields = copy.deepcopy(self.base_fields)
- for key, field in base_fields.items():
- ret[key] = field
-
- # Add in the default fields
- default_fields = self.get_default_fields()
- for key, val in default_fields.items():
- if key not in ret:
- ret[key] = val
-
- # If 'fields' is specified, use those fields, in that order.
- if self.opts.fields:
- assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple'
- new = SortedDict()
- for key in self.opts.fields:
- new[key] = ret[key]
- ret = new
-
- # Remove anything in 'exclude'
- if self.opts.exclude:
- assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple'
- for key in self.opts.exclude:
- ret.pop(key, None)
-
- for key, field in ret.items():
- field.initialize(parent=self, field_name=key)
-
- return ret
-
- #####
- # Methods to convert or revert from objects <--> primitive representations.
-
- def get_field_key(self, field_name):
- """
- Return the key that should be used for a given field.
- """
- return field_name
+@six.add_metaclass(SerializerMetaclass)
+class Serializer(BaseSerializer):
+ def __init__(self, *args, **kwargs):
+ self.context = kwargs.pop('context', {})
+ kwargs.pop('partial', None)
+ kwargs.pop('many', None)
- def restore_fields(self, data, files):
- """
- Core of deserialization, together with `restore_object`.
- Converts a dictionary of data into a dictionary of deserialized fields.
- """
- reverted_data = {}
+ super(Serializer, self).__init__(*args, **kwargs)
- if data is not None and not isinstance(data, dict):
- self._errors['non_field_errors'] = ['Invalid data']
- return None
+ # Every new serializer is created with a clone of the field instances.
+ # This allows users to dynamically modify the fields on a serializer
+ # instance without affecting every other serializer class.
+ self.fields = self._get_base_fields()
+ # Setup all the child fields, to provide them with the current context.
for field_name, field in self.fields.items():
- field.initialize(parent=self, field_name=field_name)
- try:
- field.field_from_native(data, files, field_name, reverted_data)
- except ValidationError as err:
- self._errors[field_name] = list(err.messages)
-
- return reverted_data
-
- def perform_validation(self, attrs):
- """
- Run `validate_<fieldname>()` and `validate()` methods on the serializer
- """
+ field.bind(field_name, self, self)
+
+ def __new__(cls, *args, **kwargs):
+ # We override this method in order to automagically create
+ # `ListSerializer` classes instead when `many=True` is set.
+ if kwargs.pop('many', False):
+ kwargs['child'] = cls()
+ return ListSerializer(*args, **kwargs)
+ return super(Serializer, cls).__new__(cls, *args, **kwargs)
+
+ def _get_base_fields(self):
+ return copy.deepcopy(self._declared_fields)
+
+ def bind(self, field_name, parent, root):
+ # If the serializer is used as a field then when it becomes bound
+ # it also needs to bind all its child fields.
+ super(Serializer, self).bind(field_name, parent, root)
for field_name, field in self.fields.items():
- if field_name in self._errors:
- continue
+ field.bind(field_name, self, root)
- source = field.source or field_name
- if self.partial and source not in attrs:
- continue
- try:
- validate_method = getattr(self, 'validate_%s' % field_name, None)
- if validate_method:
- attrs = validate_method(attrs, source)
- except ValidationError as err:
- self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
-
- # If there are already errors, we don't run .validate() because
- # field-validation failed and thus `attrs` may not be complete.
- # which in turn can cause inconsistent validation errors.
- if not self._errors:
- try:
- attrs = self.validate(attrs)
- except ValidationError as err:
- if hasattr(err, 'message_dict'):
- for field_name, error_messages in err.message_dict.items():
- self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages)
- elif hasattr(err, 'messages'):
- self._errors['non_field_errors'] = err.messages
+ def get_initial(self):
+ return dict([
+ (field.field_name, field.get_initial())
+ for field in self.fields.values()
+ ])
- return attrs
+ def get_value(self, dictionary):
+ # We override the default field access in order to support
+ # nested HTML forms.
+ if html.is_html_input(dictionary):
+ return html.parse_html_dict(dictionary, prefix=self.field_name)
+ return dictionary.get(self.field_name, empty)
- def validate(self, attrs):
+ def to_internal_value(self, data):
"""
- Stub method, to be overridden in Serializer subclasses
+ Dict of native values <- Dict of primitive datatypes.
"""
- return attrs
+ if not isinstance(data, dict):
+ raise ValidationError({
+ api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
+ })
- def restore_object(self, attrs, instance=None):
- """
- Deserialize a dictionary of attributes into an object instance.
- You should override this method to control how deserialized objects
- are instantiated.
- """
- if instance is not None:
- instance.update(attrs)
- return instance
- return attrs
-
- def to_native(self, obj):
- """
- Serialize objects -> primitives.
- """
- ret = self._dict_class()
- ret.fields = self._dict_class()
+ ret = {}
+ errors = {}
+ fields = [field for field in self.fields.values() if not field.read_only]
- for field_name, field in self.fields.items():
- if field.read_only and obj is None:
- continue
- field.initialize(parent=self, field_name=field_name)
- key = self.get_field_key(field_name)
- value = field.field_to_native(obj, field_name)
- method = getattr(self, 'transform_%s' % field_name, None)
- if callable(method):
- value = method(obj, value)
- if not getattr(field, 'write_only', False):
- ret[key] = value
- ret.fields[key] = self.augment_field(field, field_name, key, value)
-
- return ret
-
- def from_native(self, data, files=None):
- """
- Deserialize primitives -> objects.
- """
- self._errors = {}
-
- if data is not None or files is not None:
- attrs = self.restore_fields(data, files)
- if attrs is not None:
- attrs = self.perform_validation(attrs)
- else:
- self._errors['non_field_errors'] = ['No input provided']
-
- if not self._errors:
- return self.restore_object(attrs, instance=getattr(self, 'object', None))
-
- def augment_field(self, field, field_name, key, value):
- # This horrible stuff is to manage serializers rendering to HTML
- field._errors = self._errors.get(key) if self._errors else None
- field._name = field_name
- field._value = self.init_data.get(key) if self._errors and self.init_data else value
- if not field.label:
- field.label = pretty_name(key)
- return field
-
- def field_to_native(self, obj, field_name):
- """
- Override default so that the serializer can be used as a nested field
- across relationships.
- """
- if self.write_only:
- return None
+ for field in fields:
+ validate_method = getattr(self, 'validate_' + field.field_name, None)
+ primitive_value = field.get_value(data)
+ try:
+ validated_value = field.run_validation(primitive_value)
+ if validate_method is not None:
+ validated_value = validate_method(validated_value)
+ except ValidationError as exc:
+ errors[field.field_name] = exc.messages
+ except SkipField:
+ pass
+ else:
+ set_value(ret, field.source_attrs, validated_value)
- if self.source == '*':
- return self.to_native(obj)
+ if errors:
+ raise ValidationError(errors)
- # Get the raw field value
try:
- source = self.source or field_name
- value = obj
-
- for component in source.split('.'):
- if value is None:
- break
- value = get_component(value, component)
- except ObjectDoesNotExist:
- return None
+ return self.validate(ret)
+ except ValidationError as exc:
+ raise ValidationError({
+ api_settings.NON_FIELD_ERRORS_KEY: exc.messages
+ })
- if is_simple_callable(getattr(value, 'all', None)):
- return [self.to_native(item) for item in value.all()]
-
- if value is None:
- return None
-
- if self.many:
- return [self.to_native(item) for item in value]
- return self.to_native(value)
-
- def field_from_native(self, data, files, field_name, into):
- """
- Override default so that the serializer can be used as a writable
- nested field across relationships.
+ def to_representation(self, instance):
"""
- if self.read_only:
- return
-
- try:
- value = data[field_name]
- except KeyError:
- if self.default is not None and not self.partial:
- # Note: partial updates shouldn't set defaults
- value = copy.deepcopy(self.default)
- else:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- return
-
- if self.source == '*':
- if value:
- reverted_data = self.restore_fields(value, {})
- if not self._errors:
- into.update(reverted_data)
- else:
- if value in (None, ''):
- into[(self.source or field_name)] = None
- else:
- # Set the serializer object if it exists
- obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
-
- # If we have a model manager or similar object then we need
- # to iterate through each instance.
- if (
- self.many and
- not hasattr(obj, '__iter__') and
- is_simple_callable(getattr(obj, 'all', None))
- ):
- obj = obj.all()
-
- kwargs = {
- 'instance': obj,
- 'data': value,
- 'context': self.context,
- 'partial': self.partial,
- 'many': self.many,
- 'allow_add_remove': self.allow_add_remove
- }
- serializer = self.__class__(**kwargs)
-
- if serializer.is_valid():
- into[self.source or field_name] = serializer.object
- else:
- # Propagate errors up to our parent
- raise NestedValidationError(serializer.errors)
-
- def get_identity(self, data):
+ Object instance -> Dict of primitive datatypes.
"""
- This hook is required for bulk update.
- It is used to determine the canonical identity of a given object.
+ ret = SortedDict()
+ fields = [field for field in self.fields.values() if not field.write_only]
- Note that the data has not been validated at this point, so we need
- to make sure that we catch any cases of incorrect datatypes being
- passed to this method.
- """
- try:
- return data.get('id', None)
- except AttributeError:
- return None
+ for field in fields:
+ native_value = field.get_attribute(instance)
+ ret[field.field_name] = field.to_representation(native_value)
- @property
- def errors(self):
- """
- Run deserialization and return error data,
- setting self.object if no errors occurred.
- """
- if self._errors is None:
- data, files = self.init_data, self.init_files
+ return ret
- if self.many is not None:
- many = self.many
- else:
- many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))
- if many:
- warnings.warn('Implicit list/queryset serialization is deprecated. '
- 'Use the `many=True` flag when instantiating the serializer.',
- DeprecationWarning, stacklevel=3)
-
- if many:
- ret = RelationsList()
- errors = []
- update = self.object is not None
-
- if update:
- # If this is a bulk update we need to map all the objects
- # to a canonical identity so we can determine which
- # individual object is being updated for each item in the
- # incoming data
- objects = self.object
- identities = [self.get_identity(self.to_native(obj)) for obj in objects]
- identity_to_objects = dict(zip(identities, objects))
-
- if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)):
- for item in data:
- if update:
- # Determine which object we're updating
- identity = self.get_identity(item)
- self.object = identity_to_objects.pop(identity, None)
- if self.object is None and not self.allow_add_remove:
- ret.append(None)
- errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
- continue
-
- ret.append(self.from_native(item, None))
- errors.append(self._errors)
-
- if update and self.allow_add_remove:
- ret._deleted = identity_to_objects.values()
-
- self._errors = any(errors) and errors or []
- else:
- self._errors = {'non_field_errors': ['Expected a list of items.']}
- else:
- ret = self.from_native(data, files)
+ def validate(self, attrs):
+ return attrs
- if not self._errors:
- self.object = ret
+ def __iter__(self):
+ errors = self.errors if hasattr(self, '_errors') else {}
+ for field in self.fields.values():
+ value = self.data.get(field.field_name) if self.data else None
+ error = errors.get(field.field_name)
+ yield FieldResult(field, value, error)
- return self._errors
+ def __repr__(self):
+ return representation.serializer_repr(self, indent=1)
- def is_valid(self):
- return not self.errors
- @property
- def data(self):
- """
- Returns the serialized data on the serializer.
- """
- if self._data is None:
- obj = self.object
+class ListSerializer(BaseSerializer):
+ child = None
+ initial = []
- if self.many is not None:
- many = self.many
- else:
- many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
- if many:
- warnings.warn('Implicit list/queryset serialization is deprecated. '
- 'Use the `many=True` flag when instantiating the serializer.',
- DeprecationWarning, stacklevel=2)
-
- if many:
- self._data = [self.to_native(item) for item in obj]
- else:
- self._data = self.to_native(obj)
+ def __init__(self, *args, **kwargs):
+ self.child = kwargs.pop('child', copy.deepcopy(self.child))
+ assert self.child is not None, '`child` is a required argument.'
+ self.context = kwargs.pop('context', {})
+ kwargs.pop('partial', None)
- return self._data
+ super(ListSerializer, self).__init__(*args, **kwargs)
+ self.child.bind('', self, self)
- def save_object(self, obj, **kwargs):
- obj.save(**kwargs)
+ def bind(self, field_name, parent, root):
+ # If the list is used as a field then it needs to provide
+ # the current context to the child serializer.
+ super(ListSerializer, self).bind(field_name, parent, root)
+ self.child.bind(field_name, self, root)
- def delete_object(self, obj):
- obj.delete()
+ def get_value(self, dictionary):
+ # We override the default field access in order to support
+ # lists in HTML forms.
+ if is_html_input(dictionary):
+ return html.parse_html_list(dictionary, prefix=self.field_name)
+ return dictionary.get(self.field_name, empty)
- def save(self, **kwargs):
+ def to_internal_value(self, data):
"""
- Save the deserialized object and return it.
+ List of dicts of native values <- List of dicts of primitive datatypes.
"""
- # Clear cached _data, which may be invalidated by `save()`
- self._data = None
-
- if isinstance(self.object, list):
- [self.save_object(item, **kwargs) for item in self.object]
-
- if self.object._deleted:
- [self.delete_object(item) for item in self.object._deleted]
- else:
- self.save_object(self.object, **kwargs)
+ if html.is_html_input(data):
+ data = html.parse_html_list(data)
- return self.object
+ return [self.child.run_validation(item) for item in data]
- def metadata(self):
+ def to_representation(self, data):
"""
- Return a dictionary of metadata about the fields on the serializer.
- Useful for things like responding to OPTIONS requests, or generating
- API schemas for auto-documentation.
+ List of object instances -> List of dicts of primitive datatypes.
"""
- return SortedDict(
- [
- (field_name, field.metadata())
- for field_name, field in six.iteritems(self.fields)
- ]
- )
+ return [self.child.to_representation(item) for item in data]
+ def create(self, attrs_list):
+ return [self.child.create(attrs) for attrs in attrs_list]
-class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)):
- pass
+ def save(self):
+ if self.instance is not None:
+ self.update(self.instance, self.validated_data)
+ self.instance = self.create(self.validated_data)
+ return self.instance
-
-class ModelSerializerOptions(SerializerOptions):
- """
- Meta class options for ModelSerializer
- """
- def __init__(self, meta):
- super(ModelSerializerOptions, self).__init__(meta)
- self.model = getattr(meta, 'model', None)
- self.read_only_fields = getattr(meta, 'read_only_fields', ())
- self.write_only_fields = getattr(meta, 'write_only_fields', ())
+ def __repr__(self):
+ return representation.list_repr(self, indent=1)
class ModelSerializer(Serializer):
- """
- A serializer that deals with model instances and querysets.
- """
- _options_class = ModelSerializerOptions
-
- field_mapping = {
+ _field_mapping = {
models.AutoField: IntegerField,
+ models.BigIntegerField: IntegerField,
+ models.BooleanField: BooleanField,
+ models.CharField: CharField,
+ models.CommaSeparatedIntegerField: CharField,
+ models.DateField: DateField,
+ models.DateTimeField: DateTimeField,
+ models.DecimalField: DecimalField,
+ models.EmailField: EmailField,
+ models.Field: ModelField,
+ models.FileField: FileField,
models.FloatField: FloatField,
+ models.ImageField: ImageField,
models.IntegerField: IntegerField,
+ models.NullBooleanField: BooleanField,
models.PositiveIntegerField: IntegerField,
- models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
- models.DateTimeField: DateTimeField,
- models.DateField: DateField,
- models.TimeField: TimeField,
- models.DecimalField: DecimalField,
- models.EmailField: EmailField,
- models.CharField: CharField,
- models.URLField: URLField,
models.SlugField: SlugField,
+ models.SmallIntegerField: IntegerField,
models.TextField: CharField,
- models.CommaSeparatedIntegerField: CharField,
- models.BooleanField: BooleanField,
- models.NullBooleanField: BooleanField,
- models.FileField: FileField,
- models.ImageField: ImageField,
+ models.TimeField: TimeField,
+ models.URLField: URLField,
}
+ _related_class = PrimaryKeyRelatedField
- def get_default_fields(self):
- """
- Return all the fields that should be serialized for the model.
- """
+ def create(self, attrs):
+ ModelClass = self.Meta.model
- cls = self.opts.model
- assert cls is not None, (
- "Serializer class '%s' is missing 'model' Meta option" %
- self.__class__.__name__
- )
- opts = cls._meta.concrete_model._meta
- ret = SortedDict()
- nested = bool(self.opts.depth)
-
- # Deal with adding the primary key field
- pk_field = opts.pk
- while pk_field.rel and pk_field.rel.parent_link:
- # If model is a child via multitable inheritance, use parent's pk
- pk_field = pk_field.rel.to._meta.pk
-
- serializer_pk_field = self.get_pk_field(pk_field)
- if serializer_pk_field:
- ret[pk_field.name] = serializer_pk_field
-
- # Deal with forward relationships
- forward_rels = [field for field in opts.fields if field.serialize]
- forward_rels += [field for field in opts.many_to_many if field.serialize]
-
- for model_field in forward_rels:
- has_through_model = False
-
- if model_field.rel:
- to_many = isinstance(model_field,
- models.fields.related.ManyToManyField)
- related_model = _resolve_model(model_field.rel.to)
-
- if to_many and not model_field.rel.through._meta.auto_created:
- has_through_model = True
-
- if model_field.rel and nested:
- if len(inspect.getargspec(self.get_nested_field).args) == 2:
- warnings.warn(
- 'The `get_nested_field(model_field)` call signature '
- 'is deprecated. '
- 'Use `get_nested_field(model_field, related_model, '
- 'to_many) instead',
- DeprecationWarning
- )
- field = self.get_nested_field(model_field)
- else:
- field = self.get_nested_field(model_field, related_model, to_many)
- elif model_field.rel:
- if len(inspect.getargspec(self.get_nested_field).args) == 3:
- warnings.warn(
- 'The `get_related_field(model_field, to_many)` call '
- 'signature is deprecated. '
- 'Use `get_related_field(model_field, related_model, '
- 'to_many) instead',
- DeprecationWarning
- )
- field = self.get_related_field(model_field, to_many=to_many)
- else:
- field = self.get_related_field(model_field, related_model, to_many)
- else:
- field = self.get_field(model_field)
+ # Remove many-to-many relationships from attrs.
+ # They are not valid arguments to the default `.create()` method,
+ # as they require that the instance has already been saved.
+ info = model_meta.get_field_info(ModelClass)
+ many_to_many = {}
+ for field_name, relation_info in info.relations.items():
+ if relation_info.to_many and (field_name in attrs):
+ many_to_many[field_name] = attrs.pop(field_name)
- if field:
- if has_through_model:
- field.read_only = True
+ instance = ModelClass.objects.create(**attrs)
- ret[model_field.name] = field
+ # Save many-to-many relationships after the instance is created.
+ if many_to_many:
+ for field_name, value in many_to_many.items():
+ setattr(instance, field_name, value)
- # Deal with reverse relationships
- if not self.opts.fields:
- reverse_rels = []
- else:
- # Reverse relationships are only included if they are explicitly
- # present in the `fields` option on the serializer
- reverse_rels = opts.get_all_related_objects()
- reverse_rels += opts.get_all_related_many_to_many_objects()
-
- for relation in reverse_rels:
- accessor_name = relation.get_accessor_name()
- if not self.opts.fields or accessor_name not in self.opts.fields:
- continue
- related_model = relation.model
- to_many = relation.field.rel.multiple
- has_through_model = False
- is_m2m = isinstance(relation.field,
- models.fields.related.ManyToManyField)
-
- if (
- is_m2m and
- hasattr(relation.field.rel, 'through') and
- not relation.field.rel.through._meta.auto_created
- ):
- has_through_model = True
-
- if nested:
- field = self.get_nested_field(None, related_model, to_many)
- else:
- field = self.get_related_field(None, related_model, to_many)
-
- if field:
- if has_through_model:
- field.read_only = True
-
- ret[accessor_name] = field
-
- # Ensure that 'read_only_fields' is an iterable
- assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
-
- # Add the `read_only` flag to any fields that have been specified
- # in the `read_only_fields` option
- for field_name in self.opts.read_only_fields:
- assert field_name not in self.base_fields.keys(), (
- "field '%s' on serializer '%s' specified in "
- "`read_only_fields`, but also added "
- "as an explicit field. Remove it from `read_only_fields`." %
- (field_name, self.__class__.__name__))
- assert field_name in ret, (
- "Non-existant field '%s' specified in `read_only_fields` "
- "on serializer '%s'." %
- (field_name, self.__class__.__name__))
- ret[field_name].read_only = True
-
- # Ensure that 'write_only_fields' is an iterable
- assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
-
- for field_name in self.opts.write_only_fields:
- assert field_name not in self.base_fields.keys(), (
- "field '%s' on serializer '%s' specified in "
- "`write_only_fields`, but also added "
- "as an explicit field. Remove it from `write_only_fields`." %
- (field_name, self.__class__.__name__))
- assert field_name in ret, (
- "Non-existant field '%s' specified in `write_only_fields` "
- "on serializer '%s'." %
- (field_name, self.__class__.__name__))
- ret[field_name].write_only = True
-
- return ret
-
- def get_pk_field(self, model_field):
- """
- Returns a default instance of the pk field.
- """
- return self.get_field(model_field)
-
- def get_nested_field(self, model_field, related_model, to_many):
- """
- Creates a default instance of a nested relational field.
-
- Note that model_field will be `None` for reverse relationships.
- """
- class NestedModelSerializer(ModelSerializer):
- class Meta:
- model = related_model
- depth = self.opts.depth - 1
-
- return NestedModelSerializer(many=to_many)
-
- def get_related_field(self, model_field, related_model, to_many):
- """
- Creates a default instance of a flat relational field.
-
- Note that model_field will be `None` for reverse relationships.
- """
- # TODO: filter queryset using:
- # .using(db).complex_filter(self.rel.limit_choices_to)
-
- kwargs = {
- 'queryset': related_model._default_manager,
- 'many': to_many
- }
-
- if model_field:
- kwargs['required'] = not(model_field.null or model_field.blank)
- if model_field.help_text is not None:
- kwargs['help_text'] = model_field.help_text
- if model_field.verbose_name is not None:
- kwargs['label'] = model_field.verbose_name
-
- if not model_field.editable:
- kwargs['read_only'] = True
-
- if model_field.verbose_name is not None:
- kwargs['label'] = model_field.verbose_name
-
- if model_field.help_text is not None:
- kwargs['help_text'] = model_field.help_text
-
- return PrimaryKeyRelatedField(**kwargs)
-
- def get_field(self, model_field):
- """
- Creates a default instance of a basic non-relational field.
- """
- kwargs = {}
-
- if model_field.null or model_field.blank:
- kwargs['required'] = False
-
- if isinstance(model_field, models.AutoField) or not model_field.editable:
- kwargs['read_only'] = True
-
- if model_field.has_default():
- kwargs['default'] = model_field.get_default()
-
- if issubclass(model_field.__class__, models.TextField):
- kwargs['widget'] = widgets.Textarea
-
- if model_field.verbose_name is not None:
- kwargs['label'] = model_field.verbose_name
-
- if model_field.help_text is not None:
- kwargs['help_text'] = model_field.help_text
-
- # TODO: TypedChoiceField?
- if model_field.flatchoices: # This ModelField contains choices
- kwargs['choices'] = model_field.flatchoices
- if model_field.null:
- kwargs['empty'] = None
- return ChoiceField(**kwargs)
-
- # put this below the ChoiceField because min_value isn't a valid initializer
- if issubclass(model_field.__class__, models.PositiveIntegerField) or\
- issubclass(model_field.__class__, models.PositiveSmallIntegerField):
- kwargs['min_value'] = 0
-
- if model_field.null and \
- issubclass(model_field.__class__, (models.CharField, models.TextField)):
- kwargs['allow_none'] = True
-
- attribute_dict = {
- models.CharField: ['max_length'],
- models.CommaSeparatedIntegerField: ['max_length'],
- models.DecimalField: ['max_digits', 'decimal_places'],
- models.EmailField: ['max_length'],
- models.FileField: ['max_length'],
- models.ImageField: ['max_length'],
- models.SlugField: ['max_length'],
- models.URLField: ['max_length'],
- }
-
- if model_field.__class__ in attribute_dict:
- attributes = attribute_dict[model_field.__class__]
- for attribute in attributes:
- kwargs.update({attribute: getattr(model_field, attribute)})
-
- try:
- return self.field_mapping[model_field.__class__](**kwargs)
- except KeyError:
- return ModelField(model_field=model_field, **kwargs)
-
- def get_validation_exclusions(self, instance=None):
- """
- Return a list of field names to exclude from model validation.
- """
- cls = self.opts.model
- opts = cls._meta.concrete_model._meta
- exclusions = [field.name for field in opts.fields + opts.many_to_many]
-
- for field_name, field in self.fields.items():
- field_name = field.source or field_name
- if (
- field_name in exclusions
- and not field.read_only
- and (field.required or hasattr(instance, field_name))
- and not isinstance(field, Serializer)
- ):
- exclusions.remove(field_name)
- return exclusions
-
- def full_clean(self, instance):
- """
- Perform Django's full_clean, and populate the `errors` dictionary
- if any validation errors occur.
-
- Note that we don't perform this inside the `.restore_object()` method,
- so that subclasses can override `.restore_object()`, and still get
- the full_clean validation checking.
- """
- try:
- instance.full_clean(exclude=self.get_validation_exclusions(instance))
- except ValidationError as err:
- self._errors = err.message_dict
- return None
return instance
- def restore_object(self, attrs, instance=None):
- """
- Restore the model instance.
- """
- m2m_data = {}
- related_data = {}
- nested_forward_relations = {}
- meta = self.opts.model._meta
-
- # Reverse fk or one-to-one relations
- for (obj, model) in meta.get_all_related_objects_with_model():
- field_name = obj.get_accessor_name()
- if field_name in attrs:
- related_data[field_name] = attrs.pop(field_name)
-
- # Reverse m2m relations
- for (obj, model) in meta.get_all_related_m2m_objects_with_model():
- field_name = obj.get_accessor_name()
- if field_name in attrs:
- m2m_data[field_name] = attrs.pop(field_name)
-
- # Forward m2m relations
- for field in meta.many_to_many + meta.virtual_fields:
- if isinstance(field, GenericForeignKey):
- continue
- if field.name in attrs:
- m2m_data[field.name] = attrs.pop(field.name)
+ def update(self, obj, attrs):
+ for attr, value in attrs.items():
+ setattr(obj, attr, value)
+ obj.save()
- # Nested forward relations - These need to be marked so we can save
- # them before saving the parent model instance.
- for field_name in attrs.keys():
- if isinstance(self.fields.get(field_name, None), Serializer):
- nested_forward_relations[field_name] = attrs[field_name]
+ def _get_base_fields(self):
+ declared_fields = copy.deepcopy(self._declared_fields)
- # Create an empty instance of the model
- if instance is None:
- instance = self.opts.model()
+ ret = SortedDict()
+ model = getattr(self.Meta, 'model')
+ fields = getattr(self.Meta, 'fields', None)
+ depth = getattr(self.Meta, 'depth', 0)
+ extra_kwargs = getattr(self.Meta, 'extra_kwargs', {})
+
+ # Retrieve metadata about fields & relationships on the model class.
+ info = model_meta.get_field_info(model)
+
+ # Use the default set of fields if none is supplied explicitly.
+ if fields is None:
+ fields = self._get_default_field_names(declared_fields, info)
+
+ for field_name in fields:
+ if field_name in declared_fields:
+ # Field is explicitly declared on the class, use that.
+ ret[field_name] = declared_fields[field_name]
+ continue
- for key, val in attrs.items():
- try:
- setattr(instance, key, val)
- except ValueError:
- self._errors[key] = [self.error_messages['required']]
-
- # Any relations that cannot be set until we've
- # saved the model get hidden away on these
- # private attributes, so we can deal with them
- # at the point of save.
- instance._related_data = related_data
- instance._m2m_data = m2m_data
- instance._nested_forward_relations = nested_forward_relations
+ elif field_name == api_settings.URL_FIELD_NAME:
+ # Create the URL field.
+ field_cls = HyperlinkedIdentityField
+ kwargs = get_url_kwargs(model)
+
+ elif field_name in info.fields_and_pk:
+ # Create regular model fields.
+ model_field = info.fields_and_pk[field_name]
+ field_cls = lookup_class(self._field_mapping, model_field)
+ kwargs = get_field_kwargs(field_name, model_field)
+ if 'choices' in kwargs:
+ # Fields with choices get coerced into `ChoiceField`
+ # instead of using their regular typed field.
+ field_cls = ChoiceField
+ if not issubclass(field_cls, ModelField):
+ # `model_field` is only valid for the fallback case of
+ # `ModelField`, which is used when no other typed field
+ # matched to the model field.
+ kwargs.pop('model_field', None)
+
+ elif field_name in info.relations:
+ # Create forward and reverse relationships.
+ relation_info = info.relations[field_name]
+ if depth:
+ field_cls = self._get_nested_class(depth, relation_info)
+ kwargs = get_nested_relation_kwargs(relation_info)
+ else:
+ field_cls = self._related_class
+ kwargs = get_relation_kwargs(field_name, relation_info)
+ # `view_name` is only valid for hyperlinked relationships.
+ if not issubclass(field_cls, HyperlinkedRelatedField):
+ kwargs.pop('view_name', None)
- return instance
+ elif hasattr(model, field_name):
+ # Create a read only field for model methods and properties.
+ field_cls = ReadOnlyField
+ kwargs = {}
- def from_native(self, data, files):
- """
- Override the default method to also include model field validation.
- """
- instance = super(ModelSerializer, self).from_native(data, files)
- if not self._errors:
- return self.full_clean(instance)
+ else:
+ raise ImproperlyConfigured(
+ 'Field name `%s` is not valid for model `%s`.' %
+ (field_name, model.__class__.__name__)
+ )
+
+ # Check that any fields declared on the class are
+ # also explicity included in `Meta.fields`.
+ missing_fields = set(declared_fields.keys()) - set(fields)
+ if missing_fields:
+ missing_field = list(missing_fields)[0]
+ raise ImproperlyConfigured(
+ 'Field `%s` has been declared on serializer `%s`, but '
+ 'is missing from `Meta.fields`.' %
+ (missing_field, self.__class__.__name__)
+ )
+
+ # Populate any kwargs defined in `Meta.extra_kwargs`
+ kwargs.update(extra_kwargs.get(field_name, {}))
+
+ # Create the serializer field.
+ ret[field_name] = field_cls(**kwargs)
- def save_object(self, obj, **kwargs):
- """
- Save the deserialized object.
- """
- if getattr(obj, '_nested_forward_relations', None):
- # Nested relationships need to be saved before we can save the
- # parent instance.
- for field_name, sub_object in obj._nested_forward_relations.items():
- if sub_object:
- self.save_object(sub_object)
- setattr(obj, field_name, sub_object)
-
- obj.save(**kwargs)
-
- if getattr(obj, '_m2m_data', None):
- for accessor_name, object_list in obj._m2m_data.items():
- setattr(obj, accessor_name, object_list)
- del(obj._m2m_data)
-
- if getattr(obj, '_related_data', None):
- related_fields = dict([
- (field.get_accessor_name(), field)
- for field, model
- in obj._meta.get_all_related_objects_with_model()
- ])
- for accessor_name, related in obj._related_data.items():
- if isinstance(related, RelationsList):
- # Nested reverse fk relationship
- for related_item in related:
- fk_field = related_fields[accessor_name].field.name
- setattr(related_item, fk_field, obj)
- self.save_object(related_item)
-
- # Delete any removed objects
- if related._deleted:
- [self.delete_object(item) for item in related._deleted]
-
- elif isinstance(related, models.Model):
- # Nested reverse one-one relationship
- fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
- setattr(related, fk_field, obj)
- self.save_object(related)
- else:
- # Reverse FK or reverse one-one
- setattr(obj, accessor_name, related)
- del(obj._related_data)
+ return ret
+ def _get_default_field_names(self, declared_fields, model_info):
+ return (
+ [model_info.pk.name] +
+ list(declared_fields.keys()) +
+ list(model_info.fields.keys()) +
+ list(model_info.forward_relations.keys())
+ )
-class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
- """
- Options for HyperlinkedModelSerializer
- """
- def __init__(self, meta):
- super(HyperlinkedModelSerializerOptions, self).__init__(meta)
- self.view_name = getattr(meta, 'view_name', None)
- self.lookup_field = getattr(meta, 'lookup_field', None)
- self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME)
+ def _get_nested_class(self, nested_depth, relation_info):
+ class NestedSerializer(ModelSerializer):
+ class Meta:
+ model = relation_info.related
+ depth = nested_depth
+ return NestedSerializer
class HyperlinkedModelSerializer(ModelSerializer):
- """
- A subclass of ModelSerializer that uses hyperlinked relationships,
- instead of primary key relationships.
- """
- _options_class = HyperlinkedModelSerializerOptions
- _default_view_name = '%(model_name)s-detail'
- _hyperlink_field_class = HyperlinkedRelatedField
- _hyperlink_identify_field_class = HyperlinkedIdentityField
-
- def get_default_fields(self):
- fields = super(HyperlinkedModelSerializer, self).get_default_fields()
-
- if self.opts.view_name is None:
- self.opts.view_name = self._get_default_view_name(self.opts.model)
-
- if self.opts.url_field_name not in fields:
- url_field = self._hyperlink_identify_field_class(
- view_name=self.opts.view_name,
- lookup_field=self.opts.lookup_field
- )
- ret = self._dict_class()
- ret[self.opts.url_field_name] = url_field
- ret.update(fields)
- fields = ret
-
- return fields
-
- def get_pk_field(self, model_field):
- if self.opts.fields and model_field.name in self.opts.fields:
- return self.get_field(model_field)
-
- def get_related_field(self, model_field, related_model, to_many):
- """
- Creates a default instance of a flat relational field.
- """
- # TODO: filter queryset using:
- # .using(db).complex_filter(self.rel.limit_choices_to)
- kwargs = {
- 'queryset': related_model._default_manager,
- 'view_name': self._get_default_view_name(related_model),
- 'many': to_many
- }
-
- if model_field:
- kwargs['required'] = not(model_field.null or model_field.blank)
- if model_field.help_text is not None:
- kwargs['help_text'] = model_field.help_text
- if model_field.verbose_name is not None:
- kwargs['label'] = model_field.verbose_name
-
- if self.opts.lookup_field:
- kwargs['lookup_field'] = self.opts.lookup_field
-
- return self._hyperlink_field_class(**kwargs)
-
- def get_identity(self, data):
- """
- This hook is required for bulk update.
- We need to override the default, to use the url as the identity.
- """
- try:
- return data.get(self.opts.url_field_name, None)
- except AttributeError:
- return None
+ _related_class = HyperlinkedRelatedField
+
+ def _get_default_field_names(self, declared_fields, model_info):
+ return (
+ [api_settings.URL_FIELD_NAME] +
+ list(declared_fields.keys()) +
+ list(model_info.fields.keys()) +
+ list(model_info.forward_relations.keys())
+ )
- def _get_default_view_name(self, model):
- """
- Return the view name to use if 'view_name' is not specified in 'Meta'
- """
- model_meta = model._meta
- format_kwargs = {
- 'app_label': model_meta.app_label,
- 'model_name': model_meta.object_name.lower()
- }
- return self._default_view_name % format_kwargs
+ def _get_nested_class(self, nested_depth, relation_info):
+ class NestedSerializer(HyperlinkedModelSerializer):
+ class Meta:
+ model = relation_info.related
+ depth = nested_depth
+ return NestedSerializer
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 644751f8..421e146c 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -77,6 +77,7 @@ DEFAULTS = {
# Exception handling
'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
+ 'NON_FIELD_ERRORS_KEY': 'non_field_errors',
# Testing
'TEST_REQUEST_RENDERER_CLASSES': (
@@ -96,24 +97,19 @@ DEFAULTS = {
'URL_FIELD_NAME': 'url',
# Input and output formats
- 'DATE_INPUT_FORMATS': (
- ISO_8601,
- ),
- 'DATE_FORMAT': None,
+ 'DATE_FORMAT': ISO_8601,
+ 'DATE_INPUT_FORMATS': (ISO_8601,),
- 'DATETIME_INPUT_FORMATS': (
- ISO_8601,
- ),
- 'DATETIME_FORMAT': None,
+ 'DATETIME_FORMAT': ISO_8601,
+ 'DATETIME_INPUT_FORMATS': (ISO_8601,),
- 'TIME_INPUT_FORMATS': (
- ISO_8601,
- ),
- 'TIME_FORMAT': None,
-
- # Pending deprecation
- 'FILTER_BACKEND': None,
+ 'TIME_FORMAT': ISO_8601,
+ 'TIME_INPUT_FORMATS': (ISO_8601,),
+ # Encoding
+ 'UNICODE_JSON': True,
+ 'COMPACT_JSON': True,
+ 'COERCE_DECIMAL_TO_STRING': True
}
@@ -129,7 +125,6 @@ IMPORT_STRINGS = (
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'DEFAULT_FILTER_BACKENDS',
'EXCEPTION_HANDLER',
- 'FILTER_BACKEND',
'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
@@ -196,15 +191,9 @@ class APISettings(object):
if val and attr in self.import_strings:
val = perform_import(val, attr)
- self.validate_setting(attr, val)
-
# Cache the result
setattr(self, attr, val)
return val
- def validate_setting(self, attr, val):
- if attr == 'FILTER_BACKEND' and val is not None:
- # Make sure we can initialize the class
- val()
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)
diff --git a/rest_framework/test.py b/rest_framework/test.py
index f89a6dcd..9b40353a 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -36,7 +36,7 @@ class APIRequestFactory(DjangoRequestFactory):
Encode the data returning a two tuple of (bytes, content_type)
"""
- if not data:
+ if data is None:
return ('', content_type)
assert format is None or content_type is None, (
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 00ffdfba..174b08b8 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -7,7 +7,6 @@ from django.db.models.query import QuerySet
from django.utils.datastructures import SortedDict
from django.utils.functional import Promise
from rest_framework.compat import force_text
-from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
import datetime
import decimal
import types
@@ -17,45 +16,47 @@ import json
class JSONEncoder(json.JSONEncoder):
"""
JSONEncoder subclass that knows how to encode date/time/timedelta,
- decimal types, and generators.
+ decimal types, generators and other basic python objects.
"""
- def default(self, o):
+ def default(self, obj):
# For Date Time string spec, see ECMA 262
# http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15
- if isinstance(o, Promise):
- return force_text(o)
- elif isinstance(o, datetime.datetime):
- r = o.isoformat()
- if o.microsecond:
- r = r[:23] + r[26:]
- if r.endswith('+00:00'):
- r = r[:-6] + 'Z'
- return r
- elif isinstance(o, datetime.date):
- return o.isoformat()
- elif isinstance(o, datetime.time):
- if timezone and timezone.is_aware(o):
+ if isinstance(obj, Promise):
+ return force_text(obj)
+ elif isinstance(obj, datetime.datetime):
+ representation = obj.isoformat()
+ if obj.microsecond:
+ representation = representation[:23] + representation[26:]
+ if representation.endswith('+00:00'):
+ representation = representation[:-6] + 'Z'
+ return representation
+ elif isinstance(obj, datetime.date):
+ return obj.isoformat()
+ elif isinstance(obj, datetime.time):
+ if timezone and timezone.is_aware(obj):
raise ValueError("JSON can't represent timezone-aware times.")
- r = o.isoformat()
- if o.microsecond:
- r = r[:12]
- return r
- elif isinstance(o, datetime.timedelta):
- return str(o.total_seconds())
- elif isinstance(o, decimal.Decimal):
- return str(o)
- elif isinstance(o, QuerySet):
- return list(o)
- elif hasattr(o, 'tolist'):
- return o.tolist()
- elif hasattr(o, '__getitem__'):
+ representation = obj.isoformat()
+ if obj.microsecond:
+ representation = representation[:12]
+ return representation
+ elif isinstance(obj, datetime.timedelta):
+ return str(obj.total_seconds())
+ elif isinstance(obj, decimal.Decimal):
+ # Serializers will coerce decimals to strings by default.
+ return float(obj)
+ elif isinstance(obj, QuerySet):
+ return list(obj)
+ elif hasattr(obj, 'tolist'):
+ # Numpy arrays and array scalars.
+ return obj.tolist()
+ elif hasattr(obj, '__getitem__'):
try:
- return dict(o)
+ return dict(obj)
except:
pass
- elif hasattr(o, '__iter__'):
- return [i for i in o]
- return super(JSONEncoder, self).default(o)
+ elif hasattr(obj, '__iter__'):
+ return [item for item in obj]
+ return super(JSONEncoder, self).default(obj)
try:
@@ -106,14 +107,14 @@ else:
SortedDict,
yaml.representer.SafeRepresenter.represent_dict
)
- SafeDumper.add_representer(
- DictWithMetadata,
- yaml.representer.SafeRepresenter.represent_dict
- )
- SafeDumper.add_representer(
- SortedDictWithMetadata,
- yaml.representer.SafeRepresenter.represent_dict
- )
+ # SafeDumper.add_representer(
+ # DictWithMetadata,
+ # yaml.representer.SafeRepresenter.represent_dict
+ # )
+ # SafeDumper.add_representer(
+ # SortedDictWithMetadata,
+ # yaml.representer.SafeRepresenter.represent_dict
+ # )
SafeDumper.add_representer(
types.GeneratorType,
yaml.representer.SafeRepresenter.represent_list
diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py
new file mode 100644
index 00000000..be72e444
--- /dev/null
+++ b/rest_framework/utils/field_mapping.py
@@ -0,0 +1,215 @@
+"""
+Helper functions for mapping model fields to a dictionary of default
+keyword arguments that should be used for their equivelent serializer fields.
+"""
+from django.core import validators
+from django.db import models
+from django.utils.text import capfirst
+from rest_framework.compat import clean_manytomany_helptext
+import inspect
+
+
+def lookup_class(mapping, instance):
+ """
+ Takes a dictionary with classes as keys, and an object.
+ Traverses the object's inheritance hierarchy in method
+ resolution order, and returns the first matching value
+ from the dictionary or raises a KeyError if nothing matches.
+ """
+ for cls in inspect.getmro(instance.__class__):
+ if cls in mapping:
+ return mapping[cls]
+ raise KeyError('Class %s not found in lookup.', cls.__name__)
+
+
+def needs_label(model_field, field_name):
+ """
+ Returns `True` if the label based on the model's verbose name
+ is not equal to the default label it would have based on it's field name.
+ """
+ default_label = field_name.replace('_', ' ').capitalize()
+ return capfirst(model_field.verbose_name) != default_label
+
+
+def get_detail_view_name(model):
+ """
+ Given a model class, return the view name to use for URL relationships
+ that refer to instances of the model.
+ """
+ return '%(model_name)s-detail' % {
+ 'app_label': model._meta.app_label,
+ 'model_name': model._meta.object_name.lower()
+ }
+
+
+def get_field_kwargs(field_name, model_field):
+ """
+ Creates a default instance of a basic non-relational field.
+ """
+ kwargs = {}
+ validator_kwarg = model_field.validators
+
+ if model_field.null or model_field.blank:
+ kwargs['required'] = False
+
+ if model_field.verbose_name and needs_label(model_field, field_name):
+ kwargs['label'] = capfirst(model_field.verbose_name)
+
+ if model_field.help_text:
+ kwargs['help_text'] = model_field.help_text
+
+ if isinstance(model_field, models.AutoField) or not model_field.editable:
+ kwargs['read_only'] = True
+ # Read only implies that the field is not required.
+ # We have a cleaner repr on the instance if we don't set it.
+ kwargs.pop('required', None)
+
+ if model_field.has_default():
+ kwargs['default'] = model_field.get_default()
+ # Having a default implies that the field is not required.
+ # We have a cleaner repr on the instance if we don't set it.
+ kwargs.pop('required', None)
+
+ if model_field.flatchoices:
+ # If this model field contains choices, then return now,
+ # any further keyword arguments are not valid.
+ kwargs['choices'] = model_field.flatchoices
+ return kwargs
+
+ # Ensure that max_length is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ max_length = getattr(model_field, 'max_length', None)
+ if max_length is not None:
+ kwargs['max_length'] = max_length
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MaxLengthValidator)
+ ]
+
+ # Ensure that min_length is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ min_length = getattr(model_field, 'min_length', None)
+ if min_length is not None:
+ kwargs['min_length'] = min_length
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MinLengthValidator)
+ ]
+
+ # Ensure that max_value is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ max_value = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MaxValueValidator)
+ ), None)
+ if max_value is not None:
+ kwargs['max_value'] = max_value
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MaxValueValidator)
+ ]
+
+ # Ensure that max_value is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ min_value = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MinValueValidator)
+ ), None)
+ if min_value is not None:
+ kwargs['min_value'] = min_value
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MinValueValidator)
+ ]
+
+ # URLField does not need to include the URLValidator argument,
+ # as it is explicitly added in.
+ if isinstance(model_field, models.URLField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.URLValidator)
+ ]
+
+ # EmailField does not need to include the validate_email argument,
+ # as it is explicitly added in.
+ if isinstance(model_field, models.EmailField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if validator is not validators.validate_email
+ ]
+
+ # SlugField do not need to include the 'validate_slug' argument,
+ if isinstance(model_field, models.SlugField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if validator is not validators.validate_slug
+ ]
+
+ max_digits = getattr(model_field, 'max_digits', None)
+ if max_digits is not None:
+ kwargs['max_digits'] = max_digits
+
+ decimal_places = getattr(model_field, 'decimal_places', None)
+ if decimal_places is not None:
+ kwargs['decimal_places'] = decimal_places
+
+ if isinstance(model_field, models.BooleanField):
+ # models.BooleanField has `blank=True`, but *is* actually
+ # required *unless* a default is provided.
+ # Also note that Django<1.6 uses `default=False` for
+ # models.BooleanField, but Django>=1.6 uses `default=None`.
+ kwargs.pop('required', None)
+
+ if validator_kwarg:
+ kwargs['validators'] = validator_kwarg
+
+ # The following will only be used by ModelField classes.
+ # Gets removed for everything else.
+ kwargs['model_field'] = model_field
+
+ return kwargs
+
+
+def get_relation_kwargs(field_name, relation_info):
+ """
+ Creates a default instance of a flat relational field.
+ """
+ model_field, related_model, to_many, has_through_model = relation_info
+ kwargs = {
+ 'queryset': related_model._default_manager,
+ 'view_name': get_detail_view_name(related_model)
+ }
+
+ if to_many:
+ kwargs['many'] = True
+
+ if has_through_model:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
+
+ if model_field:
+ if model_field.null or model_field.blank:
+ kwargs['required'] = False
+ if model_field.verbose_name and needs_label(model_field, field_name):
+ kwargs['label'] = capfirst(model_field.verbose_name)
+ if not model_field.editable:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
+ help_text = clean_manytomany_helptext(model_field.help_text)
+ if help_text:
+ kwargs['help_text'] = help_text
+
+ return kwargs
+
+
+def get_nested_relation_kwargs(relation_info):
+ kwargs = {'read_only': True}
+ if relation_info.to_many:
+ kwargs['many'] = True
+ return kwargs
+
+
+def get_url_kwargs(model_field):
+ return {
+ 'view_name': get_detail_view_name(model_field)
+ }
diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py
index 6d53aed1..470af51b 100644
--- a/rest_framework/utils/formatting.py
+++ b/rest_framework/utils/formatting.py
@@ -2,11 +2,12 @@
Utility functions to return a formatted name and description for a given view.
"""
from __future__ import unicode_literals
+import re
from django.utils.html import escape
from django.utils.safestring import mark_safe
-from rest_framework.compat import apply_markdown
-import re
+
+from rest_framework.compat import apply_markdown, force_text
def remove_trailing_string(content, trailing):
@@ -28,6 +29,7 @@ def dedent(content):
as it fails to dedent multiline docstrings that include
unindented text on the initial line.
"""
+ content = force_text(content)
whitespace_counts = [len(line) - len(line.lstrip(' '))
for line in content.splitlines()[1:] if line.lstrip()]
diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py
new file mode 100644
index 00000000..edc591e9
--- /dev/null
+++ b/rest_framework/utils/html.py
@@ -0,0 +1,88 @@
+"""
+Helpers for dealing with HTML input.
+"""
+import re
+
+
+def is_html_input(dictionary):
+ # MultiDict type datastructures are used to represent HTML form input,
+ # which may have more than one value for each key.
+ return hasattr(dictionary, 'getlist')
+
+
+def parse_html_list(dictionary, prefix=''):
+ """
+ Used to suport list values in HTML forms.
+ Supports lists of primitives and/or dictionaries.
+
+ * List of primitives.
+
+ {
+ '[0]': 'abc',
+ '[1]': 'def',
+ '[2]': 'hij'
+ }
+ -->
+ [
+ 'abc',
+ 'def',
+ 'hij'
+ ]
+
+ * List of dictionaries.
+
+ {
+ '[0]foo': 'abc',
+ '[0]bar': 'def',
+ '[1]foo': 'hij',
+ '[2]bar': 'klm',
+ }
+ -->
+ [
+ {'foo': 'abc', 'bar': 'def'},
+ {'foo': 'hij', 'bar': 'klm'}
+ ]
+ """
+ Dict = type(dictionary)
+ ret = {}
+ regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix))
+ for field, value in dictionary.items():
+ match = regex.match(field)
+ if not match:
+ continue
+ index, key = match.groups()
+ index = int(index)
+ if not key:
+ ret[index] = value
+ elif isinstance(ret.get(index), dict):
+ ret[index][key] = value
+ else:
+ ret[index] = Dict({key: value})
+ return [ret[item] for item in sorted(ret.keys())]
+
+
+def parse_html_dict(dictionary, prefix):
+ """
+ Used to support dictionary values in HTML forms.
+
+ {
+ 'profile.username': 'example',
+ 'profile.email': 'example@example.com',
+ }
+ -->
+ {
+ 'profile': {
+ 'username': 'example,
+ 'email': 'example@example.com'
+ }
+ }
+ """
+ ret = {}
+ regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix))
+ for field, value in dictionary.items():
+ match = regex.match(field)
+ if not match:
+ continue
+ key = match.groups()[0]
+ ret[key] = value
+ return ret
diff --git a/rest_framework/utils/humanize_datetime.py b/rest_framework/utils/humanize_datetime.py
new file mode 100644
index 00000000..649f2abc
--- /dev/null
+++ b/rest_framework/utils/humanize_datetime.py
@@ -0,0 +1,47 @@
+"""
+Helper functions that convert strftime formats into more readable representations.
+"""
+from rest_framework import ISO_8601
+
+
+def datetime_formats(formats):
+ format = ', '.join(formats).replace(
+ ISO_8601,
+ 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
+ )
+ return humanize_strptime(format)
+
+
+def date_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
+ return humanize_strptime(format)
+
+
+def time_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
+ return humanize_strptime(format)
+
+
+def humanize_strptime(format_string):
+ # Note that we're missing some of the locale specific mappings that
+ # don't really make sense.
+ mapping = {
+ "%Y": "YYYY",
+ "%y": "YY",
+ "%m": "MM",
+ "%b": "[Jan-Dec]",
+ "%B": "[January-December]",
+ "%d": "DD",
+ "%H": "hh",
+ "%I": "hh", # Requires '%p' to differentiate from '%H'.
+ "%M": "mm",
+ "%S": "ss",
+ "%f": "uuuuuu",
+ "%a": "[Mon-Sun]",
+ "%A": "[Monday-Sunday]",
+ "%p": "[AM|PM]",
+ "%z": "[+HHMM|-HHMM]"
+ }
+ for key, val in mapping.items():
+ format_string = format_string.replace(key, val)
+ return format_string
diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py
new file mode 100644
index 00000000..b6c41174
--- /dev/null
+++ b/rest_framework/utils/model_meta.py
@@ -0,0 +1,129 @@
+"""
+Helper function for returning the field information that is associated
+with a model class. This includes returning all the forward and reverse
+relationships and their associated metadata.
+
+Usage: `get_field_info(model)` returns a `FieldInfo` instance.
+"""
+from collections import namedtuple
+from django.db import models
+from django.utils import six
+from django.utils.datastructures import SortedDict
+import inspect
+
+
+FieldInfo = namedtuple('FieldResult', [
+ 'pk', # Model field instance
+ 'fields', # Dict of field name -> model field instance
+ 'forward_relations', # Dict of field name -> RelationInfo
+ 'reverse_relations', # Dict of field name -> RelationInfo
+ 'fields_and_pk', # Shortcut for 'pk' + 'fields'
+ 'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
+])
+
+RelationInfo = namedtuple('RelationInfo', [
+ 'model_field',
+ 'related',
+ 'to_many',
+ 'has_through_model'
+])
+
+
+def _resolve_model(obj):
+ """
+ Resolve supplied `obj` to a Django model class.
+
+ `obj` must be a Django model class itself, or a string
+ representation of one. Useful in situtations like GH #1225 where
+ Django may not have resolved a string-based reference to a model in
+ another model's foreign key definition.
+
+ String representations should have the format:
+ 'appname.ModelName'
+ """
+ if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
+ app_name, model_name = obj.split('.')
+ return models.get_model(app_name, model_name)
+ elif inspect.isclass(obj) and issubclass(obj, models.Model):
+ return obj
+ raise ValueError("{0} is not a Django model".format(obj))
+
+
+def get_field_info(model):
+ """
+ Given a model class, returns a `FieldInfo` instance containing metadata
+ about the various field types on the model.
+ """
+ opts = model._meta.concrete_model._meta
+
+ # Deal with the primary key.
+ pk = opts.pk
+ while pk.rel and pk.rel.parent_link:
+ # If model is a child via multitable inheritance, use parent's pk.
+ pk = pk.rel.to._meta.pk
+
+ # Deal with regular fields.
+ fields = SortedDict()
+ for field in [field for field in opts.fields if field.serialize and not field.rel]:
+ fields[field.name] = field
+
+ # Deal with forward relationships.
+ forward_relations = SortedDict()
+ for field in [field for field in opts.fields if field.serialize and field.rel]:
+ forward_relations[field.name] = RelationInfo(
+ model_field=field,
+ related=_resolve_model(field.rel.to),
+ to_many=False,
+ has_through_model=False
+ )
+
+ # Deal with forward many-to-many relationships.
+ for field in [field for field in opts.many_to_many if field.serialize]:
+ forward_relations[field.name] = RelationInfo(
+ model_field=field,
+ related=_resolve_model(field.rel.to),
+ to_many=True,
+ has_through_model=(
+ not field.rel.through._meta.auto_created
+ )
+ )
+
+ # Deal with reverse relationships.
+ reverse_relations = SortedDict()
+ for relation in opts.get_all_related_objects():
+ accessor_name = relation.get_accessor_name()
+ reverse_relations[accessor_name] = RelationInfo(
+ model_field=None,
+ related=relation.model,
+ to_many=relation.field.rel.multiple,
+ has_through_model=False
+ )
+
+ # Deal with reverse many-to-many relationships.
+ for relation in opts.get_all_related_many_to_many_objects():
+ accessor_name = relation.get_accessor_name()
+ reverse_relations[accessor_name] = RelationInfo(
+ model_field=None,
+ related=relation.model,
+ to_many=True,
+ has_through_model=(
+ hasattr(relation.field.rel, 'through') and
+ not relation.field.rel.through._meta.auto_created
+ )
+ )
+
+ # Shortcut that merges both regular fields and the pk,
+ # for simplifying regular field lookup.
+ fields_and_pk = SortedDict()
+ fields_and_pk['pk'] = pk
+ fields_and_pk[pk.name] = pk
+ fields_and_pk.update(fields)
+
+ # Shortcut that merges both forward and reverse relationships
+
+ relations = SortedDict(
+ list(forward_relations.items()) +
+ list(reverse_relations.items())
+ )
+
+ return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations)
diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py
new file mode 100644
index 00000000..e64fdd22
--- /dev/null
+++ b/rest_framework/utils/representation.py
@@ -0,0 +1,87 @@
+"""
+Helper functions for creating user-friendly representations
+of serializer classes and serializer fields.
+"""
+from django.db import models
+import re
+
+
+def manager_repr(value):
+ model = value.model
+ opts = model._meta
+ for _, name, manager in opts.concrete_managers + opts.abstract_managers:
+ if manager == value:
+ return '%s.%s.all()' % (model._meta.object_name, name)
+ return repr(value)
+
+
+def smart_repr(value):
+ if isinstance(value, models.Manager):
+ return manager_repr(value)
+
+ value = repr(value)
+
+ # Representations like u'help text'
+ # should simply be presented as 'help text'
+ if value.startswith("u'") and value.endswith("'"):
+ return value[1:]
+
+ # Representations like
+ # <django.core.validators.RegexValidator object at 0x1047af050>
+ # Should be presented as
+ # <django.core.validators.RegexValidator object>
+ value = re.sub(' at 0x[0-9a-f]{4,32}>', '>', value)
+
+ return value
+
+
+def field_repr(field, force_many=False):
+ kwargs = field._kwargs
+ if force_many:
+ kwargs = kwargs.copy()
+ kwargs['many'] = True
+ kwargs.pop('child', None)
+
+ arg_string = ', '.join([smart_repr(val) for val in field._args])
+ kwarg_string = ', '.join([
+ '%s=%s' % (key, smart_repr(val))
+ for key, val in sorted(kwargs.items())
+ ])
+ if arg_string and kwarg_string:
+ arg_string += ', '
+
+ if force_many:
+ class_name = force_many.__class__.__name__
+ else:
+ class_name = field.__class__.__name__
+
+ return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
+
+
+def serializer_repr(serializer, indent, force_many=None):
+ ret = field_repr(serializer, force_many) + ':'
+ indent_str = ' ' * indent
+
+ if force_many:
+ fields = force_many.fields
+ else:
+ fields = serializer.fields
+
+ for field_name, field in fields.items():
+ ret += '\n' + indent_str + field_name + ' = '
+ if hasattr(field, 'fields'):
+ ret += serializer_repr(field, indent + 1)
+ elif hasattr(field, 'child'):
+ ret += list_repr(field, indent + 1)
+ elif hasattr(field, 'child_relation'):
+ ret += field_repr(field.child_relation, force_many=field.child_relation)
+ else:
+ ret += field_repr(field)
+ return ret
+
+
+def list_repr(serializer, indent):
+ child = serializer.child
+ if hasattr(child, 'fields'):
+ return serializer_repr(serializer, indent, force_many=child)
+ return field_repr(serializer)
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 38346ab7..9f08a4ad 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework.
"""
from __future__ import unicode_literals
-from django.core.exceptions import PermissionDenied
+from django.core.exceptions import PermissionDenied, ValidationError, NON_FIELD_ERRORS
from django.http import Http404
from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt
@@ -51,7 +51,8 @@ def exception_handler(exc):
Returns the response that should be used for any given exception.
By default we handle the REST framework `APIException`, and also
- Django's builtin `Http404` and `PermissionDenied` exceptions.
+ Django's built-in `ValidationError`, `Http404` and `PermissionDenied`
+ exceptions.
Any unhandled exceptions may return `None`, which will cause a 500 error
to be raised.
@@ -61,13 +62,22 @@ def exception_handler(exc):
if getattr(exc, 'auth_header', None):
headers['WWW-Authenticate'] = exc.auth_header
if getattr(exc, 'wait', None):
- headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
headers['Retry-After'] = '%d' % exc.wait
return Response({'detail': exc.detail},
status=exc.status_code,
headers=headers)
+ elif isinstance(exc, ValidationError):
+ # ValidationErrors may include the non-field key named '__all__'.
+ # When returning a response we map this to a key name that can be
+ # modified in settings.
+ if NON_FIELD_ERRORS in exc.message_dict:
+ errors = exc.message_dict.pop(NON_FIELD_ERRORS)
+ exc.message_dict[api_settings.NON_FIELD_ERRORS_KEY] = errors
+ return Response(exc.message_dict,
+ status=status.HTTP_400_BAD_REQUEST)
+
elif isinstance(exc, Http404):
return Response({'detail': 'Not found'},
status=status.HTTP_404_NOT_FOUND)
diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py
index bb5b304e..84b4bd8d 100644
--- a/rest_framework/viewsets.py
+++ b/rest_framework/viewsets.py
@@ -20,6 +20,7 @@ from __future__ import unicode_literals
from functools import update_wrapper
from django.utils.decorators import classonlymethod
+from django.views.decorators.csrf import csrf_exempt
from rest_framework import views, generics, mixins
@@ -89,7 +90,7 @@ class ViewSetMixin(object):
# resolved URL.
view.cls = cls
view.suffix = initkwargs.get('suffix', None)
- return view
+ return csrf_exempt(view)
def initialize_request(self, request, *args, **kargs):
"""