diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/authtoken/serializers.py | 5 | ||||
| -rw-r--r-- | rest_framework/authtoken/views.py | 3 | ||||
| -rw-r--r-- | rest_framework/fields.py | 50 | ||||
| -rw-r--r-- | rest_framework/mixins.py | 41 | ||||
| -rw-r--r-- | rest_framework/pagination.py | 36 | ||||
| -rw-r--r-- | rest_framework/relations.py | 2 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 53 |
7 files changed, 103 insertions, 87 deletions
diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 99e99ae3..edeae857 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -19,11 +19,12 @@ class AuthTokenSerializer(serializers.Serializer): if not user.is_active: msg = _('User account is disabled.') raise serializers.ValidationError(msg) - attrs['user'] = user - return attrs else: msg = _('Unable to login with provided credentials.') raise serializers.ValidationError(msg) else: msg = _('Must include "username" and "password"') raise serializers.ValidationError(msg) + + attrs['user'] = user + return attrs diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 7c03cb76..94e6f061 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -18,7 +18,8 @@ class ObtainAuthToken(APIView): def post(self, request): serializer = self.serializer_class(data=request.DATA) if serializer.is_valid(): - token, created = Token.objects.get_or_create(user=serializer.object['user']) + user = serializer.validated_data['user'] + token, created = Token.objects.get_or_create(user=user) return Response({'token': token.key}) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 3e0f7ca4..838aa3b0 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,4 +1,5 @@ from rest_framework.utils import html +import inspect class empty: @@ -11,6 +12,22 @@ class empty: pass +def is_simple_callable(obj): + """ + True if the object is a callable that takes no arguments. + """ + function = inspect.isfunction(obj) + method = inspect.ismethod(obj) + + if not (function or method): + return False + + args, _, _, defaults = inspect.getargspec(obj) + len_args = len(args) if function else len(args) - 1 + len_defaults = len(defaults) if defaults else 0 + return len_args <= len_defaults + + def get_attribute(instance, attrs): """ Similar to Python's built in `getattr(instance, attr)`, @@ -98,6 +115,7 @@ class Field(object): self.field_name = field_name self.parent = parent self.root = root + self.context = parent.context # `self.label` should deafult to being based on the field name. if self.label is None: @@ -297,25 +315,55 @@ class IntegerField(Field): self.fail('invalid_integer') return data + def to_primative(self, value): + if value is None: + return None + return int(value) + class EmailField(CharField): pass # TODO +class URLField(CharField): + pass # TODO + + class RegexField(CharField): def __init__(self, **kwargs): self.regex = kwargs.pop('regex') super(CharField, self).__init__(**kwargs) +class DateField(CharField): + def __init__(self, **kwargs): + self.input_formats = kwargs.pop('input_formats', None) + super(DateField, self).__init__(**kwargs) + + +class TimeField(CharField): + def __init__(self, **kwargs): + self.input_formats = kwargs.pop('input_formats', None) + super(TimeField, self).__init__(**kwargs) + + class DateTimeField(CharField): - pass # TODO + def __init__(self, **kwargs): + self.input_formats = kwargs.pop('input_formats', None) + super(DateTimeField, self).__init__(**kwargs) class FileField(Field): pass # TODO +class ReadOnlyField(Field): + def to_primative(self, value): + if is_simple_callable(value): + return value() + return value + + class MethodField(Field): def __init__(self, **kwargs): kwargs['source'] = '*' diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 3e9c9bb3..359740ce 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -13,23 +13,6 @@ from rest_framework.request import clone_request from rest_framework.settings import api_settings -def _get_validation_exclusions(obj, lookup_field=None): - """ - Given a model instance, and an optional pk and slug field, - return the full list of all other field names on that model. - - For use when performing full_clean on a model instance, - so we only clean the required fields. - """ - if lookup_field == 'pk': - pk_field = obj._meta.pk - while pk_field.rel: - pk_field = pk_field.rel.to._meta.pk - lookup_field = pk_field.name - - return [field.name for field in obj._meta.fields if field.name != lookup_field] - - class CreateModelMixin(object): """ Create a model instance. @@ -92,15 +75,14 @@ class UpdateModelMixin(object): if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - lookup_value = self.kwargs[lookup_url_kwarg] - extras = {self.lookup_field: lookup_value} - if self.object is None: + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + lookup_value = self.kwargs[lookup_url_kwarg] + extras = {self.lookup_field: lookup_value} self.object = serializer.save(extras=extras) return Response(serializer.data, status=status.HTTP_201_CREATED) - self.object = serializer.save(extras=extras) + self.object = serializer.save() return Response(serializer.data, status=status.HTTP_200_OK) def partial_update(self, request, *args, **kwargs): @@ -122,21 +104,6 @@ class UpdateModelMixin(object): # return a 404 response. raise - def pre_save(self, obj): - """ - Set any attributes on the object that are implicit in the request. - """ - lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field - lookup_value = self.kwargs[lookup_url_kwarg] - - setattr(obj, self.lookup_field, lookup_value) - - # Ensure we clean the attributes so that we don't eg return integer - # pk using a string representation, as provided by the url conf kwarg. - if hasattr(obj, 'full_clean'): - exclude = _get_validation_exclusions(obj, self.lookup_field) - obj.full_clean(exclude) - class DestroyModelMixin(object): """ diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 83ef97c5..478d32b4 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -13,7 +13,7 @@ class NextPageField(serializers.Field): """ page_field = 'page' - def to_native(self, value): + def to_primative(self, value): if not value.has_next(): return None page = value.next_page_number() @@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field): """ page_field = 'page' - def to_native(self, value): + def to_primative(self, value): if not value.has_previous(): return None page = value.previous_page_number() @@ -48,25 +48,11 @@ class DefaultObjectSerializer(serializers.Field): super(DefaultObjectSerializer, self).__init__(source=source) -# class PaginationSerializerOptions(serializers.SerializerOptions): -# """ -# An object that stores the options that may be provided to a -# pagination serializer by using the inner `Meta` class. - -# Accessible on the instance as `serializer.opts`. -# """ -# def __init__(self, meta): -# super(PaginationSerializerOptions, self).__init__(meta) -# self.object_serializer_class = getattr(meta, 'object_serializer_class', -# DefaultObjectSerializer) - - class BasePaginationSerializer(serializers.Serializer): """ A base class for pagination serializers to inherit from, to make implementing custom serializers more easy. """ - # _options_class = PaginationSerializerOptions results_field = 'results' def __init__(self, *args, **kwargs): @@ -75,14 +61,16 @@ class BasePaginationSerializer(serializers.Serializer): """ super(BasePaginationSerializer, self).__init__(*args, **kwargs) results_field = self.results_field - object_serializer = self.opts.object_serializer_class - - if 'context' in kwargs: - context_kwarg = {'context': kwargs['context']} - else: - context_kwarg = {} - - self.fields[results_field] = object_serializer(source='object_list', **context_kwarg) + try: + object_serializer = self.Meta.object_serializer_class + except AttributeError: + object_serializer = DefaultObjectSerializer + + self.fields[results_field] = serializers.ListSerializer( + child=object_serializer(), + source='object_list' + ) + self.fields[results_field].bind(results_field, self, self) # TODO: Support automatic binding class PaginationSerializer(BasePaginationSerializer): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 42d2c121..0b01394a 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -73,7 +73,7 @@ class HyperlinkedRelatedField(RelatedField): try: http_prefix = value.startswith(('http:', 'https:')) except AttributeError: - self.fail('incorrect_type', type(value).__name__) + self.fail('incorrect_type', data_type=type(value).__name__) if http_prefix: # If needed convert absolute URLs to relative path diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 2f23b4d9..c38d8968 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -142,7 +142,7 @@ class Serializer(BaseSerializer): return super(Serializer, cls).__new__(cls) def __init__(self, *args, **kwargs): - kwargs.pop('context', None) + self.context = kwargs.pop('context', {}) kwargs.pop('partial', None) kwargs.pop('many', False) @@ -202,7 +202,7 @@ class Serializer(BaseSerializer): if errors: raise ValidationError(errors) - return ret + return self.validate(ret) def to_primative(self, instance): """ @@ -217,6 +217,9 @@ class Serializer(BaseSerializer): return ret + def validate(self, attrs): + return attrs + def __iter__(self): errors = self.errors if hasattr(self, '_errors') else {} for field in self.fields.values(): @@ -232,8 +235,7 @@ class ListSerializer(BaseSerializer): def __init__(self, *args, **kwargs): self.child = kwargs.pop('child', copy.deepcopy(self.child)) assert self.child is not None, '`child` is a required argument.' - - kwargs.pop('context', None) + self.context = kwargs.pop('context', {}) kwargs.pop('partial', None) super(ListSerializer, self).__init__(*args, **kwargs) @@ -316,19 +318,19 @@ class ModelSerializer(Serializer): models.PositiveIntegerField: IntegerField, models.SmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField, - # models.DateTimeField: DateTimeField, - # models.DateField: DateField, - # models.TimeField: TimeField, + models.DateTimeField: DateTimeField, + models.DateField: DateField, + models.TimeField: TimeField, # models.DecimalField: DecimalField, - # models.EmailField: EmailField, + models.EmailField: EmailField, models.CharField: CharField, - # models.URLField: URLField, + models.URLField: URLField, # models.SlugField: SlugField, models.TextField: CharField, models.CommaSeparatedIntegerField: CharField, models.BooleanField: BooleanField, models.NullBooleanField: BooleanField, - # models.FileField: FileField, + models.FileField: FileField, # models.ImageField: ImageField, } @@ -338,6 +340,15 @@ class ModelSerializer(Serializer): self.opts = self._options_class(self.Meta) super(ModelSerializer, self).__init__(*args, **kwargs) + def create(self): + ModelClass = self.opts.model + return ModelClass.objects.create(**self.validated_data) + + def update(self, obj): + for attr, value in self.validated_data.items(): + setattr(obj, attr, value) + obj.save() + def get_fields(self): # Get the explicitly declared fields. fields = copy.deepcopy(self.base_fields) @@ -566,8 +577,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializer(ModelSerializer): _options_class = HyperlinkedModelSerializerOptions _default_view_name = '%(model_name)s-detail' - # _hyperlink_field_class = HyperlinkedRelatedField - # _hyperlink_identify_field_class = HyperlinkedIdentityField + _hyperlink_field_class = HyperlinkedRelatedField + _hyperlink_identify_field_class = HyperlinkedIdentityField def get_default_fields(self): fields = super(HyperlinkedModelSerializer, self).get_default_fields() @@ -575,15 +586,15 @@ class HyperlinkedModelSerializer(ModelSerializer): if self.opts.view_name is None: self.opts.view_name = self._get_default_view_name(self.opts.model) - # if self.opts.url_field_name not in fields: - # url_field = self._hyperlink_identify_field_class( - # view_name=self.opts.view_name, - # lookup_field=self.opts.lookup_field - # ) - # ret = self._dict_class() - # ret[self.opts.url_field_name] = url_field - # ret.update(fields) - # fields = ret + if self.opts.url_field_name not in fields: + url_field = self._hyperlink_identify_field_class( + view_name=self.opts.view_name, + lookup_field=self.opts.lookup_field + ) + ret = fields.__class__() + ret[self.opts.url_field_name] = url_field + ret.update(fields) + fields = ret return fields |
