diff options
| author | Tom Christie | 2014-09-02 17:41:23 +0100 | 
|---|---|---|
| committer | Tom Christie | 2014-09-02 17:41:23 +0100 | 
| commit | f2852811f93863f2eed04d51eeb7ef27716b2409 (patch) | |
| tree | 45799932849f81d45d77edc53cb00269465ba0f1 /rest_framework | |
| parent | ec096a1caceff6a4f5c75a152dd1c7bea9ed281d (diff) | |
| download | django-rest-framework-f2852811f93863f2eed04d51eeb7ef27716b2409.tar.bz2 | |
Getting tests passing
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/authtoken/serializers.py | 5 | ||||
| -rw-r--r-- | rest_framework/authtoken/views.py | 3 | ||||
| -rw-r--r-- | rest_framework/fields.py | 50 | ||||
| -rw-r--r-- | rest_framework/mixins.py | 41 | ||||
| -rw-r--r-- | rest_framework/pagination.py | 36 | ||||
| -rw-r--r-- | rest_framework/relations.py | 2 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 53 | 
7 files changed, 103 insertions, 87 deletions
diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 99e99ae3..edeae857 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -19,11 +19,12 @@ class AuthTokenSerializer(serializers.Serializer):                  if not user.is_active:                      msg = _('User account is disabled.')                      raise serializers.ValidationError(msg) -                attrs['user'] = user -                return attrs              else:                  msg = _('Unable to login with provided credentials.')                  raise serializers.ValidationError(msg)          else:              msg = _('Must include "username" and "password"')              raise serializers.ValidationError(msg) + +        attrs['user'] = user +        return attrs diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 7c03cb76..94e6f061 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -18,7 +18,8 @@ class ObtainAuthToken(APIView):      def post(self, request):          serializer = self.serializer_class(data=request.DATA)          if serializer.is_valid(): -            token, created = Token.objects.get_or_create(user=serializer.object['user']) +            user = serializer.validated_data['user'] +            token, created = Token.objects.get_or_create(user=user)              return Response({'token': token.key})          return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 3e0f7ca4..838aa3b0 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,4 +1,5 @@  from rest_framework.utils import html +import inspect  class empty: @@ -11,6 +12,22 @@ class empty:      pass +def is_simple_callable(obj): +    """ +    True if the object is a callable that takes no arguments. +    """ +    function = inspect.isfunction(obj) +    method = inspect.ismethod(obj) + +    if not (function or method): +        return False + +    args, _, _, defaults = inspect.getargspec(obj) +    len_args = len(args) if function else len(args) - 1 +    len_defaults = len(defaults) if defaults else 0 +    return len_args <= len_defaults + +  def get_attribute(instance, attrs):      """      Similar to Python's built in `getattr(instance, attr)`, @@ -98,6 +115,7 @@ class Field(object):          self.field_name = field_name          self.parent = parent          self.root = root +        self.context = parent.context          # `self.label` should deafult to being based on the field name.          if self.label is None: @@ -297,25 +315,55 @@ class IntegerField(Field):              self.fail('invalid_integer')          return data +    def to_primative(self, value): +        if value is None: +            return None +        return int(value) +  class EmailField(CharField):      pass  # TODO +class URLField(CharField): +    pass  # TODO + +  class RegexField(CharField):      def __init__(self, **kwargs):          self.regex = kwargs.pop('regex')          super(CharField, self).__init__(**kwargs) +class DateField(CharField): +    def __init__(self, **kwargs): +        self.input_formats = kwargs.pop('input_formats', None) +        super(DateField, self).__init__(**kwargs) + + +class TimeField(CharField): +    def __init__(self, **kwargs): +        self.input_formats = kwargs.pop('input_formats', None) +        super(TimeField, self).__init__(**kwargs) + +  class DateTimeField(CharField): -    pass  # TODO +    def __init__(self, **kwargs): +        self.input_formats = kwargs.pop('input_formats', None) +        super(DateTimeField, self).__init__(**kwargs)  class FileField(Field):      pass  # TODO +class ReadOnlyField(Field): +    def to_primative(self, value): +        if is_simple_callable(value): +            return value() +        return value + +  class MethodField(Field):      def __init__(self, **kwargs):          kwargs['source'] = '*' diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 3e9c9bb3..359740ce 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -13,23 +13,6 @@ from rest_framework.request import clone_request  from rest_framework.settings import api_settings -def _get_validation_exclusions(obj, lookup_field=None): -    """ -    Given a model instance, and an optional pk and slug field, -    return the full list of all other field names on that model. - -    For use when performing full_clean on a model instance, -    so we only clean the required fields. -    """ -    if lookup_field == 'pk': -        pk_field = obj._meta.pk -        while pk_field.rel: -            pk_field = pk_field.rel.to._meta.pk -        lookup_field = pk_field.name - -    return [field.name for field in obj._meta.fields if field.name != lookup_field] - -  class CreateModelMixin(object):      """      Create a model instance. @@ -92,15 +75,14 @@ class UpdateModelMixin(object):          if not serializer.is_valid():              return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) -        lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field -        lookup_value = self.kwargs[lookup_url_kwarg] -        extras = {self.lookup_field: lookup_value} -          if self.object is None: +            lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field +            lookup_value = self.kwargs[lookup_url_kwarg] +            extras = {self.lookup_field: lookup_value}              self.object = serializer.save(extras=extras)              return Response(serializer.data, status=status.HTTP_201_CREATED) -        self.object = serializer.save(extras=extras) +        self.object = serializer.save()          return Response(serializer.data, status=status.HTTP_200_OK)      def partial_update(self, request, *args, **kwargs): @@ -122,21 +104,6 @@ class UpdateModelMixin(object):                  # return a 404 response.                  raise -    def pre_save(self, obj): -        """ -        Set any attributes on the object that are implicit in the request. -        """ -        lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field -        lookup_value = self.kwargs[lookup_url_kwarg] - -        setattr(obj, self.lookup_field, lookup_value) - -        # Ensure we clean the attributes so that we don't eg return integer -        # pk using a string representation, as provided by the url conf kwarg. -        if hasattr(obj, 'full_clean'): -            exclude = _get_validation_exclusions(obj, self.lookup_field) -            obj.full_clean(exclude) -  class DestroyModelMixin(object):      """ diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 83ef97c5..478d32b4 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -13,7 +13,7 @@ class NextPageField(serializers.Field):      """      page_field = 'page' -    def to_native(self, value): +    def to_primative(self, value):          if not value.has_next():              return None          page = value.next_page_number() @@ -28,7 +28,7 @@ class PreviousPageField(serializers.Field):      """      page_field = 'page' -    def to_native(self, value): +    def to_primative(self, value):          if not value.has_previous():              return None          page = value.previous_page_number() @@ -48,25 +48,11 @@ class DefaultObjectSerializer(serializers.Field):          super(DefaultObjectSerializer, self).__init__(source=source) -# class PaginationSerializerOptions(serializers.SerializerOptions): -#     """ -#     An object that stores the options that may be provided to a -#     pagination serializer by using the inner `Meta` class. - -#     Accessible on the instance as `serializer.opts`. -#     """ -#     def __init__(self, meta): -#         super(PaginationSerializerOptions, self).__init__(meta) -#         self.object_serializer_class = getattr(meta, 'object_serializer_class', -#                                                DefaultObjectSerializer) - -  class BasePaginationSerializer(serializers.Serializer):      """      A base class for pagination serializers to inherit from,      to make implementing custom serializers more easy.      """ -    # _options_class = PaginationSerializerOptions      results_field = 'results'      def __init__(self, *args, **kwargs): @@ -75,14 +61,16 @@ class BasePaginationSerializer(serializers.Serializer):          """          super(BasePaginationSerializer, self).__init__(*args, **kwargs)          results_field = self.results_field -        object_serializer = self.opts.object_serializer_class - -        if 'context' in kwargs: -            context_kwarg = {'context': kwargs['context']} -        else: -            context_kwarg = {} - -        self.fields[results_field] = object_serializer(source='object_list', **context_kwarg) +        try: +            object_serializer = self.Meta.object_serializer_class +        except AttributeError: +            object_serializer = DefaultObjectSerializer + +        self.fields[results_field] = serializers.ListSerializer( +            child=object_serializer(), +            source='object_list' +        ) +        self.fields[results_field].bind(results_field, self, self)  # TODO: Support automatic binding  class PaginationSerializer(BasePaginationSerializer): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 42d2c121..0b01394a 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -73,7 +73,7 @@ class HyperlinkedRelatedField(RelatedField):          try:              http_prefix = value.startswith(('http:', 'https:'))          except AttributeError: -            self.fail('incorrect_type', type(value).__name__) +            self.fail('incorrect_type', data_type=type(value).__name__)          if http_prefix:              # If needed convert absolute URLs to relative path diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 2f23b4d9..c38d8968 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -142,7 +142,7 @@ class Serializer(BaseSerializer):          return super(Serializer, cls).__new__(cls)      def __init__(self, *args, **kwargs): -        kwargs.pop('context', None) +        self.context = kwargs.pop('context', {})          kwargs.pop('partial', None)          kwargs.pop('many', False) @@ -202,7 +202,7 @@ class Serializer(BaseSerializer):          if errors:              raise ValidationError(errors) -        return ret +        return self.validate(ret)      def to_primative(self, instance):          """ @@ -217,6 +217,9 @@ class Serializer(BaseSerializer):          return ret +    def validate(self, attrs): +        return attrs +      def __iter__(self):          errors = self.errors if hasattr(self, '_errors') else {}          for field in self.fields.values(): @@ -232,8 +235,7 @@ class ListSerializer(BaseSerializer):      def __init__(self, *args, **kwargs):          self.child = kwargs.pop('child', copy.deepcopy(self.child))          assert self.child is not None, '`child` is a required argument.' - -        kwargs.pop('context', None) +        self.context = kwargs.pop('context', {})          kwargs.pop('partial', None)          super(ListSerializer, self).__init__(*args, **kwargs) @@ -316,19 +318,19 @@ class ModelSerializer(Serializer):          models.PositiveIntegerField: IntegerField,          models.SmallIntegerField: IntegerField,          models.PositiveSmallIntegerField: IntegerField, -        # models.DateTimeField: DateTimeField, -        # models.DateField: DateField, -        # models.TimeField: TimeField, +        models.DateTimeField: DateTimeField, +        models.DateField: DateField, +        models.TimeField: TimeField,          # models.DecimalField: DecimalField, -        # models.EmailField: EmailField, +        models.EmailField: EmailField,          models.CharField: CharField, -        # models.URLField: URLField, +        models.URLField: URLField,          # models.SlugField: SlugField,          models.TextField: CharField,          models.CommaSeparatedIntegerField: CharField,          models.BooleanField: BooleanField,          models.NullBooleanField: BooleanField, -        # models.FileField: FileField, +        models.FileField: FileField,          # models.ImageField: ImageField,      } @@ -338,6 +340,15 @@ class ModelSerializer(Serializer):          self.opts = self._options_class(self.Meta)          super(ModelSerializer, self).__init__(*args, **kwargs) +    def create(self): +        ModelClass = self.opts.model +        return ModelClass.objects.create(**self.validated_data) + +    def update(self, obj): +        for attr, value in self.validated_data.items(): +            setattr(obj, attr, value) +        obj.save() +      def get_fields(self):          # Get the explicitly declared fields.          fields = copy.deepcopy(self.base_fields) @@ -566,8 +577,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):  class HyperlinkedModelSerializer(ModelSerializer):      _options_class = HyperlinkedModelSerializerOptions      _default_view_name = '%(model_name)s-detail' -    # _hyperlink_field_class = HyperlinkedRelatedField -    # _hyperlink_identify_field_class = HyperlinkedIdentityField +    _hyperlink_field_class = HyperlinkedRelatedField +    _hyperlink_identify_field_class = HyperlinkedIdentityField      def get_default_fields(self):          fields = super(HyperlinkedModelSerializer, self).get_default_fields() @@ -575,15 +586,15 @@ class HyperlinkedModelSerializer(ModelSerializer):          if self.opts.view_name is None:              self.opts.view_name = self._get_default_view_name(self.opts.model) -        # if self.opts.url_field_name not in fields: -        #     url_field = self._hyperlink_identify_field_class( -        #         view_name=self.opts.view_name, -        #         lookup_field=self.opts.lookup_field -        #     ) -        #     ret = self._dict_class() -        #     ret[self.opts.url_field_name] = url_field -        #     ret.update(fields) -        #     fields = ret +        if self.opts.url_field_name not in fields: +            url_field = self._hyperlink_identify_field_class( +                view_name=self.opts.view_name, +                lookup_field=self.opts.lookup_field +            ) +            ret = fields.__class__() +            ret[self.opts.url_field_name] = url_field +            ret.update(fields) +            fields = ret          return fields  | 
