aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/authtoken/serializers.py5
-rw-r--r--rest_framework/authtoken/views.py3
-rw-r--r--rest_framework/fields.py50
-rw-r--r--rest_framework/mixins.py41
-rw-r--r--rest_framework/pagination.py36
-rw-r--r--rest_framework/relations.py2
-rw-r--r--rest_framework/serializers.py53
7 files changed, 103 insertions, 87 deletions
diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py
index 99e99ae3..edeae857 100644
--- a/rest_framework/authtoken/serializers.py
+++ b/rest_framework/authtoken/serializers.py
@@ -19,11 +19,12 @@ class AuthTokenSerializer(serializers.Serializer):
if not user.is_active:
msg = _('User account is disabled.')
raise serializers.ValidationError(msg)
- attrs['user'] = user
- return attrs
else:
msg = _('Unable to login with provided credentials.')
raise serializers.ValidationError(msg)
else:
msg = _('Must include "username" and "password"')
raise serializers.ValidationError(msg)
+
+ attrs['user'] = user
+ return attrs
diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py
index 7c03cb76..94e6f061 100644
--- a/rest_framework/authtoken/views.py
+++ b/rest_framework/authtoken/views.py
@@ -18,7 +18,8 @@ class ObtainAuthToken(APIView):
def post(self, request):
serializer = self.serializer_class(data=request.DATA)
if serializer.is_valid():
- token, created = Token.objects.get_or_create(user=serializer.object['user'])
+ user = serializer.validated_data['user']
+ token, created = Token.objects.get_or_create(user=user)
return Response({'token': token.key})
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 3e0f7ca4..838aa3b0 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,4 +1,5 @@
from rest_framework.utils import html
+import inspect
class empty:
@@ -11,6 +12,22 @@ class empty:
pass
+def is_simple_callable(obj):
+ """
+ True if the object is a callable that takes no arguments.
+ """
+ function = inspect.isfunction(obj)
+ method = inspect.ismethod(obj)
+
+ if not (function or method):
+ return False
+
+ args, _, _, defaults = inspect.getargspec(obj)
+ len_args = len(args) if function else len(args) - 1
+ len_defaults = len(defaults) if defaults else 0
+ return len_args <= len_defaults
+
+
def get_attribute(instance, attrs):
"""
Similar to Python's built in `getattr(instance, attr)`,
@@ -98,6 +115,7 @@ class Field(object):
self.field_name = field_name
self.parent = parent
self.root = root
+ self.context = parent.context
# `self.label` should deafult to being based on the field name.
if self.label is None:
@@ -297,25 +315,55 @@ class IntegerField(Field):
self.fail('invalid_integer')
return data
+ def to_primative(self, value):
+ if value is None:
+ return None
+ return int(value)
+
class EmailField(CharField):
pass # TODO
+class URLField(CharField):
+ pass # TODO
+
+
class RegexField(CharField):
def __init__(self, **kwargs):
self.regex = kwargs.pop('regex')
super(CharField, self).__init__(**kwargs)
+class DateField(CharField):
+ def __init__(self, **kwargs):
+ self.input_formats = kwargs.pop('input_formats', None)
+ super(DateField, self).__init__(**kwargs)
+
+
+class TimeField(CharField):
+ def __init__(self, **kwargs):
+ self.input_formats = kwargs.pop('input_formats', None)
+ super(TimeField, self).__init__(**kwargs)
+
+
class DateTimeField(CharField):
- pass # TODO
+ def __init__(self, **kwargs):
+ self.input_formats = kwargs.pop('input_formats', None)
+ super(DateTimeField, self).__init__(**kwargs)
class FileField(Field):
pass # TODO
+class ReadOnlyField(Field):
+ def to_primative(self, value):
+ if is_simple_callable(value):
+ return value()
+ return value
+
+
class MethodField(Field):
def __init__(self, **kwargs):
kwargs['source'] = '*'
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 3e9c9bb3..359740ce 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -13,23 +13,6 @@ from rest_framework.request import clone_request
from rest_framework.settings import api_settings
-def _get_validation_exclusions(obj, lookup_field=None):
- """
- Given a model instance, and an optional pk and slug field,
- return the full list of all other field names on that model.
-
- For use when performing full_clean on a model instance,
- so we only clean the required fields.
- """
- if lookup_field == 'pk':
- pk_field = obj._meta.pk
- while pk_field.rel:
- pk_field = pk_field.rel.to._meta.pk
- lookup_field = pk_field.name
-
- return [field.name for field in obj._meta.fields if field.name != lookup_field]
-
-
class CreateModelMixin(object):
"""
Create a model instance.
@@ -92,15 +75,14 @@ class UpdateModelMixin(object):
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
- lookup_value = self.kwargs[lookup_url_kwarg]
- extras = {self.lookup_field: lookup_value}
-
if self.object is None:
+ lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
+ lookup_value = self.kwargs[lookup_url_kwarg]
+ extras = {self.lookup_field: lookup_value}
self.object = serializer.save(extras=extras)
return Response(serializer.data, status=status.HTTP_201_CREATED)
- self.object = serializer.save(extras=extras)
+ self.object = serializer.save()
return Response(serializer.data, status=status.HTTP_200_OK)
def partial_update(self, request, *args, **kwargs):
@@ -122,21 +104,6 @@ class UpdateModelMixin(object):
# return a 404 response.
raise
- def pre_save(self, obj):
- """
- Set any attributes on the object that are implicit in the request.
- """
- lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
- lookup_value = self.kwargs[lookup_url_kwarg]
-
- setattr(obj, self.lookup_field, lookup_value)
-
- # Ensure we clean the attributes so that we don't eg return integer
- # pk using a string representation, as provided by the url conf kwarg.
- if hasattr(obj, 'full_clean'):
- exclude = _get_validation_exclusions(obj, self.lookup_field)
- obj.full_clean(exclude)
-
class DestroyModelMixin(object):
"""
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 83ef97c5..478d32b4 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -13,7 +13,7 @@ class NextPageField(serializers.Field):
"""
page_field = 'page'
- def to_native(self, value):
+ def to_primative(self, value):
if not value.has_next():
return None
page = value.next_page_number()
@@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field):
"""
page_field = 'page'
- def to_native(self, value):
+ def to_primative(self, value):
if not value.has_previous():
return None
page = value.previous_page_number()
@@ -48,25 +48,11 @@ class DefaultObjectSerializer(serializers.Field):
super(DefaultObjectSerializer, self).__init__(source=source)
-# class PaginationSerializerOptions(serializers.SerializerOptions):
-# """
-# An object that stores the options that may be provided to a
-# pagination serializer by using the inner `Meta` class.
-
-# Accessible on the instance as `serializer.opts`.
-# """
-# def __init__(self, meta):
-# super(PaginationSerializerOptions, self).__init__(meta)
-# self.object_serializer_class = getattr(meta, 'object_serializer_class',
-# DefaultObjectSerializer)
-
-
class BasePaginationSerializer(serializers.Serializer):
"""
A base class for pagination serializers to inherit from,
to make implementing custom serializers more easy.
"""
- # _options_class = PaginationSerializerOptions
results_field = 'results'
def __init__(self, *args, **kwargs):
@@ -75,14 +61,16 @@ class BasePaginationSerializer(serializers.Serializer):
"""
super(BasePaginationSerializer, self).__init__(*args, **kwargs)
results_field = self.results_field
- object_serializer = self.opts.object_serializer_class
-
- if 'context' in kwargs:
- context_kwarg = {'context': kwargs['context']}
- else:
- context_kwarg = {}
-
- self.fields[results_field] = object_serializer(source='object_list', **context_kwarg)
+ try:
+ object_serializer = self.Meta.object_serializer_class
+ except AttributeError:
+ object_serializer = DefaultObjectSerializer
+
+ self.fields[results_field] = serializers.ListSerializer(
+ child=object_serializer(),
+ source='object_list'
+ )
+ self.fields[results_field].bind(results_field, self, self) # TODO: Support automatic binding
class PaginationSerializer(BasePaginationSerializer):
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 42d2c121..0b01394a 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -73,7 +73,7 @@ class HyperlinkedRelatedField(RelatedField):
try:
http_prefix = value.startswith(('http:', 'https:'))
except AttributeError:
- self.fail('incorrect_type', type(value).__name__)
+ self.fail('incorrect_type', data_type=type(value).__name__)
if http_prefix:
# If needed convert absolute URLs to relative path
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 2f23b4d9..c38d8968 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -142,7 +142,7 @@ class Serializer(BaseSerializer):
return super(Serializer, cls).__new__(cls)
def __init__(self, *args, **kwargs):
- kwargs.pop('context', None)
+ self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
kwargs.pop('many', False)
@@ -202,7 +202,7 @@ class Serializer(BaseSerializer):
if errors:
raise ValidationError(errors)
- return ret
+ return self.validate(ret)
def to_primative(self, instance):
"""
@@ -217,6 +217,9 @@ class Serializer(BaseSerializer):
return ret
+ def validate(self, attrs):
+ return attrs
+
def __iter__(self):
errors = self.errors if hasattr(self, '_errors') else {}
for field in self.fields.values():
@@ -232,8 +235,7 @@ class ListSerializer(BaseSerializer):
def __init__(self, *args, **kwargs):
self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.'
-
- kwargs.pop('context', None)
+ self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
super(ListSerializer, self).__init__(*args, **kwargs)
@@ -316,19 +318,19 @@ class ModelSerializer(Serializer):
models.PositiveIntegerField: IntegerField,
models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
- # models.DateTimeField: DateTimeField,
- # models.DateField: DateField,
- # models.TimeField: TimeField,
+ models.DateTimeField: DateTimeField,
+ models.DateField: DateField,
+ models.TimeField: TimeField,
# models.DecimalField: DecimalField,
- # models.EmailField: EmailField,
+ models.EmailField: EmailField,
models.CharField: CharField,
- # models.URLField: URLField,
+ models.URLField: URLField,
# models.SlugField: SlugField,
models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField,
models.NullBooleanField: BooleanField,
- # models.FileField: FileField,
+ models.FileField: FileField,
# models.ImageField: ImageField,
}
@@ -338,6 +340,15 @@ class ModelSerializer(Serializer):
self.opts = self._options_class(self.Meta)
super(ModelSerializer, self).__init__(*args, **kwargs)
+ def create(self):
+ ModelClass = self.opts.model
+ return ModelClass.objects.create(**self.validated_data)
+
+ def update(self, obj):
+ for attr, value in self.validated_data.items():
+ setattr(obj, attr, value)
+ obj.save()
+
def get_fields(self):
# Get the explicitly declared fields.
fields = copy.deepcopy(self.base_fields)
@@ -566,8 +577,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
- # _hyperlink_field_class = HyperlinkedRelatedField
- # _hyperlink_identify_field_class = HyperlinkedIdentityField
+ _hyperlink_field_class = HyperlinkedRelatedField
+ _hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
@@ -575,15 +586,15 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model)
- # if self.opts.url_field_name not in fields:
- # url_field = self._hyperlink_identify_field_class(
- # view_name=self.opts.view_name,
- # lookup_field=self.opts.lookup_field
- # )
- # ret = self._dict_class()
- # ret[self.opts.url_field_name] = url_field
- # ret.update(fields)
- # fields = ret
+ if self.opts.url_field_name not in fields:
+ url_field = self._hyperlink_identify_field_class(
+ view_name=self.opts.view_name,
+ lookup_field=self.opts.lookup_field
+ )
+ ret = fields.__class__()
+ ret[self.opts.url_field_name] = url_field
+ ret.update(fields)
+ fields = ret
return fields